From 465f8c3ceec343cb966bfd5a3bfd8dd3b31f0678 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Wed, 20 May 2020 20:14:04 -0400
Subject: [PATCH] Add pass replacement logic to `call_with_passes`

Along with an in-memory implementation of a pass factory that works with it.
---
 src/_zkapauthorizer/_storage_client.py        | 117 +++++++++-----
 src/_zkapauthorizer/spending.py               | 136 ++++++++++++++++-
 src/_zkapauthorizer/tests/matchers.py         |  20 +++
 src/_zkapauthorizer/tests/storage_common.py   |  54 ++++++-
 .../tests/test_storage_client.py              | 143 +++++++++++++-----
 5 files changed, 386 insertions(+), 84 deletions(-)

diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py
index 0f7238c..ad2254f 100644
--- a/src/_zkapauthorizer/_storage_client.py
+++ b/src/_zkapauthorizer/_storage_client.py
@@ -36,13 +36,12 @@ from zope.interface import (
 )
 
 from eliot.twisted import (
-    DeferredContext,
+    inline_callbacks,
 )
 
 from twisted.internet.defer import (
     inlineCallbacks,
     returnValue,
-    maybeDeferred,
 )
 from allmydata.interfaces import (
     IStorageServer,
@@ -65,6 +64,7 @@ from .storage_common import (
     get_required_new_passes_for_mutable_write,
 )
 
+
 class IncorrectStorageServerReference(Exception):
     """
     A Foolscap remote object which should reference a ZKAPAuthorizer storage
@@ -84,55 +84,90 @@ class IncorrectStorageServerReference(Exception):
         )
 
 
+def replace_invalid_passes_with_new_passes(passes, more_passes_required):
+    """
+    Replace all rejected passes in the given pass group with new ones.  Mark
+    any rejected passes as rejected.
+
+    :param IPassGroup passes: A group of passes, some of which may have been
+        rejected.
+
+    :param MorePassesRequired more_passes_required: An exception possibly
+        detailing the rejection of some passes from the group.
+
+    :return: ``None`` if no passes in the group were rejected and so there is
+        nothing to replace.  Otherwise, a new ``IPassGroup`` created from
+        ``passes`` but with rejected passes replaced with new ones.
+    """
+    num_failed = len(more_passes_required.signature_check_failed)
+    if num_failed == 0:
+        # If no signature checks failed then the call just didn't supply
+        # enough passes.  The exception tells us how many passes we should
+        # spend so we could try again with that number of passes but for
+        # now we'll just let the exception propagate.  The client should
+        # always figure out the number of passes right on the first try so
+        # this case is somewhat suspicious.  Err on the side of lack of
+        # service instead of burning extra passes.
+        #
+        # We *could* just `raise` here and only be called from an `except`
+        # suite... but let's not be so vulgar.
+        return None
+    SIGNATURE_CHECK_FAILED.log(count=num_failed)
+    rejected_passes, okay_passes = passes.split(more_passes_required.signature_check_failed)
+    rejected_passes.mark_invalid(u"signature check failed")
+    return okay_passes.expand(len(more_passes_required.signature_check_failed))
+
+
+@inline_callbacks
 def call_with_passes(method, num_passes, get_passes):
     """
     Call a method, passing the requested number of passes as the first
     argument, and try again if the call fails with an error related to some of
     the passes being rejected.
 
-    :param method: A callable which accepts a list of encoded passes as its
-        only argument and returns a ``Deferred``.  If the ``Deferred`` fires
-        with ``MorePassesRequired`` then the invalid passes will be discarded
-        and replacement passes will be requested for a new call of ``method``.
-        This will repeat until no passes remain, the method succeeds, or the
-        methods fails in a different way.
+    :param (IPassGroup -> Deferred) method: An operation to call with some passes.
+        If the returned ``Deferred`` fires with ``MorePassesRequired`` then
+        the invalid passes will be discarded and replacement passes will be
+        requested for a new call of ``method``.  This will repeat until no
+        passes remain, the method succeeds, or the methods fails in a
+        different way.
 
     :param int num_passes: The number of passes to pass to the call.
 
-    :param (unicode -> int -> [bytes]) get_passes: A function for getting
+    :param (int -> IPassGroup) get_passes: A function for getting
         passes.
 
-    :return: Whatever ``method`` returns.
+    :return: A ``Deferred`` that fires with whatever the ``Deferred`` returned
+        by ``method`` fires with (apart from ``MorePassesRequired`` failures
+        that trigger a retry).
     """
-    def get_more_passes(reason):
-        reason.trap(MorePassesRequired)
-        num_failed = len(reason.value.signature_check_failed)
-        if num_failed == 0:
-            # If no signature checks failed then the call just didn't supply
-            # enough passes.  The exception tells us how many passes we should
-            # spend so we could try again with that number of passes but for
-            # now we'll just let the exception propagate.  The client should
-            # always figure out the number of passes right on the first try so
-            # this case is somewhat suspicious.  Err on the side of lack of
-            # service instead of burning extra passes.
-            return reason
-        SIGNATURE_CHECK_FAILED.log(count=num_failed)
-        new_passes = get_passes(num_failed)
-        for idx, new_pass in zip(reason.value.signature_check_failed, new_passes):
-            passes[idx] = new_pass
-        return go(passes)
-
-    def go(passes):
-        # Capture the Eliot context for the errback.
-        d = DeferredContext(maybeDeferred(method, passes))
-        d.addErrback(get_more_passes)
-        # Return the underlying Deferred without finishing the action.
-        return d.result
-
-    with CALL_WITH_PASSES(count=num_passes).context():
+    with CALL_WITH_PASSES(count=num_passes):
         passes = get_passes(num_passes)
-        # Finish the Eliot action when this is done.
-        return DeferredContext(go(passes)).addActionFinish()
+        try:
+            # Try and repeat as necessary.
+            while True:
+                try:
+                    result = yield method(passes)
+                except MorePassesRequired as e:
+                    updated_passes = replace_invalid_passes_with_new_passes(
+                        passes,
+                        e,
+                    )
+                    if updated_passes is None:
+                        raise
+                    else:
+                        passes = updated_passes
+                else:
+                    # Commit the spend of the passes when the operation finally succeeds.
+                    passes.mark_spent()
+                    break
+        except:
+            # Something went wrong that we can't address with a retry.
+            passes.reset()
+            raise
+
+        # Give the operation's result to the caller.
+        returnValue(result)
 
 
 def with_rref(f):
@@ -149,16 +184,16 @@ def with_rref(f):
     return g
 
 
-def _get_encoded_passes(passes):
+def _get_encoded_passes(group):
     """
-    :param list[Pass] passes: A group of passes to encode.
+    :param IPassGroup group: A group of passes to encode.
 
     :return list[bytes]: The encoded form of the passes in the given group.
     """
     return list(
         t.pass_text.encode("ascii")
         for t
-        in passes
+        in group.passes
     )
 
 
diff --git a/src/_zkapauthorizer/spending.py b/src/_zkapauthorizer/spending.py
index ac0353c..2e44de9 100644
--- a/src/_zkapauthorizer/spending.py
+++ b/src/_zkapauthorizer/spending.py
@@ -16,12 +16,134 @@
 A module for logic controlling the manner in which ZKAPs are spent.
 """
 
+from zope.interface import (
+    Interface,
+    Attribute,
+    implementer,
+)
+
 import attr
 
 from .eliot import (
     GET_PASSES,
 )
 
+class IPassGroup(Interface):
+    """
+    A group of passed meant to be spent together.
+    """
+    passes = Attribute(":ivar list[Pass] passes: The passes themselves.")
+
+    def split(select_indices):
+        """
+        Create two new ``IPassGroup`` providers.  The first contains all passes in
+        this group at the given indices.  The second contains all the others.
+
+        :param list[int] select_indices: The indices of the passes to include
+            in the first resulting group.
+
+        :return (IPassGroup, IPassGroup): The two new groups.
+        """
+
+    def expand(by_amount):
+        """
+        Create a new ``IPassGroup`` provider which contains all of this group's
+        passes and some more.
+
+        :param int by_amount: The number of additional passes the resulting
+            group should contain.
+
+        :return IPassGroup: The new group.
+        """
+
+    def mark_spent():
+        """
+        The passes have been spent successfully.  Ensure none of them appear in
+        any ``IPassGroup`` provider created in the future.
+
+        :return: ``None``
+        """
+
+    def mark_invalid(reason):
+        """
+        The passes could not be spent.  Ensure none of them appear in any
+        ``IPassGroup`` provider created in the future.
+
+        :param unicode reason: A short description of the reason the passes
+            could not be spent.
+
+        :return: ``None``
+        """
+
+    def reset():
+        """
+        The passes have not been spent.  Return them to for use in a future
+        ``IPassGroup`` provider.
+
+        :return: ``None``
+        """
+
+
+class IPassFactory(Interface):
+    """
+    An object which can create passes.
+    """
+    def get(message, num_passes):
+        """
+        :param unicode message: A request-binding message for the resulting passes.
+
+        :param int num_passes: The number of passes to request.
+
+        :return IPassGroup: A group of passes bound to the given message and
+            of the requested size.
+        """
+
+
+@implementer(IPassGroup)
+@attr.s
+class PassGroup(object):
+    """
+    Track the state of a group of passes intended as payment for an operation.
+
+    :ivar unicode _message: The request binding message for this group of
+        passes.
+
+    :ivar IPassFactory _factory: The factory which created this pass group.
+
+    :ivar list[Pass] passes: The passes of which this group consists.
+    """
+    _message = attr.ib()
+    _factory = attr.ib()
+    passes = attr.ib()
+
+    def split(self, select_indices):
+        selected = []
+        unselected = []
+        for idx, p in enumerate(self.passes):
+            if idx in select_indices:
+                selected.append(p)
+            else:
+                unselected.append(p)
+        return (
+            attr.evolve(self, passes=selected),
+            attr.evolve(self, passes=unselected),
+        )
+
+    def expand(self, by_amount):
+        return attr.evolve(
+            self,
+            passes=self.passes + self._factory.get(self._message, by_amount).passes,
+        )
+
+    def mark_spent(self):
+        self._factory._mark_spent(self.passes)
+
+    def mark_invalid(self, reason):
+        self._factory._mark_invalid(reason, self.passes)
+
+    def reset(self):
+        self._factory._reset(self.passes)
+
 
 @attr.s
 class SpendingController(object):
@@ -39,4 +161,16 @@ class SpendingController(object):
             message=message,
             count=num_passes,
         )
