diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py index ad2254f142101574d27bb1712e3f6af7bbf82b04..d62370ab4677b705dd55162feb88aac328063c93 100644 --- a/src/_zkapauthorizer/_storage_client.py +++ b/src/_zkapauthorizer/_storage_client.py @@ -84,10 +84,9 @@ class IncorrectStorageServerReference(Exception): ) -def replace_invalid_passes_with_new_passes(passes, more_passes_required): +def invalidate_rejected_passes(passes, more_passes_required): """ - Replace all rejected passes in the given pass group with new ones. Mark - any rejected passes as rejected. + Return a new ``IPassGroup`` with all rejected passes removed from it. :param IPassGroup passes: A group of passes, some of which may have been rejected. @@ -115,7 +114,14 @@ def replace_invalid_passes_with_new_passes(passes, more_passes_required): SIGNATURE_CHECK_FAILED.log(count=num_failed) rejected_passes, okay_passes = passes.split(more_passes_required.signature_check_failed) rejected_passes.mark_invalid(u"signature check failed") - return okay_passes.expand(len(more_passes_required.signature_check_failed)) + + # It would be great to just expand okay_passes right here. However, if + # that fails (eg because we don't have enough tokens remaining) then the + # caller will have a hard time figuring out which okay passes remain that + # it needs to reset. :/ So, instead, pass back the complete okay set. The + # caller can figure out by how much to expand it by considering its size + # and the original number of passes it requested. + return okay_passes @inline_callbacks @@ -142,28 +148,33 @@ def call_with_passes(method, num_passes, get_passes): that trigger a retry). """ with CALL_WITH_PASSES(count=num_passes): - passes = get_passes(num_passes) + pass_group = get_passes(num_passes) try: # Try and repeat as necessary. while True: try: - result = yield method(passes) + result = yield method(pass_group) except MorePassesRequired as e: - updated_passes = replace_invalid_passes_with_new_passes( - passes, + okay_pass_group = invalidate_rejected_passes( + pass_group, e, ) - if updated_passes is None: + if okay_pass_group is None: raise else: - passes = updated_passes + # Update the local in case we end up going to the + # except suite below. + pass_group = okay_pass_group + # Add the necessary number of new passes. This might + # fail if we don't have enough tokens. + pass_group = pass_group.expand(num_passes - len(pass_group.passes)) else: # Commit the spend of the passes when the operation finally succeeds. - passes.mark_spent() + pass_group.mark_spent() break except: # Something went wrong that we can't address with a retry. - passes.reset() + pass_group.reset() raise # Give the operation's result to the caller. diff --git a/src/_zkapauthorizer/tests/storage_common.py b/src/_zkapauthorizer/tests/storage_common.py index 32ef040d24e652320e3b4eaee69c5f13b2c99b91..b8dedf468df67f82a47c29868975631fcc147518 100644 --- a/src/_zkapauthorizer/tests/storage_common.py +++ b/src/_zkapauthorizer/tests/storage_common.py @@ -28,7 +28,6 @@ from struct import ( ) from itertools import ( - count, islice, ) @@ -56,6 +55,7 @@ from .privacypass import ( ) from ..model import ( + NotEnoughTokens, Pass, ) @@ -167,16 +167,19 @@ def whitebox_write_sparse_share(sharepath, version, size, leases, now): ) -def integer_passes(): +def integer_passes(limit): """ :return: Return a function which can be used to get a number of passes. The function accepts a unicode request-binding message and an integer number of passes. It returns a list of integers which serve as passes. Successive calls to the function return unique pass values. """ - counter = count(0) + counter = iter(range(limit)) def get_passes(message, num_passes): - return list(islice(counter, num_passes)) + result = list(islice(counter, num_passes)) + if len(result) < num_passes: + raise NotEnoughTokens() + return result return get_passes @@ -216,15 +219,13 @@ def privacypass_passes(signing_key): return partial(get_passes, signing_key=signing_key) -def pass_factory(get_passes=None): +def pass_factory(get_passes): """ Get a new factory for passes. :param (unicode -> int -> [pass]) get_passes: A function the factory can use to get new passes. """ - if get_passes is None: - get_passes = integer_passes() return _PassFactory(get_passes=get_passes) diff --git a/src/_zkapauthorizer/tests/test_storage_client.py b/src/_zkapauthorizer/tests/test_storage_client.py index 5fd6b6d4c343788e05e584ab7fe7078080b73122..39f0e6c544cf107a529535ce9b63cba3713fc43a 100644 --- a/src/_zkapauthorizer/tests/test_storage_client.py +++ b/src/_zkapauthorizer/tests/test_storage_client.py @@ -16,6 +16,10 @@ Tests for ``_zkapauthorizer._storage_client``. """ +from __future__ import ( + division, +) + from functools import ( partial, ) @@ -32,6 +36,7 @@ from testtools.matchers import ( HasLength, MatchesAll, AllMatch, + IsInstance, ) from testtools.twistedsupport import ( succeeded, @@ -59,7 +64,9 @@ from .strategies import ( from ..api import ( MorePassesRequired, ) - +from ..model import ( + NotEnoughTokens, +) from .._storage_client import ( call_with_passes, ) @@ -69,6 +76,7 @@ from .._storage_server import ( from .storage_common import ( pass_factory, + integer_passes, ) @@ -88,7 +96,7 @@ class CallWithPassesTests(TestCase): call_with_passes( lambda group: succeed(result), num_passes, - partial(pass_factory().get, u"message"), + partial(pass_factory(integer_passes(num_passes)).get, u"message"), ), succeeded(Is(result)), ) @@ -105,7 +113,7 @@ class CallWithPassesTests(TestCase): call_with_passes( lambda group: fail(result), num_passes, - partial(pass_factory().get, u"message"), + partial(pass_factory(integer_passes(num_passes)).get, u"message"), ), failed( AfterPreprocessing( @@ -122,7 +130,7 @@ class CallWithPassesTests(TestCase): provider containing ``num_passes`` created by the function passed for ``get_passes``. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes)) self.assertThat( call_with_passes( @@ -143,7 +151,7 @@ class CallWithPassesTests(TestCase): ``call_with_passes`` marks the passes it uses as spent if the operation succeeds. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes)) self.assertThat( call_with_passes( @@ -163,7 +171,7 @@ class CallWithPassesTests(TestCase): """ ``call_with_passes`` returns the passes it uses if the operation fails. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes)) self.assertThat( call_with_passes( @@ -188,7 +196,7 @@ class CallWithPassesTests(TestCase): of passes, still of length ```num_passes``, but without the passes which were rejected on the first try. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes * 2)) def reject_even_pass_values(group): passes = group.passes @@ -233,7 +241,7 @@ class CallWithPassesTests(TestCase): no passes have been marked as invalid. This happens if all passes given were valid but too fewer were given. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes)) def reject_passes(group): passes = group.passes @@ -271,3 +279,66 @@ class CallWithPassesTests(TestCase): returned=HasLength(num_passes), ), ) + + @given(pass_counts(), pass_counts()) + def test_not_enough_tokens_for_retry(self, num_passes, extras): + """ + When there are not enough tokens to successfully complete a retry with the + required number of passes, ``call_with_passes`` marks all passes + reported as invalid during its efforts as such and resets all other + passes it acquired. + """ + passes = pass_factory(integer_passes(num_passes + extras)) + rejected = [] + accepted = [] + + def reject_half_passes(group): + num = len(group.passes) + # Floor division will always short-change valid here, even for a + # group size of 1. Therefore there will always be some passes + # marked as invalid. + accept_indexes = range(num // 2) + reject_indexes = range(num // 2, num) + # Only keep this iteration's accepted passes. We'll want to see + # that the final iteration's passes are all returned. Passes from + # earlier iterations don't matter. + accepted[:] = list(group.passes[i] for i in accept_indexes) + # On the other hand, keep *all* rejected passes. They should all + # be marked as invalid and we want to make sure that's the case, + # no matter which iteration rejected them. + rejected.extend(group.passes[i] for i in reject_indexes) + _ValidationResult( + valid=accept_indexes, + signature_check_failed=reject_indexes, + ).raise_for(num) + + self.assertThat( + call_with_passes( + # Since half of every group is rejected, we'll eventually run + # out of passes no matter how many we start with. + reject_half_passes, + num_passes, + partial(passes.get, u"message"), + ), + failed( + AfterPreprocessing( + lambda f: f.value, + IsInstance(NotEnoughTokens), + ), + ), + ) + self.assertThat( + passes, + MatchesStructure( + # Whatever is left in the group when we run out of tokens must + # be returned. + returned=Equals(accepted), + in_use=HasLength(0), + invalid=AfterPreprocessing( + lambda invalid: invalid.keys(), + Equals(rejected), + ), + spent=HasLength(0), + issued=Equals(set(accepted + rejected)), + ), + )