From 4bc1ec10b9746f39ac1e7deb162f70da7248fa54 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Tue, 8 Oct 2019 20:46:48 -0400
Subject: [PATCH] Use ``passes`` to authorized ``allocate_buckets``.

If enough valid passes are supplied for the amount of storage requested then
allow the operation to succeed.  Otherwise, fail with an error.
---
 src/_zkapauthorizer/_storage_client.py        |  7 +-
 src/_zkapauthorizer/_storage_server.py        | 37 ++++------
 src/_zkapauthorizer/storage_common.py         | 24 +++++++
 src/_zkapauthorizer/tests/privacypass.py      | 55 +++++++++++++++
 .../tests/test_storage_protocol.py            | 29 ++++----
 .../tests/test_storage_server.py              | 67 +++----------------
 6 files changed, 123 insertions(+), 96 deletions(-)
 create mode 100644 src/_zkapauthorizer/tests/privacypass.py

diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py
index 364aeac..7eebdc5 100644
--- a/src/_zkapauthorizer/_storage_client.py
+++ b/src/_zkapauthorizer/_storage_client.py
@@ -31,6 +31,8 @@ from allmydata.interfaces import (
 )
 
 from .storage_common import (
+    BYTES_PER_PASS,
+    required_passes,
     allocate_buckets_message,
     add_lease_message,
     renew_lease_message,
@@ -99,7 +101,10 @@ class ZKAPAuthorizerStorageClient(object):
     ):
         return self._rref.callRemote(
             "allocate_buckets",
-            self._get_encoded_passes(allocate_buckets_message(storage_index), 1),
+            self._get_encoded_passes(
+                allocate_buckets_message(storage_index),
+                required_passes(BYTES_PER_PASS, sharenums, allocated_size),
+            ),
             storage_index,
             renew_secret,
             cancel_secret,
diff --git a/src/_zkapauthorizer/_storage_server.py b/src/_zkapauthorizer/_storage_server.py
index ae8297b..6d03290 100644
--- a/src/_zkapauthorizer/_storage_server.py
+++ b/src/_zkapauthorizer/_storage_server.py
@@ -24,10 +24,6 @@ from __future__ import (
     absolute_import,
 )
 
-from math import (
-    ceil,
-)
-
 import attr
 from attr.validators import (
     provides,
@@ -56,6 +52,8 @@ from .foolscap import (
     RITokenAuthorizedStorageServer,
 )
 from .storage_common import (
+    BYTES_PER_PASS,
+    required_passes,
     allocate_buckets_message,
     add_lease_message,
     renew_lease_message,
@@ -67,6 +65,15 @@ class MorePassesRequired(Exception):
         self.valid_count = valid_count
         self.required_count = required_count
 
+    def __repr__(self):
+        return "MorePassedRequired(valid_count={}, required_count={})".format(
+            self.valid_count,
+            self.required_count,
+        )
+
+    def __str__(self):
+        return repr(self)
+
 
 @implementer_only(RITokenAuthorizedStorageServer, IReferenceable, IRemotelyCallable)
 # It would be great to use `frozen=True` (value-based hashing) instead of
@@ -78,25 +85,9 @@ class ZKAPAuthorizerStorageServer(Referenceable):
     A class which wraps an ``RIStorageServer`` to insert pass validity checks
     before allowing certain functionality.
     """
-    # The number of bytes we're willing to store for a lease period for each
-    # pass submitted.
-    _BYTES_PER_PASS = 128 * 1024
-
     _original = attr.ib(validator=provides(RIStorageServer))
     _signing_key = attr.ib(validator=instance_of(SigningKey))
 
-    def _required_passes(self, stored_bytes):
-        """
-        Calculate the number of passes that are required to store ``stored_bytes``
-        for one lease period.
-
-        :param int stored_bytes: A number of bytes of storage for which to
-            calculate a price in passes.
-
-        :return int: The number of passes.
-        """
-        return int(ceil(stored_bytes / self._BYTES_PER_PASS))
-
     def _is_invalid_pass(self, message, pass_):
         """
         Check the validity of a single pass.
@@ -152,11 +143,11 @@ class ZKAPAuthorizerStorageServer(Referenceable):
             allocate_buckets_message(storage_index),
             passes,
         )
-        required_passes = self._required_passes(len(sharenums) * allocated_size)
-        if len(valid_passes) < required_passes:
+        required_pass_count = required_passes(BYTES_PER_PASS, sharenums, allocated_size)
+        if len(valid_passes) < required_pass_count:
             raise MorePassesRequired(
                 len(valid_passes),
-                required_passes,
+                required_pass_count,
             )
 
         return self._original.remote_allocate_buckets(
diff --git a/src/_zkapauthorizer/storage_common.py b/src/_zkapauthorizer/storage_common.py
index c22bbc2..955eb59 100644
--- a/src/_zkapauthorizer/storage_common.py
+++ b/src/_zkapauthorizer/storage_common.py
@@ -3,6 +3,10 @@ from base64 import (
     b64encode,
 )
 
+from math import (
+    ceil,
+)
+
 def _message_maker(label):
     def make_message(storage_index):
         return u"{label} {storage_index}".format(
@@ -15,3 +19,23 @@ allocate_buckets_message = _message_maker(u"allocate_buckets")
 add_lease_message = _message_maker(u"add_lease")
 renew_lease_message = _message_maker(u"renew_lease")
 slot_testv_and_readv_and_writev_message = _message_maker(u"slot_testv_and_readv_and_writev")
+
+# The number of bytes we're willing to store for a lease period for each pass
+# submitted.
+BYTES_PER_PASS = 128 * 1024
+
+def required_passes(bytes_per_pass, share_nums, share_size):
+    """
+    Calculate the number of passes that are required to store ``stored_bytes``
+    for one lease period.
+
+    :param int stored_bytes: A number of bytes of storage for which to
+        calculate a price in passes.
+
+    :return int: The number of passes.
+    """
+    return int(
+        ceil(
+            (len(share_nums) * share_size) / bytes_per_pass,
+        ),
+    )
diff --git a/src/_zkapauthorizer/tests/privacypass.py b/src/_zkapauthorizer/tests/privacypass.py
new file mode 100644
index 0000000..9b46fe3
--- /dev/null
+++ b/src/_zkapauthorizer/tests/privacypass.py
@@ -0,0 +1,55 @@
+from __future__ import (
+    absolute_import,
+)
+
+from privacypass import (
+    BatchDLEQProof,
+    PublicKey,
+)
+
+def make_passes(signing_key, for_message, random_tokens):
+    blinded_tokens = list(
+        token.blind()
+        for token
+        in random_tokens
+    )
+    signatures = list(
+        signing_key.sign(blinded_token)
+        for blinded_token
+        in blinded_tokens
+    )
+    proof = BatchDLEQProof.create(
+        signing_key,
+        blinded_tokens,
+        signatures,
+    )
+    unblinded_signatures = proof.invalid_or_unblind(
+        random_tokens,
+        blinded_tokens,
+        signatures,
+        PublicKey.from_signing_key(signing_key),
+    )
+    preimages = list(
+        unblinded_signature.preimage()
+        for unblinded_signature
+        in unblinded_signatures
+    )
+    verification_keys = list(
+        unblinded_signature.derive_verification_key_sha512()
+        for unblinded_signature
+        in unblinded_signatures
+    )
+    message_signatures = list(
+        verification_key.sign_sha512(for_message.encode("utf-8"))
+        for verification_key
+        in verification_keys
+    )
+    passes = list(
+        u"{} {}".format(
+            preimage.encode_base64().decode("ascii"),
+            signature.encode_base64().decode("ascii"),
+        ).encode("ascii")
+        for (preimage, signature)
+        in zip(preimages, message_signatures)
+    )
+    return passes
diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py
index a32fa40..98320fd 100644
--- a/src/_zkapauthorizer/tests/test_storage_protocol.py
+++ b/src/_zkapauthorizer/tests/test_storage_protocol.py
@@ -65,9 +65,13 @@ from foolscap.referenceable import (
 )
 
 from privacypass import (
+    RandomToken,
     random_signing_key,
 )
 
+from .privacypass import (
+    make_passes,
+)
 from .strategies import (
     storage_indexes,
     lease_renew_secrets,
@@ -90,15 +94,12 @@ from ..api import (
     ZKAPAuthorizerStorageServer,
     ZKAPAuthorizerStorageClient,
 )
-from ..foolscap import (
-    TOKEN_LENGTH,
+from ..storage_common import (
+    slot_testv_and_readv_and_writev_message,
 )
 from ..model import (
     Pass,
 )
-from ..storage_common import (
-    slot_testv_and_readv_and_writev_message,
-)
 
 @attr.s
 class LocalRemote(object):
@@ -151,15 +152,15 @@ class ShareTests(TestCase):
         self.signing_key = random_signing_key()
 
         def get_passes(message, count):
-            if not isinstance(message, bytes):
-                raise TypeError("message must be bytes")
-            try:
-                message.decode("utf-8")
-            except UnicodeDecodeError:
-                raise TypeError("message must be valid utf-8")
-
-            return [Pass(u"x" * TOKEN_LENGTH)] * count
-
+            return list(
+                Pass(pass_.decode("ascii"))
+                for pass_
+                in make_passes(
+                    self.signing_key,
+                    message,
+                    list(RandomToken.create() for n in range(count)),
+                )
+            )
         self.server = ZKAPAuthorizerStorageServer(
             self.anonymous_storage_server,
             self.signing_key,
diff --git a/src/_zkapauthorizer/tests/test_storage_server.py b/src/_zkapauthorizer/tests/test_storage_server.py
index f2cc0b8..32ebcc4 100644
--- a/src/_zkapauthorizer/tests/test_storage_server.py
+++ b/src/_zkapauthorizer/tests/test_storage_server.py
@@ -22,12 +22,16 @@ from hypothesis.strategies import (
     lists,
 )
 from privacypass import (
-    BatchDLEQProof,
-    PublicKey,
     RandomToken,
     random_signing_key,
 )
+from foolscap.referenceable import (
+    LocalReferenceable,
+)
 
+from .privacypass import (
+    make_passes,
+)
 from .strategies import (
     zkaps,
 )
@@ -39,57 +43,10 @@ from ..api import (
     MorePassesRequired,
 )
 from ..storage_common import (
+    BYTES_PER_PASS,
     allocate_buckets_message,
 )
 
-def make_passes(signing_key, for_message, random_tokens):
-    blinded_tokens = list(
-        token.blind()
-        for token
-        in random_tokens
-    )
-    signatures = list(
-        signing_key.sign(blinded_token)
-        for blinded_token
-        in blinded_tokens
-    )
-    proof = BatchDLEQProof.create(
-        signing_key,
-        blinded_tokens,
-        signatures,
-    )
-    unblinded_signatures = proof.invalid_or_unblind(
-        random_tokens,
-        blinded_tokens,
-        signatures,
-        PublicKey.from_signing_key(signing_key),
-    )
-    preimages = list(
-        unblinded_signature.preimage()
-        for unblinded_signature
-        in unblinded_signatures
-    )
-    verification_keys = list(
-        unblinded_signature.derive_verification_key_sha512()
-        for unblinded_signature
-        in unblinded_signatures
-    )
-    message_signatures = list(
-        verification_key.sign_sha512(for_message.encode("utf-8"))
-        for verification_key
-        in verification_keys
-    )
-    passes = list(
-        u"{} {}".format(
-            preimage.encode_base64().decode("ascii"),
-            signature.encode_base64().decode("ascii"),
-        ).encode("ascii")
-        for (preimage, signature)
-        in zip(preimages, message_signatures)
-    )
-    return passes
-
-
 
 class PassValidationTests(TestCase):
     """
@@ -135,9 +92,8 @@ class PassValidationTests(TestCase):
         stored.
         """
         required_passes = 2
-        bytes_per_pass = self.storage_server._BYTES_PER_PASS
         share_nums = {3, 7}
-        allocated_size = int((required_passes * bytes_per_pass) / len(share_nums))
+        allocated_size = int((required_passes * BYTES_PER_PASS) / len(share_nums))
         storage_index = b"0123456789"
         renew_secret = b"x" * 32
         cancel_secret = b"y" * 32
@@ -155,7 +111,7 @@ class PassValidationTests(TestCase):
              cancel_secret,
              share_nums,
              allocated_size,
-             FakeRemoteReference(),
+             LocalReferenceable(None),
             ),
             {},
         )
@@ -163,8 +119,3 @@ class PassValidationTests(TestCase):
             allocate_buckets,
             raises(MorePassesRequired),
         )
-
-
-class FakeRemoteReference(object):
-    def notifyOnDisconnect(self, callback, *args, **kwargs):
-        pass
-- 
GitLab