diff --git a/src/_zkapauthorizer/_storage_server.py b/src/_zkapauthorizer/_storage_server.py index e40a00d1ea6731a55cea4916048543e4dd843d99..ae8297be304718e74931723425d7848ec8412fb9 100644 --- a/src/_zkapauthorizer/_storage_server.py +++ b/src/_zkapauthorizer/_storage_server.py @@ -24,6 +24,10 @@ from __future__ import ( absolute_import, ) +from math import ( + ceil, +) + import attr from attr.validators import ( provides, @@ -58,6 +62,12 @@ from .storage_common import ( slot_testv_and_readv_and_writev_message, ) +class MorePassesRequired(Exception): + def __init__(self, valid_count, required_count): + self.valid_count = valid_count + self.required_count = required_count + + @implementer_only(RITokenAuthorizedStorageServer, IReferenceable, IRemotelyCallable) # It would be great to use `frozen=True` (value-based hashing) instead of # `cmp=False` (identity based hashing) but Referenceable wants to set some @@ -68,9 +78,25 @@ class ZKAPAuthorizerStorageServer(Referenceable): A class which wraps an ``RIStorageServer`` to insert pass validity checks before allowing certain functionality. """ + # The number of bytes we're willing to store for a lease period for each + # pass submitted. + _BYTES_PER_PASS = 128 * 1024 + _original = attr.ib(validator=provides(RIStorageServer)) _signing_key = attr.ib(validator=instance_of(SigningKey)) + def _required_passes(self, stored_bytes): + """ + Calculate the number of passes that are required to store ``stored_bytes`` + for one lease period. + + :param int stored_bytes: A number of bytes of storage for which to + calculate a price in passes. + + :return int: The number of passes. + """ + return int(ceil(stored_bytes / self._BYTES_PER_PASS)) + def _is_invalid_pass(self, message, pass_): """ Check the validity of a single pass. @@ -117,13 +143,30 @@ class ZKAPAuthorizerStorageServer(Referenceable): """ return self._original.remote_get_version() - def remote_allocate_buckets(self, passes, storage_index, *a, **kw): + def remote_allocate_buckets(self, passes, storage_index, renew_secret, cancel_secret, sharenums, allocated_size, canary): """ Pass-through after a pass check to ensure that clients can only allocate storage for immutable shares if they present valid passes. """ - self._validate_passes(allocate_buckets_message(storage_index), passes) - return self._original.remote_allocate_buckets(storage_index, *a, **kw) + valid_passes = self._validate_passes( + allocate_buckets_message(storage_index), + passes, + ) + required_passes = self._required_passes(len(sharenums) * allocated_size) + if len(valid_passes) < required_passes: + raise MorePassesRequired( + len(valid_passes), + required_passes, + ) + + return self._original.remote_allocate_buckets( + storage_index, + renew_secret, + cancel_secret, + sharenums, + allocated_size, + canary, + ) def remote_get_buckets(self, storage_index): """ diff --git a/src/_zkapauthorizer/api.py b/src/_zkapauthorizer/api.py index 81f47520ce66ffadb55a41fb3885d1cd50a7947c..8b89611ba10e3ee3833f2d5c7c45d1d8365ee320 100644 --- a/src/_zkapauthorizer/api.py +++ b/src/_zkapauthorizer/api.py @@ -13,12 +13,14 @@ # limitations under the License. __all__ = [ + "MorePassesRequired", "ZKAPAuthorizerStorageServer", "ZKAPAuthorizerStorageClient", "ZKAPAuthorizer", ] from ._storage_server import ( + MorePassesRequired, ZKAPAuthorizerStorageServer, ) from ._storage_client import ( diff --git a/src/_zkapauthorizer/tests/test_storage_server.py b/src/_zkapauthorizer/tests/test_storage_server.py index 7be247f34efa305ee22e4d7b5120b7679592828a..f2cc0b85e133587b69f79816afd50e8a3e56e9ce 100644 --- a/src/_zkapauthorizer/tests/test_storage_server.py +++ b/src/_zkapauthorizer/tests/test_storage_server.py @@ -1,5 +1,6 @@ from __future__ import ( absolute_import, + division, ) from random import ( @@ -11,6 +12,7 @@ from testtools import ( from testtools.matchers import ( Equals, AfterPreprocessing, + raises, ) from hypothesis import ( given, @@ -34,8 +36,11 @@ from .fixtures import ( ) from ..api import ( ZKAPAuthorizerStorageServer, + MorePassesRequired, +) +from ..storage_common import ( + allocate_buckets_message, ) - def make_passes(signing_key, for_message, random_tokens): blinded_tokens = list( @@ -121,3 +126,45 @@ class PassValidationTests(TestCase): Equals(set(valid_passes)), ), ) + + + def test_allocate_buckets_fails_without_enough_passes(self): + """ + ``remote_allocate_buckets`` fails with ``MorePassesRequired`` if it is + passed fewer passes than it requires for the amount of data to be + stored. + """ + required_passes = 2 + bytes_per_pass = self.storage_server._BYTES_PER_PASS + share_nums = {3, 7} + allocated_size = int((required_passes * bytes_per_pass) / len(share_nums)) + storage_index = b"0123456789" + renew_secret = b"x" * 32 + cancel_secret = b"y" * 32 + valid_passes = make_passes( + self.signing_key, + allocate_buckets_message(storage_index), + list(RandomToken.create() for i in range(required_passes - 1)), + ) + + allocate_buckets = lambda: self.storage_server.doRemoteCall( + "allocate_buckets", + (valid_passes, + storage_index, + renew_secret, + cancel_secret, + share_nums, + allocated_size, + FakeRemoteReference(), + ), + {}, + ) + self.assertThat( + allocate_buckets, + raises(MorePassesRequired), + ) + + +class FakeRemoteReference(object): + def notifyOnDisconnect(self, callback, *args, **kwargs): + pass