From ee4dc0fbb0255bffc5b95a6c23ffcee3fd0eb1a2 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Mon, 29 Jun 2020 15:39:19 -0400
Subject: [PATCH] Check leases when deciding what to spend on a mutable write

---
 src/_zkapauthorizer/_storage_client.py        | 30 ++++++-
 .../tests/test_storage_protocol.py            | 81 +++++++++++++++++++
 2 files changed, 107 insertions(+), 4 deletions(-)

diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py
index 684389e..a9cced2 100644
--- a/src/_zkapauthorizer/_storage_client.py
+++ b/src/_zkapauthorizer/_storage_client.py
@@ -30,6 +30,9 @@ from functools import (
 )
 
 import attr
+from attr.validators import (
+    provides,
+)
 
 from zope.interface import (
     implementer,
@@ -39,6 +42,12 @@ from eliot.twisted import (
     inline_callbacks,
 )
 
+from twisted.internet.interfaces import (
+    IReactorTime,
+)
+from twisted.python.reflect import (
+    namedAny,
+)
 from twisted.internet.defer import (
     returnValue,
 )
@@ -263,6 +272,10 @@ class ZKAPAuthorizerStorageClient(object):
     _pass_value = pass_value_attribute()
     _get_rref = attr.ib()
     _get_passes = attr.ib()
+    _clock = attr.ib(
+        validator=provides(IReactorTime),
+        default=attr.Factory(partial(namedAny, "twisted.internet.reactor")),
+    )
 
     def _rref(self):
         rref = self._get_rref()
@@ -464,11 +477,20 @@ class ZKAPAuthorizerStorageClient(object):
             # on the storage server that will give us a really good estimate
             # of the current size of all of the specified shares (keys of
             # tw_vectors).
-            current_sizes = yield rref.callRemote(
-                "share_sizes",
-                storage_index,
-                set(tw_vectors),
+            [stats] = yield rref.callRemote(
+                "stat_shares",
+                [storage_index],
             )
+            # Filter down to only the shares that have an active lease.  If
+            # we're going to write to any other shares we will have to pay to
+            # renew their leases.
+            now = self._clock.seconds()
+            current_sizes = {
+                sharenum: stat.size
+                for (sharenum, stat)
+                in stats.items()
+                if stat.lease_expiration > now
+            }
             # Determine the cost of the new storage for the operation.
             num_passes = get_required_new_passes_for_mutable_write(
                 self._pass_value,
diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py
index c6ebbca..bd1cfe9 100644
--- a/src/_zkapauthorizer/tests/test_storage_protocol.py
+++ b/src/_zkapauthorizer/tests/test_storage_protocol.py
@@ -992,6 +992,87 @@ class ShareTests(TestCase):
             Equals(leases_before),
         )
 
+    @given(
+        storage_index=storage_indexes(),
+        sharenum=sharenums(),
+        size=sizes(),
+        clock=clocks(),
+        write_enabler=write_enabler_secrets(),
+        renew_secret=lease_renew_secrets(),
+        cancel_secret=lease_cancel_secrets(),
+        test_and_write_vectors_for_shares=test_and_write_vectors_for_shares(),
+    )
+    def test_mutable_rewrite_renews_expired_lease(
+            self,
+            storage_index,
+            clock,
+            sharenum,
+            size,
+            write_enabler,
+            renew_secret,
+            cancel_secret,
+            test_and_write_vectors_for_shares,
+    ):
+        """
+        When mutable share data with an expired lease is rewritten using
+        *slot_testv_and_readv_and_writev* a new lease is paid for and granted.
+        """
+        # Hypothesis causes our storage server to be used many times.  Clean
+        # up between iterations.
+        cleanup_storage_server(self.anonymous_storage_server)
+
+        # Make the client and server use our clock.
+        self.server._clock = clock
+        self.client._clock = clock
+
+        secrets = (write_enabler, renew_secret, cancel_secret)
+
+        def write():
+            return self.client.slot_testv_and_readv_and_writev(
+                storage_index,
+                secrets=secrets,
+                tw_vectors={
+                    k: v.for_call()
+                    for (k, v)
+                    in test_and_write_vectors_for_shares.items()
+                },
+                r_vector=[],
+            )
+
+        # anonymous_storage_server uses time.time() to assign leases,
+        # unfortunately.
+        patch = MonkeyPatch("time.time", clock.seconds)
+        try:
+            patch.setUp()
+
+            # Create a share we can toy with.
+            self.assertThat(write(), is_successful_write())
+
+            # Advance time by more than a lease period so the lease is no
+            # longer valid.
+            clock.advance(self.server.LEASE_PERIOD.total_seconds() + 1)
+
+            self.assertThat(write(), is_successful_write())
+        finally:
+            patch.cleanUp()
+
+        # Not only should the write above succeed but the lease should now be
+        # marked as expiring one additional lease period into the future.
+        self.assertThat(
+            self.server.remote_stat_shares([storage_index]),
+            Equals([{
+                num: ShareStat(
+                    size=get_implied_data_length(
+                        test_and_write_vectors_for_shares[num].write_vector,
+                        test_and_write_vectors_for_shares[num].new_length,
+                    ),
+                    lease_expiration=int(clock.seconds() + self.server.LEASE_PERIOD.total_seconds()),
+                )
+                for num
+                in test_and_write_vectors_for_shares
+            }]),
+        )
+
     @given(
         storage_index=storage_indexes(),
         secrets=tuples(
-- 
GitLab