diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py index 0f7238ce8b8283a3e4f96075fcd13086dcd802fc..ad2254f142101574d27bb1712e3f6af7bbf82b04 100644 --- a/src/_zkapauthorizer/_storage_client.py +++ b/src/_zkapauthorizer/_storage_client.py @@ -36,13 +36,12 @@ from zope.interface import ( ) from eliot.twisted import ( - DeferredContext, + inline_callbacks, ) from twisted.internet.defer import ( inlineCallbacks, returnValue, - maybeDeferred, ) from allmydata.interfaces import ( IStorageServer, @@ -65,6 +64,7 @@ from .storage_common import ( get_required_new_passes_for_mutable_write, ) + class IncorrectStorageServerReference(Exception): """ A Foolscap remote object which should reference a ZKAPAuthorizer storage @@ -84,55 +84,90 @@ class IncorrectStorageServerReference(Exception): ) +def replace_invalid_passes_with_new_passes(passes, more_passes_required): + """ + Replace all rejected passes in the given pass group with new ones. Mark + any rejected passes as rejected. + + :param IPassGroup passes: A group of passes, some of which may have been + rejected. + + :param MorePassesRequired more_passes_required: An exception possibly + detailing the rejection of some passes from the group. + + :return: ``None`` if no passes in the group were rejected and so there is + nothing to replace. Otherwise, a new ``IPassGroup`` created from + ``passes`` but with rejected passes replaced with new ones. + """ + num_failed = len(more_passes_required.signature_check_failed) + if num_failed == 0: + # If no signature checks failed then the call just didn't supply + # enough passes. The exception tells us how many passes we should + # spend so we could try again with that number of passes but for + # now we'll just let the exception propagate. The client should + # always figure out the number of passes right on the first try so + # this case is somewhat suspicious. Err on the side of lack of + # service instead of burning extra passes. + # + # We *could* just `raise` here and only be called from an `except` + # suite... but let's not be so vulgar. + return None + SIGNATURE_CHECK_FAILED.log(count=num_failed) + rejected_passes, okay_passes = passes.split(more_passes_required.signature_check_failed) + rejected_passes.mark_invalid(u"signature check failed") + return okay_passes.expand(len(more_passes_required.signature_check_failed)) + + +@inline_callbacks def call_with_passes(method, num_passes, get_passes): """ Call a method, passing the requested number of passes as the first argument, and try again if the call fails with an error related to some of the passes being rejected. - :param method: A callable which accepts a list of encoded passes as its - only argument and returns a ``Deferred``. If the ``Deferred`` fires - with ``MorePassesRequired`` then the invalid passes will be discarded - and replacement passes will be requested for a new call of ``method``. - This will repeat until no passes remain, the method succeeds, or the - methods fails in a different way. + :param (IPassGroup -> Deferred) method: An operation to call with some passes. + If the returned ``Deferred`` fires with ``MorePassesRequired`` then + the invalid passes will be discarded and replacement passes will be + requested for a new call of ``method``. This will repeat until no + passes remain, the method succeeds, or the methods fails in a + different way. :param int num_passes: The number of passes to pass to the call. - :param (unicode -> int -> [bytes]) get_passes: A function for getting + :param (int -> IPassGroup) get_passes: A function for getting passes. - :return: Whatever ``method`` returns. + :return: A ``Deferred`` that fires with whatever the ``Deferred`` returned + by ``method`` fires with (apart from ``MorePassesRequired`` failures + that trigger a retry). """ - def get_more_passes(reason): - reason.trap(MorePassesRequired) - num_failed = len(reason.value.signature_check_failed) - if num_failed == 0: - # If no signature checks failed then the call just didn't supply - # enough passes. The exception tells us how many passes we should - # spend so we could try again with that number of passes but for - # now we'll just let the exception propagate. The client should - # always figure out the number of passes right on the first try so - # this case is somewhat suspicious. Err on the side of lack of - # service instead of burning extra passes. - return reason - SIGNATURE_CHECK_FAILED.log(count=num_failed) - new_passes = get_passes(num_failed) - for idx, new_pass in zip(reason.value.signature_check_failed, new_passes): - passes[idx] = new_pass - return go(passes) - - def go(passes): - # Capture the Eliot context for the errback. - d = DeferredContext(maybeDeferred(method, passes)) - d.addErrback(get_more_passes) - # Return the underlying Deferred without finishing the action. - return d.result - - with CALL_WITH_PASSES(count=num_passes).context(): + with CALL_WITH_PASSES(count=num_passes): passes = get_passes(num_passes) - # Finish the Eliot action when this is done. - return DeferredContext(go(passes)).addActionFinish() + try: + # Try and repeat as necessary. + while True: + try: + result = yield method(passes) + except MorePassesRequired as e: + updated_passes = replace_invalid_passes_with_new_passes( + passes, + e, + ) + if updated_passes is None: + raise + else: + passes = updated_passes + else: + # Commit the spend of the passes when the operation finally succeeds. + passes.mark_spent() + break + except: + # Something went wrong that we can't address with a retry. + passes.reset() + raise + + # Give the operation's result to the caller. + returnValue(result) def with_rref(f): @@ -149,16 +184,16 @@ def with_rref(f): return g -def _get_encoded_passes(passes): +def _get_encoded_passes(group): """ - :param list[Pass] passes: A group of passes to encode. + :param IPassGroup group: A group of passes to encode. :return list[bytes]: The encoded form of the passes in the given group. """ return list( t.pass_text.encode("ascii") for t - in passes + in group.passes ) diff --git a/src/_zkapauthorizer/spending.py b/src/_zkapauthorizer/spending.py index ac0353c6595cfa002c27c4ef0432df44cc8bc63f..2e44de96f4af2157d9db61f1b425c42e9d733f8d 100644 --- a/src/_zkapauthorizer/spending.py +++ b/src/_zkapauthorizer/spending.py @@ -16,12 +16,134 @@ A module for logic controlling the manner in which ZKAPs are spent. """ +from zope.interface import ( + Interface, + Attribute, + implementer, +) + import attr from .eliot import ( GET_PASSES, ) +class IPassGroup(Interface): + """ + A group of passed meant to be spent together. + """ + passes = Attribute(":ivar list[Pass] passes: The passes themselves.") + + def split(select_indices): + """ + Create two new ``IPassGroup`` providers. The first contains all passes in + this group at the given indices. The second contains all the others. + + :param list[int] select_indices: The indices of the passes to include + in the first resulting group. + + :return (IPassGroup, IPassGroup): The two new groups. + """ + + def expand(by_amount): + """ + Create a new ``IPassGroup`` provider which contains all of this group's + passes and some more. + + :param int by_amount: The number of additional passes the resulting + group should contain. + + :return IPassGroup: The new group. + """ + + def mark_spent(): + """ + The passes have been spent successfully. Ensure none of them appear in + any ``IPassGroup`` provider created in the future. + + :return: ``None`` + """ + + def mark_invalid(reason): + """ + The passes could not be spent. Ensure none of them appear in any + ``IPassGroup`` provider created in the future. + + :param unicode reason: A short description of the reason the passes + could not be spent. + + :return: ``None`` + """ + + def reset(): + """ + The passes have not been spent. Return them to for use in a future + ``IPassGroup`` provider. + + :return: ``None`` + """ + + +class IPassFactory(Interface): + """ + An object which can create passes. + """ + def get(message, num_passes): + """ + :param unicode message: A request-binding message for the resulting passes. + + :param int num_passes: The number of passes to request. + + :return IPassGroup: A group of passes bound to the given message and + of the requested size. + """ + + +@implementer(IPassGroup) +@attr.s +class PassGroup(object): + """ + Track the state of a group of passes intended as payment for an operation. + + :ivar unicode _message: The request binding message for this group of + passes. + + :ivar IPassFactory _factory: The factory which created this pass group. + + :ivar list[Pass] passes: The passes of which this group consists. + """ + _message = attr.ib() + _factory = attr.ib() + passes = attr.ib() + + def split(self, select_indices): + selected = [] + unselected = [] + for idx, p in enumerate(self.passes): + if idx in select_indices: + selected.append(p) + else: + unselected.append(p) + return ( + attr.evolve(self, passes=selected), + attr.evolve(self, passes=unselected), + ) + + def expand(self, by_amount): + return attr.evolve( + self, + passes=self.passes + self._factory.get(self._message, by_amount).passes, + ) + + def mark_spent(self): + self._factory._mark_spent(self.passes) + + def mark_invalid(self, reason): + self._factory._mark_invalid(reason, self.passes) + + def reset(self): + self._factory._reset(self.passes) + @attr.s class SpendingController(object): @@ -39,4 +161,16 @@ class SpendingController(object): message=message, count=num_passes, ) - return passes + return PassGroup(message, self, passes) + + def _mark_spent(self, group): + # TODO + pass + + def _mark_invalid(self, reason, group): + # TODO + pass + + def _reset(self, group): + # TODO + pass diff --git a/src/_zkapauthorizer/tests/matchers.py b/src/_zkapauthorizer/tests/matchers.py index 6c7ab457c04c6971965779bb4445517decf9e933..79b4febf24c6cad07c30b10f837aa3e9a74b92f4 100644 --- a/src/_zkapauthorizer/tests/matchers.py +++ b/src/_zkapauthorizer/tests/matchers.py @@ -154,3 +154,23 @@ def leases_current(relevant_storage_indexes, now, min_lease_remaining): ), ), ) + + +def even(): + """ + Matches even integers. + """ + return AfterPreprocessing( + lambda n: n % 2, + Equals(0), + ) + + +def odd(): + """ + Matches odd integers. + """ + return AfterPreprocessing( + lambda n: n % 2, + Equals(1), + ) diff --git a/src/_zkapauthorizer/tests/storage_common.py b/src/_zkapauthorizer/tests/storage_common.py index c882a4917a92f6954b00bfb4d64b1c7acfcf3f56..ca2915f963d7c16a66e979db984a89c1524aa518 100644 --- a/src/_zkapauthorizer/tests/storage_common.py +++ b/src/_zkapauthorizer/tests/storage_common.py @@ -34,6 +34,10 @@ from itertools import ( import attr +from zope.interface import ( + implementer, +) + from twisted.python.filepath import ( FilePath, ) @@ -55,6 +59,11 @@ from ..model import ( Pass, ) +from ..spending import ( + IPassFactory, + PassGroup, +) + # Hard-coded in Tahoe-LAFS LEASE_INTERVAL = 60 * 60 * 24 * 31 @@ -219,6 +228,7 @@ def pass_factory(get_passes=None): return _PassFactory(get_passes=get_passes) +@implementer(IPassFactory) @attr.s class _PassFactory(object): """ @@ -227,14 +237,56 @@ class _PassFactory(object): :ivar (unicode -> int -> [bytes]) _get_passes: A function for getting passes. + :ivar set[int] in_use: All of the passes given out without a confirmed + terminal state. + + :ivar dict[int, unicode] invalid: All of the passes given out and returned + using ``IPassGroup.invalid`` mapped to the reason given. + + :ivar set[int] spent: All of the passes given out and returned via + ``IPassGroup.mark_spent``. + :ivar set[int] issued: All of the passes ever given out. + :ivar list[int] returned: A list of passes which were given out but then + returned via ``IPassGroup.reset``. """ _get_passes = attr.ib() + + returned = attr.ib(default=attr.Factory(list), init=False) + in_use = attr.ib(default=attr.Factory(set), init=False) + invalid = attr.ib(default=attr.Factory(dict), init=False) + spent = attr.ib(default=attr.Factory(set), init=False) issued = attr.ib(default=attr.Factory(set), init=False) def get(self, message, num_passes): passes = [] + if self.returned: + passes.extend(self.returned[:num_passes]) + del self.returned[:num_passes] + num_passes -= len(passes) passes.extend(self._get_passes(message, num_passes)) self.issued.update(passes) - return passes + self.in_use.update(passes) + return PassGroup(message, self, passes) + + def _mark_spent(self, passes): + for p in passes: + if p not in self.in_use: + raise ValueError("Pass {} cannot be spent, it is not in use.".format(p)) + self.spent.update(passes) + self.in_use.difference_update(passes) + + def _mark_invalid(self, reason, passes): + for p in passes: + if p not in self.in_use: + raise ValueError("Pass {} cannot be invalid, it is not in use.".format(p)) + self.invalid.update(dict.fromkeys(passes, reason)) + self.in_use.difference_update(passes) + + def _reset(self, passes): + for p in passes: + if p not in self.in_use: + raise ValueError("Pass {} cannot be reset, it is not in use.".format(p)) + self.returned.extend(passes) + self.in_use.difference_update(passes) diff --git a/src/_zkapauthorizer/tests/test_storage_client.py b/src/_zkapauthorizer/tests/test_storage_client.py index 77018c2928f1c7cd991093ed416e12c5ce3a4dff..ee5d1bc001923beb5957c0a3205e2b7bb8b650bb 100644 --- a/src/_zkapauthorizer/tests/test_storage_client.py +++ b/src/_zkapauthorizer/tests/test_storage_client.py @@ -16,11 +16,8 @@ Tests for ``_zkapauthorizer._storage_client``. """ -import attr - -from itertools import ( - count, - islice, +from functools import ( + partial, ) from testtools import ( @@ -31,6 +28,10 @@ from testtools.matchers import ( Is, Equals, AfterPreprocessing, + MatchesStructure, + HasLength, + MatchesAll, + AllMatch, ) from testtools.twistedsupport import ( succeeded, @@ -49,6 +50,11 @@ from twisted.internet.defer import ( fail, ) +from .matchers import ( + even, + odd, +) + from ..api import ( MorePassesRequired, ) @@ -60,30 +66,13 @@ from .._storage_server import ( _ValidationResult, ) -def pass_counts(): - return integers(min_value=1, max_value=2 ** 8) - - -def pass_factory(): - return _PassFactory() - -@attr.s -class _PassFactory(object): - """ - A stateful pass issuer. - - :ivar list spent: All of the passes ever issued. - - :ivar _fountain: A counter for making each new pass issued unique. - """ - spent = attr.ib(default=attr.Factory(list)) +from .storage_common import ( + pass_factory, +) - _fountain = attr.ib(default=attr.Factory(count)) - def get(self, num_passes): - passes = list(islice(self._fountain, num_passes)) - self.spent.extend(passes) - return passes +def pass_counts(): + return integers(min_value=1, max_value=2 ** 8) class CallWithPassesTests(TestCase): @@ -100,9 +89,9 @@ class CallWithPassesTests(TestCase): result = object() self.assertThat( call_with_passes( - lambda passes: succeed(result), + lambda group: succeed(result), num_passes, - pass_factory().get, + partial(pass_factory().get, u"message"), ), succeeded(Is(result)), ) @@ -117,9 +106,9 @@ class CallWithPassesTests(TestCase): result = Exception() self.assertThat( call_with_passes( - lambda passes: fail(result), + lambda group: fail(result), num_passes, - pass_factory().get, + partial(pass_factory().get, u"message"), ), failed( AfterPreprocessing( @@ -130,27 +119,71 @@ class CallWithPassesTests(TestCase): ) @given(pass_counts()) - def test_passes(self, num_passes): + def test_passes_issued(self, num_passes): """ - ``call_with_passes`` calls the given method with a list of passes - containing ``num_passes`` created by the function passed for + ``call_with_passes`` calls the given method with an ``IPassGroup`` + provider containing ``num_passes`` created by the function passed for ``get_passes``. """ passes = pass_factory() self.assertThat( call_with_passes( - lambda passes: succeed(passes), + lambda group: succeed(group.passes), num_passes, - passes.get, + partial(passes.get, u"message"), ), succeeded( Equals( - passes.spent, + sorted(passes.issued), ), ), ) + @given(pass_counts()) + def test_passes_spent_on_success(self, num_passes): + """ + ``call_with_passes`` marks the passes it uses as spent if the operation + succeeds. + """ + passes = pass_factory() + + self.assertThat( + call_with_passes( + lambda group: None, + num_passes, + partial(passes.get, u"message"), + ), + succeeded(Always()), + ) + self.assertThat( + passes.issued, + Equals(passes.spent), + ) + + @given(pass_counts()) + def test_passes_returned_on_failure(self, num_passes): + """ + ``call_with_passes`` returns the passes it uses if the operation fails. + """ + passes = pass_factory() + + self.assertThat( + call_with_passes( + lambda group: fail(Exception("Anything")), + num_passes, + partial(passes.get, u"message"), + ), + failed(Always()), + ) + self.assertThat( + passes, + MatchesStructure( + issued=Equals(set(passes.returned)), + spent=Equals(set()), + ), + ) + @given(pass_counts()) def test_retry_on_rejected_passes(self, num_passes): """ @@ -160,7 +193,8 @@ class CallWithPassesTests(TestCase): """ passes = pass_factory() - def reject_even_pass_values(passes): + def reject_even_pass_values(group): + passes = group.passes good_passes = list(idx for (idx, p) in enumerate(passes) if p % 2) bad_passes = list(idx for (idx, p) in enumerate(passes) if idx not in good_passes) if len(good_passes) < num_passes: @@ -174,10 +208,26 @@ class CallWithPassesTests(TestCase): call_with_passes( reject_even_pass_values, num_passes, - passes.get, + partial(passes.get, u"message"), ), succeeded(Always()), ) + self.assertThat( + passes, + MatchesStructure( + returned=HasLength(0), + in_use=HasLength(0), + invalid=MatchesAll( + HasLength(num_passes), + AllMatch(even()), + ), + spent=MatchesAll( + HasLength(num_passes), + AllMatch(odd()), + ), + issued=Equals(passes.spent | set(passes.invalid.keys())), + ), + ) @given(pass_counts()) def test_pass_through_too_few_passes(self, num_passes): @@ -188,7 +238,8 @@ class CallWithPassesTests(TestCase): """ passes = pass_factory() - def reject_passes(passes): + def reject_passes(group): + passes = group.passes _ValidationResult( valid=range(len(passes)), signature_check_failed=[], @@ -198,7 +249,7 @@ class CallWithPassesTests(TestCase): call_with_passes( reject_passes, num_passes, - passes.get, + partial(passes.get, u"message"), ), failed( AfterPreprocessing( @@ -213,3 +264,13 @@ class CallWithPassesTests(TestCase): ), ), ) + + # The passes in the group that was rejected are also returned for + # later use. + self.assertThat( + passes, + MatchesStructure( + spent=HasLength(0), + returned=HasLength(num_passes), + ), + )