diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py index 6559b732e6a1bcd67396dea3d561162f2ce31c5c..63e31d9061873fc88676efe0dae7cdccb8b95876 100644 --- a/src/_zkapauthorizer/_storage_client.py +++ b/src/_zkapauthorizer/_storage_client.py @@ -20,6 +20,10 @@ This is the client part of a storage access protocol. The server part is implemented in ``_storage_server.py``. """ +from functools import ( + partial, +) + import attr from zope.interface import ( @@ -28,12 +32,14 @@ from zope.interface import ( from twisted.internet.defer import ( inlineCallbacks, returnValue, + maybeDeferred, ) from allmydata.interfaces import ( IStorageServer, ) from .storage_common import ( + MorePassesRequired, pass_value_attribute, required_passes, allocate_buckets_message, @@ -64,6 +70,42 @@ class IncorrectStorageServerReference(Exception): ) +def call_with_passes(method, num_passes, get_passes): + """ + Call a method, passing the requested number of passes as the first + argument, and try again if the call fails with an error related to some of + the passes being rejected. + + :param method: A callable which accepts a list of encoded passes as its + only argument and returns a ``Deferred``. If the ``Deferred`` fires + with ``MorePassesRequired`` then the invalid passes will be discarded + and replacement passes will be requested for a new call of ``method``. + This will repeat until no passes remain, the method succeeds, or the + methods fails in a different way. + + :param int num_passes: The number of passes to pass to the call. + + :param (unicode -> int -> [bytes]) get_passes: A function for getting + passes. + + :return: Whatever ``method`` returns. + """ + def get_more_passes(reason): + reason.trap(MorePassesRequired) + new_passes = get_passes(len(reason.value.signature_check_failed)) + for idx, new_pass in zip(reason.value.signature_check_failed, new_passes): + passes[idx] = new_pass + return go(passes) + + def go(passes): + d = maybeDeferred(method, passes) + d.addErrback(get_more_passes) + return d + + passes = get_passes(num_passes) + return go(passes) + + @implementer(IStorageServer) @attr.s class ZKAPAuthorizerStorageClient(object): @@ -144,18 +186,24 @@ class ZKAPAuthorizerStorageClient(object): allocated_size, canary, ): - return self._rref.callRemote( - "allocate_buckets", - self._get_encoded_passes( - allocate_buckets_message(storage_index), - required_passes(self._pass_value, [allocated_size] * len(sharenums)), + # XXX _rref is a property and reading it does some stuff that needs to + # happen before we get passes. Read it eagerly here. Blech. + rref = self._rref + message = allocate_buckets_message(storage_index) + num_passes = required_passes(self._pass_value, [allocated_size] * len(sharenums)) + return call_with_passes( + lambda passes: rref.callRemote( + "allocate_buckets", + passes, + storage_index, + renew_secret, + cancel_secret, + sharenums, + allocated_size, + canary, ), - storage_index, - renew_secret, - cancel_secret, - sharenums, - allocated_size, - canary, + num_passes, + partial(self._get_encoded_passes, message), ) def get_buckets( diff --git a/src/_zkapauthorizer/storage_common.py b/src/_zkapauthorizer/storage_common.py index 1a52b4d26b855baefe3802b3203283fb0a7306f6..80707f226b892c96cfa5fefd278e68b9146dc7e1 100644 --- a/src/_zkapauthorizer/storage_common.py +++ b/src/_zkapauthorizer/storage_common.py @@ -30,7 +30,7 @@ from .validators import ( greater_than, ) -@attr.s +@attr.s(frozen=True) class MorePassesRequired(Exception): """ Storage operations fail with ``MorePassesRequired`` when they are not @@ -47,7 +47,7 @@ class MorePassesRequired(Exception): """ valid_count = attr.ib() required_count = attr.ib() - signature_check_failed = attr.ib() + signature_check_failed = attr.ib(converter=frozenset) def _message_maker(label): diff --git a/src/_zkapauthorizer/tests/test_storage_client.py b/src/_zkapauthorizer/tests/test_storage_client.py new file mode 100644 index 0000000000000000000000000000000000000000..aea80a5d5bfa7401707856b99936653dd14c8ac0 --- /dev/null +++ b/src/_zkapauthorizer/tests/test_storage_client.py @@ -0,0 +1,176 @@ +# Copyright 2020 PrivateStorage.io, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for ``_zkapauthorizer._storage_client``. +""" + +import attr + +from itertools import ( + count, + islice, +) + +from testtools import ( + TestCase, +) +from testtools.matchers import ( + Always, + Is, + Equals, + AfterPreprocessing, +) +from testtools.twistedsupport import ( + succeeded, + failed, +) + +from hypothesis import ( + given, +) +from hypothesis.strategies import ( + integers, +) + +from twisted.internet.defer import ( + succeed, + fail, +) + +from .._storage_client import ( + call_with_passes, +) +from .._storage_server import ( + _ValidationResult, +) + +def pass_counts(): + return integers(min_value=1, max_value=2 ** 8) + + +def pass_factory(): + return _PassFactory() + +@attr.s +class _PassFactory(object): + """ + A stateful pass issuer. + + :ivar list spent: All of the passes ever issued. + + :ivar _fountain: A counter for making each new pass issued unique. + """ + spent = attr.ib(default=attr.Factory(list)) + + _fountain = attr.ib(default=attr.Factory(count)) + + def get(self, num_passes): + passes = list(islice(self._fountain, num_passes)) + self.spent.extend(passes) + return passes + + +class CallWithPassesTests(TestCase): + """ + Tests for ``call_with_passes``. + """ + @given(pass_counts()) + def test_success_result(self, num_passes): + """ + ``call_with_passes`` returns a ``Deferred`` that fires with the same + success result as that of the ``Deferred`` returned by the method + passed in. + """ + result = object() + self.assertThat( + call_with_passes( + lambda passes: succeed(result), + num_passes, + pass_factory().get, + ), + succeeded(Is(result)), + ) + + @given(pass_counts()) + def test_failure_result(self, num_passes): + """ + ``call_with_passes`` returns a ``Deferred`` that fires with the same + failure result as that of the ``Deferred`` returned by the method + passed in if that failure is not a ``MorePassesRequired``. + """ + result = Exception() + self.assertThat( + call_with_passes( + lambda passes: fail(result), + num_passes, + pass_factory().get, + ), + failed( + AfterPreprocessing( + lambda f: f.value, + Is(result), + ), + ), + ) + + @given(pass_counts()) + def test_passes(self, num_passes): + """ + ``call_with_passes`` calls the given method with a list of passes + containing ``num_passes`` created by the function passed for + ``get_passes``. + """ + passes = pass_factory() + + self.assertThat( + call_with_passes( + lambda passes: succeed(passes), + num_passes, + passes.get, + ), + succeeded( + Equals( + passes.spent, + ), + ), + ) + + @given(pass_counts()) + def test_retry_on_rejected_passes(self, num_passes): + """ + ``call_with_passes`` tries calling the given method again with a new list + of passes, still of length ```num_passes``, but without the passes + which were rejected on the first try. + """ + passes = pass_factory() + + def reject_even_pass_values(passes): + good_passes = list(idx for (idx, p) in enumerate(passes) if p % 2) + bad_passes = list(idx for (idx, p) in enumerate(passes) if idx not in good_passes) + if len(good_passes) < num_passes: + _ValidationResult( + valid=good_passes, + signature_check_failed=bad_passes, + ).raise_for(num_passes) + return None + + self.assertThat( + call_with_passes( + reject_even_pass_values, + num_passes, + passes.get, + ), + succeeded(Always()), + ) diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py index 612a8538f5704e5cf68a135e5f17be3cad3769d7..3ffca8ae76d7b3e8ec69b5742e271a53f4b6cdab 100644 --- a/src/_zkapauthorizer/tests/test_storage_protocol.py +++ b/src/_zkapauthorizer/tests/test_storage_protocol.py @@ -52,6 +52,7 @@ from hypothesis.strategies import ( lists, tuples, integers, + data as data_strategy, ) from twisted.python.runtime import ( @@ -117,6 +118,7 @@ from ..api import ( ) from ..storage_common import ( slot_testv_and_readv_and_writev_message, + allocate_buckets_message, get_implied_data_length, required_passes, ) @@ -165,6 +167,28 @@ class RequiredPassesTests(TestCase): ) +def get_passes(message, count, signing_key): + """ + :param unicode message: Request-binding message for PrivacyPass. + + :param int count: The number of passes to get. + + :param SigningKEy signing_key: The key to use to sign the passes. + + :return list[Pass]: ``count`` new random passes signed with the given key + and bound to the given message. + """ + return list( + Pass(*pass_.split(u" ")) + for pass_ + in make_passes( + signing_key, + message, + list(RandomToken.create() for n in range(count)), + ) + ) + + class ShareTests(TestCase): """ Tests for interaction with shares. @@ -183,17 +207,10 @@ class ShareTests(TestCase): self.signing_key = random_signing_key() self.spent_passes = 0 - def get_passes(message, count): + def counting_get_passes(message, count): self.spent_passes += count - return list( - Pass(*pass_.split(u" ")) - for pass_ - in make_passes( - self.signing_key, - message, - list(RandomToken.create() for n in range(count)), - ) - ) + return get_passes(message, count, self.signing_key) + self.server = ZKAPAuthorizerStorageServer( self.anonymous_storage_server, self.pass_value, @@ -203,7 +220,7 @@ class ShareTests(TestCase): self.client = ZKAPAuthorizerStorageClient( self.pass_value, get_rref=lambda: self.local_remote_server, - get_passes=get_passes, + get_passes=counting_get_passes, ) def test_get_version(self): @@ -222,8 +239,9 @@ class ShareTests(TestCase): cancel_secret=lease_cancel_secrets(), sharenums=sharenum_sets(), size=sizes(), + data=data_strategy(), ) - def test_rejected_passes_reported(self, storage_index, renew_secret, cancel_secret, sharenums, size): + def test_rejected_passes_reported(self, storage_index, renew_secret, cancel_secret, sharenums, size, data): """ Any passes rejected by the storage server are reported with a ``MorePassesRequired`` exception sent to the client. @@ -232,18 +250,57 @@ class ShareTests(TestCase): # up between iterations. cleanup_storage_server(self.anonymous_storage_server) - # Break our infinite pass factory by replacing the expected key with a - # new one. Now the passes are mis-signed as far as the server is - # concerned. The clunky way we control pass generation means it's - # hard to have anything but an all-or-nothing test. Perhaps some - # future refactoring will let us exercise a mix of passes with valid - # and invalid signatures. - self.signing_key = random_signing_key() - num_passes = required_passes(self.pass_value, [size] * len(sharenums)) + # Pick some passes to mess with. + bad_pass_indexes = data.draw( + lists( + integers( + min_value=0, + max_value=num_passes - 1, + ), + min_size=1, + max_size=num_passes, + unique=True, + ), + ) + + # Make some passes with a key untrusted by the server. + bad_passes = get_passes( + allocate_buckets_message(storage_index), + len(bad_pass_indexes), + random_signing_key(), + ) + + # Make some passes with a key trusted by the server. + good_passes = get_passes( + allocate_buckets_message(storage_index), + num_passes - len(bad_passes), + self.signing_key, + ) + + all_passes = [] + for i in range(num_passes): + if i in bad_pass_indexes: + all_passes.append(bad_passes.pop()) + else: + all_passes.append(good_passes.pop()) + + # Sanity checks + self.assertThat(bad_passes, Equals([])) + self.assertThat(good_passes, Equals([])) + self.assertThat(all_passes, HasLength(num_passes)) + self.assertThat( - self.client.allocate_buckets( + # Bypass the client handling of MorePassesRequired so we can see + # it. + self.local_remote_server.callRemote( + "allocate_buckets", + list( + pass_.pass_text.encode("ascii") + for pass_ + in all_passes + ), storage_index, renew_secret, cancel_secret, @@ -256,9 +313,9 @@ class ShareTests(TestCase): lambda f: f.value, Equals( MorePassesRequired( - valid_count=0, + valid_count=num_passes - len(bad_pass_indexes), required_count=num_passes, - signature_check_failed=range(num_passes), + signature_check_failed=bad_pass_indexes, ), ), ),