-        return passes
+        return PassGroup(message, self, passes)
+
+    def _mark_spent(self, group):
+        # TODO
+        pass
+
+    def _mark_invalid(self, reason, group):
+        # TODO
+        pass
+
+    def _reset(self, group):
+        # TODO
+        pass
diff --git a/src/_zkapauthorizer/tests/matchers.py b/src/_zkapauthorizer/tests/matchers.py
index 6c7ab45..79b4feb 100644
--- a/src/_zkapauthorizer/tests/matchers.py
+++ b/src/_zkapauthorizer/tests/matchers.py
@@ -154,3 +154,23 @@ def leases_current(relevant_storage_indexes, now, min_lease_remaining):
             ),
         ),
     )
+
+
+def even():
+    """
+    Matches even integers.
+    """
+    return AfterPreprocessing(
+        lambda n: n % 2,
+        Equals(0),
+    )
+
+
+def odd():
+    """
+    Matches odd integers.
+    """
+    return AfterPreprocessing(
+        lambda n: n % 2,
+        Equals(1),
+    )
diff --git a/src/_zkapauthorizer/tests/storage_common.py b/src/_zkapauthorizer/tests/storage_common.py
index c882a49..ca2915f 100644
--- a/src/_zkapauthorizer/tests/storage_common.py
+++ b/src/_zkapauthorizer/tests/storage_common.py
@@ -34,6 +34,10 @@ from itertools import (
 
 import attr
 
+from zope.interface import (
+    implementer,
+)
+
 from twisted.python.filepath import (
     FilePath,
 )
@@ -55,6 +59,11 @@ from ..model import (
     Pass,
 )
 
+from ..spending import (
+    IPassFactory,
+    PassGroup,
+)
+
 # Hard-coded in Tahoe-LAFS
 LEASE_INTERVAL = 60 * 60 * 24 * 31
 
@@ -219,6 +228,7 @@ def pass_factory(get_passes=None):
     return _PassFactory(get_passes=get_passes)
 
 
+@implementer(IPassFactory)
 @attr.s
 class _PassFactory(object):
     """
@@ -227,14 +237,56 @@ class _PassFactory(object):
     :ivar (unicode -> int -> [bytes]) _get_passes: A function for getting
         passes.
 
+    :ivar set[int] in_use: All of the passes given out without a confirmed
+        terminal state.
+
+    :ivar dict[int, unicode] invalid: All of the passes given out and returned
+        using ``IPassGroup.invalid`` mapped to the reason given.
+
+    :ivar set[int] spent: All of the passes given out and returned via
+        ``IPassGroup.mark_spent``.
+
     :ivar set[int] issued: All of the passes ever given out.
 
+    :ivar list[int] returned: A list of passes which were given out but then
+        returned via ``IPassGroup.reset``.
     """
     _get_passes = attr.ib()
+
+    returned = attr.ib(default=attr.Factory(list), init=False)
+    in_use = attr.ib(default=attr.Factory(set), init=False)
+    invalid = attr.ib(default=attr.Factory(dict), init=False)
+    spent = attr.ib(default=attr.Factory(set), init=False)
     issued = attr.ib(default=attr.Factory(set), init=False)
 
     def get(self, message, num_passes):
         passes = []
+        if self.returned:
+            passes.extend(self.returned[:num_passes])
+            del self.returned[:num_passes]
+            num_passes -= len(passes)
         passes.extend(self._get_passes(message, num_passes))
         self.issued.update(passes)
-        return passes
+        self.in_use.update(passes)
+        return PassGroup(message, self, passes)
+
+    def _mark_spent(self, passes):
+        for p in passes:
+            if p not in self.in_use:
+                raise ValueError("Pass {} cannot be spent, it is not in use.".format(p))
+        self.spent.update(passes)
+        self.in_use.difference_update(passes)
+
+    def _mark_invalid(self, reason, passes):
+        for p in passes:
+            if p not in self.in_use:
+                raise ValueError("Pass {} cannot be invalid, it is not in use.".format(p))
+        self.invalid.update(dict.fromkeys(passes, reason))
+        self.in_use.difference_update(passes)
+
+    def _reset(self, passes):
+        for p in passes:
+            if p not in self.in_use:
+                raise ValueError("Pass {} cannot be reset, it is not in use.".format(p))
+        self.returned.extend(passes)
+        self.in_use.difference_update(passes)
diff --git a/src/_zkapauthorizer/tests/test_storage_client.py b/src/_zkapauthorizer/tests/test_storage_client.py
index 77018c2..ee5d1bc 100644
--- a/src/_zkapauthorizer/tests/test_storage_client.py
+++ b/src/_zkapauthorizer/tests/test_storage_client.py
@@ -16,11 +16,8 @@
 Tests for ``_zkapauthorizer._storage_client``.
 """
 
-import attr
-
-from itertools import (
-    count,
-    islice,
+from functools import (
+    partial,
 )
 
 from testtools import (
@@ -31,6 +28,10 @@ from testtools.matchers import (
     Is,
     Equals,
     AfterPreprocessing,
+    MatchesStructure,
+    HasLength,
+    MatchesAll,
+    AllMatch,
 )
 from testtools.twistedsupport import (
     succeeded,
@@ -49,6 +50,11 @@ from twisted.internet.defer import (
     fail,
 )
 
+from .matchers import (
+    even,
+    odd,
+)
+
 from ..api import (
     MorePassesRequired,
 )
@@ -60,30 +66,13 @@ from .._storage_server import (
     _ValidationResult,
 )
 
-def pass_counts():
-    return integers(min_value=1, max_value=2 ** 8)
-
-
-def pass_factory():
-    return _PassFactory()
-
-@attr.s
-class _PassFactory(object):
-    """
-    A stateful pass issuer.
-
-    :ivar list spent: All of the passes ever issued.
-
-    :ivar _fountain: A counter for making each new pass issued unique.
-    """
-    spent = attr.ib(default=attr.Factory(list))
+from .storage_common import (
+    pass_factory,
+)
 
-    _fountain = attr.ib(default=attr.Factory(count))
 
-    def get(self, num_passes):
-        passes = list(islice(self._fountain, num_passes))
-        self.spent.extend(passes)
-        return passes
+def pass_counts():
+    return integers(min_value=1, max_value=2 ** 8)
 
 
 class CallWithPassesTests(TestCase):
@@ -100,9 +89,9 @@ class CallWithPassesTests(TestCase):
         result = object()
         self.assertThat(
             call_with_passes(
-                lambda passes: succeed(result),
+                lambda group: succeed(result),
                 num_passes,
-                pass_factory().get,
+                partial(pass_factory().get, u"message"),
             ),
             succeeded(Is(result)),
         )
@@ -117,9 +106,9 @@ class CallWithPassesTests(TestCase):
         result = Exception()
         self.assertThat(
             call_with_passes(
-                lambda passes: fail(result),
+                lambda group: fail(result),
                 num_passes,
-                pass_factory().get,
+                partial(pass_factory().get, u"message"),
             ),
             failed(
                 AfterPreprocessing(
@@ -130,27 +119,71 @@ class CallWithPassesTests(TestCase):
         )
 
     @given(pass_counts())
-    def test_passes(self, num_passes):
+    def test_passes_issued(self, num_passes):
         """
-        ``call_with_passes`` calls the given method with a list of passes
-        containing ``num_passes`` created by the function passed for
+        ``call_with_passes`` calls the given method with an ``IPassGroup``
+        provider containing ``num_passes`` created by the function passed for
         ``get_passes``.
         """
         passes = pass_factory()
 
         self.assertThat(
             call_with_passes(
-                lambda passes: succeed(passes),
+                lambda group: succeed(group.passes),
                 num_passes,
-                passes.get,
+                partial(passes.get, u"message"),
             ),
             succeeded(
                 Equals(
-                    passes.spent,
+                    sorted(passes.issued),
                 ),
             ),
         )
 
+    @given(pass_counts())
+    def test_passes_spent_on_success(self, num_passes):
+        """
+        ``call_with_passes`` marks the passes it uses as spent if the operation
+        succeeds.
+        """
+        passes = pass_factory()
+
+        self.assertThat(
+            call_with_passes(
+                lambda group: None,
+                num_passes,
+                partial(passes.get, u"message"),
+            ),
+            succeeded(Always()),
+        )
+        self.assertThat(
+            passes.issued,
+            Equals(passes.spent),
+        )
+
+    @given(pass_counts())
+    def test_passes_returned_on_failure(self, num_passes):
+        """
+        ``call_with_passes`` returns the passes it uses if the operation fails.
+        """
+        passes = pass_factory()
+
+        self.assertThat(
+            call_with_passes(
+                lambda group: fail(Exception("Anything")),
+                num_passes,
+                partial(passes.get, u"message"),
+            ),
+            failed(Always()),
+        )
+        self.assertThat(
+            passes,
+            MatchesStructure(
+                issued=Equals(set(passes.returned)),
+                spent=Equals(set()),
+            ),
+        )
+
     @given(pass_counts())
     def test_retry_on_rejected_passes(self, num_passes):
         """
@@ -160,7 +193,8 @@ class CallWithPassesTests(TestCase):
         """
         passes = pass_factory()
 
-        def reject_even_pass_values(passes):
+        def reject_even_pass_values(group):
+            passes = group.passes
             good_passes = list(idx for (idx, p) in enumerate(passes) if p % 2)
             bad_passes = list(idx for (idx, p) in enumerate(passes) if idx not in good_passes)
             if len(good_passes) < num_passes:
@@ -174,10 +208,26 @@ class CallWithPassesTests(TestCase):
             call_with_passes(
                 reject_even_pass_values,
                 num_passes,
-                passes.get,
+                partial(passes.get, u"message"),
             ),
             succeeded(Always()),
         )
+        self.assertThat(
+            passes,
+            MatchesStructure(
+                returned=HasLength(0),
+                in_use=HasLength(0),
+                invalid=MatchesAll(
+                    HasLength(num_passes),
+                    AllMatch(even()),
+                ),
+                spent=MatchesAll(
+                    HasLength(num_passes),
+                    AllMatch(odd()),
+                ),
+                issued=Equals(passes.spent | set(passes.invalid.keys())),
+            ),
+        )
 
     @given(pass_counts())
     def test_pass_through_too_few_passes(self, num_passes):
@@ -188,7 +238,8 @@ class CallWithPassesTests(TestCase):
         """
         passes = pass_factory()
 
-        def reject_passes(passes):
+        def reject_passes(group):
+            passes = group.passes
             _ValidationResult(
                 valid=range(len(passes)),
                 signature_check_failed=[],
@@ -198,7 +249,7 @@ class CallWithPassesTests(TestCase):
             call_with_passes(
                 reject_passes,
                 num_passes,
-                passes.get,
+                partial(passes.get, u"message"),
             ),
             failed(
                 AfterPreprocessing(
@@ -213,3 +264,13 @@ class CallWithPassesTests(TestCase):
                 ),
             ),
         )
+
+        # The passes in the group that was rejected are also returned for
+        # later use.
+        self.assertThat(
+            passes,
+            MatchesStructure(
+                spent=HasLength(0),
+                returned=HasLength(num_passes),
+            ),
+        )
-- 
GitLab