diff --git a/src/_zkapauthorizer/_storage_server.py b/src/_zkapauthorizer/_storage_server.py index 7aa17c840a705ffb15a656a8f85befda42836781..73753fe6c4d710af5e633c6ee683c3f195a034f0 100644 --- a/src/_zkapauthorizer/_storage_server.py +++ b/src/_zkapauthorizer/_storage_server.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2019 PrivateStorage.io, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -101,6 +102,7 @@ from .storage_common import ( SLOT_HEADER_SIZE = 468 LEASE_TRAILER_SIZE = 4 +@attr.s class MorePassesRequired(Exception): """ Storage operations fail with ``MorePassesRequired`` when they are not @@ -111,19 +113,88 @@ class MorePassesRequired(Exception): ivar int required_count: The number of valid passes which must be presented for the operation to be authorized. + + :ivar list[int] signature_check_failed: Indices into the supplied list of + passes indicating passes which failed the signature check. + """ + valid_count = attr.ib() + required_count = attr.ib() + signature_check_failed = attr.ib() + + +@attr.s +class _ValidationResult(object): + """ + The result of validating a list of passes. + + :ivar list[int] valid: A list of indexes (into the validated list) of which + are acceptable. + + :ivar list[int] signature_check_failed: A list of indexes (into the + validated list) of passes which did not have a correct signature. """ - def __init__(self, valid_count, required_count): - self.valid_count = valid_count - self.required_count = required_count + valid = attr.ib() + signature_check_failed = attr.ib() + + @classmethod + def _is_invalid_pass(cls, message, pass_, signing_key): + """ + Cryptographically check the validity of a single pass. + + :param unicode message: The shared message for pass validation. + :param bytes pass_: The encoded pass to validate. + + :return bool: ``False`` (invalid) if the pass includes a valid + signature, ``True`` (valid) otherwise. + """ + assert isinstance(message, unicode), "message %r not unicode" % (message,) + assert isinstance(pass_, bytes), "pass %r not bytes" % (pass_,) + try: + preimage_base64, signature_base64 = pass_.split(b" ") + preimage = TokenPreimage.decode_base64(preimage_base64) + proposed_signature = VerificationSignature.decode_base64(signature_base64) + unblinded_token = signing_key.rederive_unblinded_token(preimage) + verification_key = unblinded_token.derive_verification_key_sha512() + invalid_pass = verification_key.invalid_sha512(proposed_signature, message.encode("utf-8")) + return invalid_pass + except Exception: + # It would be pretty nice to log something here, sometimes, I guess? + return True + + @classmethod + def validate_passes(cls, message, passes, signing_key): + """ + Check all of the given passes for validity. + + :param unicode message: The shared message for pass validation. + :param list[bytes] passes: The encoded passes to validate. + :param SigningKey signing_key: The signing key to use to check the passes. - def __repr__(self): - return "MorePassedRequired(valid_count={}, required_count={})".format( - self.valid_count, - self.required_count, + :return: An instance of this class describing the validation result + for all passes given. + """ + valid = [] + signature_check_failed = [] + for idx, pass_ in enumerate(passes): + if cls._is_invalid_pass(message, pass_, signing_key): + signature_check_failed.append(idx) + else: + valid.append(idx) + return cls( + valid=valid, + signature_check_failed=signature_check_failed, ) - def __str__(self): - return repr(self) + def raise_for(self, required_pass_count): + """ + :raise MorePassesRequired: Always raised with fields populated from this + instance and the given ``required_pass_count``. + """ + raise MorePassesRequired( + len(self.valid), + required_pass_count, + self.signature_check_failed, + ) class LeaseRenewalRequired(Exception): @@ -161,48 +232,6 @@ class ZKAPAuthorizerStorageServer(Referenceable): default=attr.Factory(partial(namedAny, "twisted.internet.reactor")), ) - def _is_invalid_pass(self, message, pass_): - """ - Cryptographically check the validity of a single pass. - - :param unicode message: The shared message for pass validation. - :param bytes pass_: The encoded pass to validate. - - :return bool: ``False`` (invalid) if the pass includes a valid - signature, ``True`` (valid) otherwise. - """ - assert isinstance(message, unicode), "message %r not unicode" % (message,) - assert isinstance(pass_, bytes), "pass %r not bytes" % (pass_,) - try: - preimage_base64, signature_base64 = pass_.split(b" ") - preimage = TokenPreimage.decode_base64(preimage_base64) - proposed_signature = VerificationSignature.decode_base64(signature_base64) - unblinded_token = self._signing_key.rederive_unblinded_token(preimage) - verification_key = unblinded_token.derive_verification_key_sha512() - invalid_pass = verification_key.invalid_sha512(proposed_signature, message.encode("utf-8")) - return invalid_pass - except Exception: - # It would be pretty nice to log something here, sometimes, I guess? - return True - - def _validate_passes(self, message, passes): - """ - Check all of the given passes for validity. - - :param unicode message: The shared message for pass validation. - :param list[bytes] passes: The encoded passes to validate. - - :return list[bytes]: The passes which are found to be valid. - """ - result = list( - pass_ - for pass_ - in passes - if not self._is_invalid_pass(message, pass_) - ) - # print("{}: {} passes, {} valid".format(message, len(passes), len(result))) - return result - def remote_get_version(self): """ Pass-through without pass check to allow clients to learn about our @@ -215,13 +244,14 @@ class ZKAPAuthorizerStorageServer(Referenceable): Pass-through after a pass check to ensure that clients can only allocate storage for immutable shares if they present valid passes. """ - valid_passes = self._validate_passes( + validation = _ValidationResult.validate_passes( allocate_buckets_message(storage_index), passes, + self._signing_key, ) check_pass_quantity_for_write( self._pass_value, - len(valid_passes), + validation, sharenums, allocated_size, ) @@ -247,12 +277,15 @@ class ZKAPAuthorizerStorageServer(Referenceable): Pass-through after a pass check to ensure clients can only extend the duration of share storage if they present valid passes. """ - # print("server add_lease({}, {!r})".format(len(passes), storage_index)) - valid_passes = self._validate_passes(add_lease_message(storage_index), passes) + validation = _ValidationResult.validate_passes( + add_lease_message(storage_index), + passes, + self._signing_key, + ) check_pass_quantity_for_lease( self._pass_value, storage_index, - valid_passes, + validation, self._original, ) return self._original.remote_add_lease(storage_index, *a, **kw) @@ -262,7 +295,11 @@ class ZKAPAuthorizerStorageServer(Referenceable): Pass-through after a pass check to ensure clients can only extend the duration of share storage if they present valid passes. """ - valid_passes = self._validate_passes(renew_lease_message(storage_index), passes) + valid_passes = _ValidationResult.validate_passes( + renew_lease_message(storage_index), + passes, + self._signing_key, + ) check_pass_quantity_for_lease( self._pass_value, storage_index, @@ -315,9 +352,10 @@ class ZKAPAuthorizerStorageServer(Referenceable): # necessary lease as part of the same operation. This must be # supported because there is no separate protocol action to # *create* a slot. Clients just begin writing to it. - valid_passes = self._validate_passes( + validation = _ValidationResult.validate_passes( slot_testv_and_readv_and_writev_message(storage_index), passes, + self._signing_key, ) if has_active_lease(self._original, storage_index, self._clock.seconds()): # Some of the storage is paid for already. @@ -337,8 +375,8 @@ class ZKAPAuthorizerStorageServer(Referenceable): current_sizes, tw_vectors, ) - if required_new_passes > len(valid_passes): - raise MorePassesRequired(len(valid_passes), required_new_passes) + if required_new_passes > len(validation.valid): + validation.raise_for(required_new_passes) # Skip over the remotely exposed method and jump to the underlying # implementation which accepts one additional parameter that we know @@ -382,12 +420,15 @@ def has_active_lease(storage_server, storage_index, now): ) -def check_pass_quantity(pass_value, valid_count, share_sizes): +def check_pass_quantity(pass_value, validation, share_sizes): """ Check that the given number of passes is sufficient to cover leases for one period for shares of the given sizes. - :param int valid_count: The number of passes. + :param int pass_value: The value of a single pass in bytes × lease periods. + + :param _ValidationResult validation: The validating results for a list of passes. + :param list[int] share_sizes: The sizes of the shares for which the lease is being created. @@ -397,17 +438,23 @@ def check_pass_quantity(pass_value, valid_count, share_sizes): :return: ``None`` if the given number of passes is sufficient. """ required_pass_count = required_passes(pass_value, share_sizes) - if valid_count < required_pass_count: - raise MorePassesRequired( - valid_count, - required_pass_count, - ) + if len(validation.valid) < required_pass_count: + validation.raise_for(required_pass_count) -def check_pass_quantity_for_lease(pass_value, storage_index, valid_passes, storage_server): +def check_pass_quantity_for_lease(pass_value, storage_index, validation, storage_server): """ Check that the given number of passes is sufficient to add or renew a lease for one period for the given storage index. + + :param int pass_value: The value of a single pass in bytes × lease periods. + + :param _ValidationResult validation: The validating results for a list of passes. + + :raise MorePassesRequired: If the given number of passes is too few for + the share sizes at the given storage index. + + :return: ``None`` if the given number of passes is sufficient. """ allocated_sizes = dict( get_share_sizes( @@ -416,16 +463,20 @@ def check_pass_quantity_for_lease(pass_value, storage_index, valid_passes, stora list(get_all_share_numbers(storage_server, storage_index)), ), ).values() - check_pass_quantity(pass_value, len(valid_passes), allocated_sizes) + check_pass_quantity(pass_value, validation, allocated_sizes) -def check_pass_quantity_for_write(pass_value, valid_count, sharenums, allocated_size): +def check_pass_quantity_for_write(pass_value, validation, sharenums, allocated_size): """ Determine if the given number of valid passes is sufficient for an attempted write. - :param int valid_count: The number of valid passes to consider. + :param int pass_value: The value of a single pass in bytes × lease periods. + + :param _ValidationResult validation: The validating results for a list of passes. + :param set[int] sharenums: The shares being written to. + :param int allocated_size: The size of each share. :raise MorePassedRequired: If the number of valid passes given is too @@ -433,7 +484,7 @@ def check_pass_quantity_for_write(pass_value, valid_count, sharenums, allocated_ :return: ``None`` if the number of valid passes given is sufficient. """ - check_pass_quantity(pass_value, valid_count, [allocated_size] * len(sharenums)) + check_pass_quantity(pass_value, validation, [allocated_size] * len(sharenums)) def get_all_share_paths(storage_server, storage_index): diff --git a/src/_zkapauthorizer/tests/test_storage_server.py b/src/_zkapauthorizer/tests/test_storage_server.py index 88ae5a1f1294bc0679787942f7432aa7e08d2291..1eddf1c2e2c173eda5ad4209c5a5397cf146dccb 100644 --- a/src/_zkapauthorizer/tests/test_storage_server.py +++ b/src/_zkapauthorizer/tests/test_storage_server.py @@ -33,9 +33,6 @@ from testtools import ( ) from testtools.matchers import ( Equals, - AfterPreprocessing, - MatchesStructure, - raises, ) from hypothesis import ( given, @@ -70,6 +67,9 @@ from .common import ( from .privacypass import ( make_passes, ) +from .matchers import ( + raises, +) from .strategies import ( zkaps, sizes, @@ -101,35 +101,24 @@ from ..storage_common import ( get_required_new_passes_for_mutable_write, summarize, ) +from .._storage_server import ( + _ValidationResult, +) -class PassValidationTests(TestCase): + +class ValidationResultTests(TestCase): """ - Tests for pass validation performed by ``ZKAPAuthorizerStorageServer``. + Tests for ``_ValidationResult``. """ - pass_value = 128 * 1024 - - @skipIf(platform.isWindows(), "Storage server is not supported on Windows") def setUp(self): - super(PassValidationTests, self).setUp() - self.clock = Clock() - # anonymous_storage_server uses time.time() so get our Clock close to - # the same time so we can do lease expiration calculations more - # easily. - self.clock.advance(time()) - self.anonymous_storage_server = self.useFixture(AnonymousStorageServer()).storage_server + super(ValidationResultTests, self).setUp() self.signing_key = random_signing_key() - self.storage_server = ZKAPAuthorizerStorageServer( - self.anonymous_storage_server, - self.pass_value, - self.signing_key, - self.clock, - ) @given(integers(min_value=0, max_value=64), lists(zkaps(), max_size=64)) def test_validation_result(self, valid_count, invalid_passes): """ - ``_get_valid_passes`` returns the number of cryptographically valid passes - in the list passed to it. + ``validate_passes`` returns a ``_ValidationResult`` instance which + describes the valid and invalid passes. """ message = u"hello world" valid_passes = make_passes( @@ -145,13 +134,53 @@ class PassValidationTests(TestCase): shuffle(all_passes) self.assertThat( - self.storage_server._validate_passes(message, all_passes), - AfterPreprocessing( - set, - Equals(set(valid_passes)), + _ValidationResult.validate_passes( + message, + all_passes, + self.signing_key, + ), + Equals( + _ValidationResult( + valid=list( + idx + for (idx, pass_) + in enumerate(all_passes) + if pass_ in valid_passes + ), + signature_check_failed=list( + idx + for (idx, pass_) + in enumerate(all_passes) + if pass_ not in valid_passes + ), + ), ), ) + +class PassValidationTests(TestCase): + """ + Tests for pass validation performed by ``ZKAPAuthorizerStorageServer``. + """ + pass_value = 128 * 1024 + + @skipIf(platform.isWindows(), "Storage server is not supported on Windows") + def setUp(self): + super(PassValidationTests, self).setUp() + self.clock = Clock() + # anonymous_storage_server uses time.time() so get our Clock close to + # the same time so we can do lease expiration calculations more + # easily. + self.clock.advance(time()) + self.anonymous_storage_server = self.useFixture(AnonymousStorageServer()).storage_server + self.signing_key = random_signing_key() + self.storage_server = ZKAPAuthorizerStorageServer( + self.anonymous_storage_server, + self.pass_value, + self.signing_key, + self.clock, + ) + def test_allocate_buckets_fails_without_enough_passes(self): """ ``remote_allocate_buckets`` fails with ``MorePassesRequired`` if it is @@ -231,8 +260,14 @@ class PassValidationTests(TestCase): result = mutable_write() except MorePassesRequired as e: self.assertThat( - e.required_count, - Equals(1), + e, + Equals( + MorePassesRequired( + valid_count=0, + required_count=1, + signature_check_failed=[], + ), + ), ) else: self.fail("expected MorePassesRequired, got {}".format(result)) @@ -329,9 +364,12 @@ class PassValidationTests(TestCase): except MorePassesRequired as e: self.assertThat( e, - MatchesStructure( - valid_count=Equals(0), - required_count=Equals(1), + Equals( + MorePassesRequired( + valid_count=0, + required_count=1, + signature_check_failed=[], + ), ), ) else: @@ -423,9 +461,12 @@ class PassValidationTests(TestCase): except MorePassesRequired as e: self.assertThat( e, - MatchesStructure( - valid_count=Equals(len(passes)), - required_count=Equals(required_count), + Equals( + MorePassesRequired( + valid_count=len(passes), + required_count=required_count, + signature_check_failed=[], + ), ), ) else: