From 4b8547726b3afb313705b86f2c32ce56a51a57af Mon Sep 17 00:00:00 2001
From: Tom Prince <tom.prince@private.storage>
Date: Tue, 7 Dec 2021 13:44:20 -0700
Subject: [PATCH] Implement a minimal interface for reporting spent ZKAPs to a
 spending service.

This includes a minimal in-memory implemenation which can be used by tests, and
adds the appropriate calls to the new interface to the storage server
implemenation.
---
 setup.cfg                                     |  1 +
 src/_zkapauthorizer/_plugin.py                | 14 +++-
 src/_zkapauthorizer/_storage_server.py        | 48 +++++++++----
 src/_zkapauthorizer/server/__init__.py        | 13 ++++
 src/_zkapauthorizer/server/spending.py        | 69 +++++++++++++++++++
 src/_zkapauthorizer/tests/matchers.py         | 22 ++++++
 .../tests/test_storage_protocol.py            | 48 ++++++++++++-
 .../tests/test_storage_server.py              | 32 +++++++--
 8 files changed, 224 insertions(+), 23 deletions(-)
 create mode 100644 src/_zkapauthorizer/server/__init__.py
 create mode 100644 src/_zkapauthorizer/server/spending.py

diff --git a/setup.cfg b/setup.cfg
index 04dea5c..c1c6250 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -27,6 +27,7 @@ package_dir =
 # the plugins package we want to ship.
 packages =
     _zkapauthorizer
+    _zkapauthorizer.server
     _zkapauthorizer.tests
     twisted.plugins
 
diff --git a/src/_zkapauthorizer/_plugin.py b/src/_zkapauthorizer/_plugin.py
index 503d43a..f6105e6 100644
--- a/src/_zkapauthorizer/_plugin.py
+++ b/src/_zkapauthorizer/_plugin.py
@@ -33,7 +33,7 @@ import attr
 from allmydata.client import _Client
 from allmydata.interfaces import IAnnounceableStorageServer, IFoolscapStoragePlugin
 from allmydata.node import MissingConfigEntry
-from challenge_bypass_ristretto import SigningKey
+from challenge_bypass_ristretto import PublicKey, SigningKey
 from eliot import start_action
 from prometheus_client import CollectorRegistry, write_to_textfile
 from twisted.internet import task
@@ -52,6 +52,7 @@ from .lease_maintenance import (
 )
 from .model import VoucherStore
 from .resource import from_configuration as resource_from_configuration
+from .server.spending import get_spender
 from .spending import SpendingController
 from .storage_common import BYTES_PER_PASS, get_configured_pass_value
 
@@ -129,13 +130,22 @@ class ZKAPAuthorizer(object):
                 kwargs.pop(u"ristretto-signing-key-path"),
             ),
         )
+        public_key = PublicKey.from_signing_key(signing_key)
         announcement = {
             u"ristretto-issuer-root-url": root_url,
+            u"ristretto-public-keys": [public_key.encode_base64()],
         }
+        anonymous_storage_server = get_anonymous_storage_server()
+        spender = get_spender(
+            config=kwargs,
+            reactor=reactor,
+            registry=registry,
+        )
         storage_server = ZKAPAuthorizerStorageServer(
-            get_anonymous_storage_server(),
+            anonymous_storage_server,
             pass_value=pass_value,
             signing_key=signing_key,
+            spender=spender,
             registry=registry,
             **kwargs
         )
diff --git a/src/_zkapauthorizer/_storage_server.py b/src/_zkapauthorizer/_storage_server.py
index b847499..468d683 100644
--- a/src/_zkapauthorizer/_storage_server.py
+++ b/src/_zkapauthorizer/_storage_server.py
@@ -38,7 +38,12 @@ from allmydata.storage.server import StorageServer
 from allmydata.storage.shares import get_share_file
 from allmydata.util.base32 import b2a
 from attr.validators import instance_of, provides
-from challenge_bypass_ristretto import SigningKey, TokenPreimage, VerificationSignature
+from challenge_bypass_ristretto import (
+    PublicKey,
+    SigningKey,
+    TokenPreimage,
+    VerificationSignature,
+)
 from eliot import log_call, start_action
 from foolscap.api import Referenceable
 from prometheus_client import CollectorRegistry, Histogram
@@ -50,6 +55,7 @@ from zope.interface import implementer
 
 from .foolscap import RIPrivacyPassAuthorizedStorageServer, ShareStat
 from .model import Pass
+from .server.spending import ISpender
 from .storage_common import (
     MorePassesRequired,
     add_lease_message,
@@ -89,8 +95,7 @@ class _ValidationResult(object):
     """
     The result of validating a list of passes.
 
-    :ivar list[int] valid: A list of indexes (into the validated list) of which
-        are acceptable.
+    :ivar list[bytes] valid: A list of valid token preimages.
 
     :ivar list[int] signature_check_failed: A list of indexes (into the
         validated list) of passes which did not have a correct signature.
@@ -105,19 +110,16 @@ class _ValidationResult(object):
         Cryptographically check the validity of a single pass.
 
         :param unicode message: The shared message for pass validation.
-        :param bytes pass_: The encoded pass to validate.
+        :param Pass pass_: The pass to validate.
 
         :return bool: ``False`` (invalid) if the pass includes a valid
             signature, ``True`` (valid) otherwise.
         """
         assert isinstance(message, unicode), "message %r not unicode" % (message,)
-        assert isinstance(pass_, bytes), "pass %r not bytes" % (pass_,)
+        assert isinstance(pass_, Pass), "pass %r not a Pass" % (pass_,)
         try:
-            parsed_pass = Pass.from_bytes(pass_)
-            preimage = TokenPreimage.decode_base64(parsed_pass.preimage)
-            proposed_signature = VerificationSignature.decode_base64(
-                parsed_pass.signature
-            )
+            preimage = TokenPreimage.decode_base64(pass_.preimage)
+            proposed_signature = VerificationSignature.decode_base64(pass_.signature)
             unblinded_token = signing_key.rederive_unblinded_token(preimage)
             verification_key = unblinded_token.derive_verification_key_sha512()
             invalid_pass = verification_key.invalid_sha512(
@@ -143,10 +145,11 @@ class _ValidationResult(object):
         valid = []
         signature_check_failed = []
         for idx, pass_ in enumerate(passes):
+            pass_ = Pass.from_bytes(pass_)
             if cls._is_invalid_pass(message, pass_, signing_key):
                 signature_check_failed.append(idx)
             else:
-                valid.append(idx)
+                valid.append(pass_.preimage)
         return cls(
             valid=valid,
             signature_check_failed=signature_check_failed,
@@ -194,6 +197,7 @@ class ZKAPAuthorizerStorageServer(Referenceable):
     _original = attr.ib(validator=provides(RIStorageServer))
     _pass_value = pass_value_attribute()
     _signing_key = attr.ib(validator=instance_of(SigningKey))
+    _spender = attr.ib(validator=provides(ISpender))
     _registry = attr.ib(
         default=attr.Factory(CollectorRegistry),
         validator=attr.validators.instance_of(CollectorRegistry),
@@ -202,8 +206,16 @@ class ZKAPAuthorizerStorageServer(Referenceable):
         validator=provides(IReactorTime),
         default=attr.Factory(partial(namedAny, "twisted.internet.reactor")),
     )
+    _public_key = attr.ib(init=False)
     _metric_spending_successes = attr.ib(init=False)
 
+    @_public_key.default
+    def _get_public_key(self):
+        # attrs evaluates defaults (whether specified inline or via decorator)
+        # in the order the attributes were defined in the class definition,
+        # so that `self._signing_key` will be assigned when this runs.
+        return PublicKey.from_signing_key(self._signing_key)
+
     def _get_spending_histogram_buckets(self):
         """
         Create the upper bounds for the ZKAP spending histogram.
@@ -327,7 +339,10 @@ class ZKAPAuthorizerStorageServer(Referenceable):
                 canary,
                 disconnect_marker,
             )
-
+        self._spender.mark_as_spent(
+            self._public_key,
+            validation.valid[:spent_passes],
+        )
         return alreadygot, bucketwriters
 
     def remote_get_buckets(self, storage_index):
@@ -354,6 +369,10 @@ class ZKAPAuthorizerStorageServer(Referenceable):
             self._original,
         )
         result = self._original.remote_add_lease(storage_index, *a, **kw)
+        self._spender.mark_as_spent(
+            self._public_key,
+            validation.valid,
+        )
         self._metric_spending_successes.observe(len(validation.valid))
         return result
 
@@ -484,6 +503,11 @@ class ZKAPAuthorizerStorageServer(Referenceable):
         # somewhat.
         add_leases_for_writev(self._original, storage_index, secrets, tw_vectors, now)
 
+        self._spender.mark_as_spent(
+            self._public_key,
+            validation.valid,
+        )
+
         # The operation has fully succeeded.
         self._metric_spending_successes.observe(required_new_passes)
 
diff --git a/src/_zkapauthorizer/server/__init__.py b/src/_zkapauthorizer/server/__init__.py
new file mode 100644
index 0000000..b00cbfe
--- /dev/null
+++ b/src/_zkapauthorizer/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 PrivateStorage.io, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/_zkapauthorizer/server/spending.py b/src/_zkapauthorizer/server/spending.py
new file mode 100644
index 0000000..f93f156
--- /dev/null
+++ b/src/_zkapauthorizer/server/spending.py
@@ -0,0 +1,69 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+try:
+    from typing import Any
+except ImportError:
+    pass
+
+import attr
+from challenge_bypass_ristretto import PublicKey
+from prometheus_client import CollectorRegistry
+from twisted.internet.interfaces import IReactorTime
+from zope.interface import Interface, implementer
+
+
+class ISpender(Interface):
+    """
+    An ``ISpender`` can records spent ZKAPs and reports double spends.
+    """
+
+    def mark_as_spent(public_key, passes):
+        # type: (PublicKey, list[bytes]) -> None
+        """
+        Record the given ZKAPs (associated to the given public key as having
+        been spent.
+
+        This does *not* report errors and should only be used in cases when
+        recording spending that has already happened. This can be because
+        we could not contact the spending service when they were spent, or
+        because we can't yet check before making changes to the node.
+        """
+
+
+@attr.s
+class _SpendingData(object):
+    spent_tokens = attr.ib(init=False, factory=dict)
+
+    def reset(self):
+        self.spent_tokens.clear()
+
+
+@implementer(ISpender)
+@attr.s
+class RecordingSpender(object):
+    """
+    An in-memory :py:`ISpender` implementation that exposes the spent tokens
+    for testing purposes.
+    """
+
+    _recorder = attr.ib(validator=attr.validators.instance_of(_SpendingData))
+
+    @classmethod
+    def make(cls):
+        # type: () -> (_SpendingData, ISpender)
+        recorder = _SpendingData()
+        return recorder, cls(recorder)
+
+    def mark_as_spent(self, public_key, passes):
+        self._recorder.spent_tokens.setdefault(public_key.encode_base64(), []).extend(
+            passes
+        )
+
+
+def get_spender(config, reactor, registry):
+    # type: (dict[str, Any], IReactorTime, CollectorRegistry) -> ISpender
+    """
+    Return an :py:`ISpender` to be used with the given storage server configuration.
+    """
+    recorder, spender = RecordingSpender.make()
+    return spender
diff --git a/src/_zkapauthorizer/tests/matchers.py b/src/_zkapauthorizer/tests/matchers.py
index bf5ab30..3a3399c 100644
--- a/src/_zkapauthorizer/tests/matchers.py
+++ b/src/_zkapauthorizer/tests/matchers.py
@@ -39,12 +39,16 @@ from testtools.matchers import (
     Matcher,
     MatchesAll,
     MatchesAny,
+    MatchesDict,
+    MatchesSetwise,
     MatchesStructure,
     Mismatch,
 )
 from testtools.twistedsupport import succeeded
 from treq import content
 
+from ..model import Pass
+from ..server.spending import _SpendingData
 from ._exception import raises
 
 
@@ -206,3 +210,21 @@ def matches_response(
             succeeded(body_matcher),
         ),
     )
+
+
+def matches_spent_passes(public_key_hash, spent_passes):
+    # type: (bytes, list[Pass]) -> Matcher[_SpendingData]
+    """
+    Returns a matcher for _SpendingData that checks whether the
+    spent pass match the given public key and passes.
+    """
+    return AfterPreprocessing(
+        lambda spending_recorder: spending_recorder.spent_tokens,
+        MatchesDict(
+            {
+                public_key_hash: MatchesSetwise(
+                    *[Equals(pass_.preimage) for pass_ in spent_passes]
+                )
+            }
+        ),
+    )
diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py
index bc11261..a5b6429 100644
--- a/src/_zkapauthorizer/tests/test_storage_protocol.py
+++ b/src/_zkapauthorizer/tests/test_storage_protocol.py
@@ -20,7 +20,7 @@ from __future__ import absolute_import
 
 from allmydata.storage.common import storage_index_to_dir
 from allmydata.storage.shares import get_share_file
-from challenge_bypass_ristretto import random_signing_key
+from challenge_bypass_ristretto import PublicKey, random_signing_key
 from fixtures import MonkeyPatch
 from foolscap.referenceable import LocalReferenceable
 from hypothesis import assume, given
@@ -52,6 +52,7 @@ from ..api import (
     ZKAPAuthorizerStorageServer,
 )
 from ..foolscap import ShareStat
+from ..server.spending import RecordingSpender
 from ..storage_common import (
     allocate_buckets_message,
     get_implied_data_length,
@@ -60,7 +61,7 @@ from ..storage_common import (
 from .common import skipIf
 from .fixtures import AnonymousStorageServer
 from .foolscap import LocalRemote
-from .matchers import matches_version_dictionary
+from .matchers import matches_spent_passes, matches_version_dictionary
 from .storage_common import (
     LEASE_INTERVAL,
     cleanup_storage_server,
@@ -146,6 +147,9 @@ class ShareTests(TestCase):
         super(ShareTests, self).setUp()
         self.canary = LocalReferenceable(None)
         self.signing_key = random_signing_key()
+        self.public_key_hash = PublicKey.from_signing_key(
+            self.signing_key
+        ).encode_base64()
         self.pass_factory = pass_factory(
             get_passes=privacypass_passes(self.signing_key)
         )
@@ -155,10 +159,12 @@ class ShareTests(TestCase):
             AnonymousStorageServer(self.clock),
         ).storage_server
 
+        self.spending_recorder, spender = RecordingSpender.make()
         self.server = ZKAPAuthorizerStorageServer(
             self.anonymous_storage_server,
             self.pass_value,
             self.signing_key,
+            spender,
             clock=self.clock,
         )
         self.local_remote_server = LocalRemote(self.server)
@@ -182,6 +188,9 @@ class ShareTests(TestCase):
         # Reset the state of any passes in our pass factory.
         self.pass_factory._clear()
 
+        # Reset any record of spent tokens.
+        self.spending_recorder.reset()
+
         # And clean out any shares that might confuse things.
         cleanup_storage_server(self.anonymous_storage_server)
 
@@ -313,6 +322,10 @@ class ShareTests(TestCase):
             Equals(sharenums),
             u"fresh server refused to allocate all requested buckets",
         )
+        self.expectThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, self.pass_factory.spent),
+        )
 
         for sharenum, bucket in allocated.items():
             bucket.remote_write(0, bytes_for_share(sharenum, size))
@@ -448,6 +461,12 @@ class ShareTests(TestCase):
             ),
         )
 
+        # The spent passes have been reported to the spending service.
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, self.pass_factory.spent),
+        )
+
         expected_leases = {}
         # Chop off the non-integer part of the expected values because share
         # files only keep integer precision.
@@ -503,6 +522,13 @@ class ShareTests(TestCase):
             ),
             succeeded(Always()),
         )
+
+        # The spent passes have been reported to the spending service.
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, self.pass_factory.spent),
+        )
+
         leases = list(self.anonymous_storage_server.get_leases(storage_index))
         self.assertThat(leases, HasLength(2))
 
@@ -768,6 +794,12 @@ class ShareTests(TestCase):
             u"Server rejected a write to a new mutable slot",
         )
 
+        # The spent passes have been reported to the spending service.
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, self.pass_factory.spent),
+        )
+
         expected = [
             {
                 sharenum: ShareStat(
@@ -903,6 +935,12 @@ class ShareTests(TestCase):
             Equals(after_passes),
         )
 
+        # The spent passes have been reported to the spending service.
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, self.pass_factory.spent),
+        )
+
         # And the lease we paid for on every share is present.
         self.assertThat(
             dict(
@@ -1029,6 +1067,12 @@ class ShareTests(TestCase):
         finally:
             patch.cleanUp()
 
+        # The spent passes have been reported to the spending service.
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, self.pass_factory.spent),
+        )
+
         # Not only should the write above succeed but the lease should now be
         # marked as expiring one additional lease period into the future.
         self.assertThat(
diff --git a/src/_zkapauthorizer/tests/test_storage_server.py b/src/_zkapauthorizer/tests/test_storage_server.py
index ba28632..aba0ad0 100644
--- a/src/_zkapauthorizer/tests/test_storage_server.py
+++ b/src/_zkapauthorizer/tests/test_storage_server.py
@@ -22,7 +22,7 @@ from random import shuffle
 from time import time
 
 from allmydata.storage.mutable import MutableShareFile
-from challenge_bypass_ristretto import random_signing_key
+from challenge_bypass_ristretto import PublicKey, random_signing_key
 from foolscap.referenceable import LocalReferenceable
 from hypothesis import given, note
 from hypothesis.strategies import integers, just, lists, one_of, tuples
@@ -33,6 +33,7 @@ from twisted.python.runtime import platform
 
 from .._storage_server import _ValidationResult
 from ..api import MorePassesRequired, ZKAPAuthorizerStorageServer
+from ..server.spending import RecordingSpender
 from ..storage_common import (
     add_lease_message,
     allocate_buckets_message,
@@ -44,7 +45,7 @@ from ..storage_common import (
 )
 from .common import skipIf
 from .fixtures import AnonymousStorageServer
-from .matchers import raises
+from .matchers import matches_spent_passes, raises
 from .storage_common import cleanup_storage_server, get_passes, write_toy_shares
 from .strategies import (
     lease_cancel_secrets,
@@ -97,11 +98,9 @@ class ValidationResultTests(TestCase):
             ),
             Equals(
                 _ValidationResult(
-                    valid=list(
-                        idx
-                        for (idx, pass_) in enumerate(all_passes)
-                        if pass_ in valid_passes
-                    ),
+                    valid=[
+                        pass_.preimage for pass_ in all_passes if pass_ in valid_passes
+                    ],
                     signature_check_failed=list(
                         idx
                         for (idx, pass_) in enumerate(all_passes)
@@ -191,6 +190,7 @@ class PassValidationTests(TestCase):
     def setUp(self):
         super(PassValidationTests, self).setUp()
         self.clock = Clock()
+        self.spending_recorder, spender = RecordingSpender.make()
         # anonymous_storage_server uses time.time() so get our Clock close to
         # the same time so we can do lease expiration calculations more
         # easily.
@@ -199,10 +199,14 @@ class PassValidationTests(TestCase):
             AnonymousStorageServer(self.clock),
         ).storage_server
         self.signing_key = random_signing_key()
+        self.public_key_hash = PublicKey.from_signing_key(
+            self.signing_key
+        ).encode_base64()
         self.storage_server = ZKAPAuthorizerStorageServer(
             self.anonymous_storage_server,
             self.pass_value,
             self.signing_key,
+            spender,
             clock=self.clock,
         )
 
@@ -220,6 +224,7 @@ class PassValidationTests(TestCase):
         # way that allows us to just move everything from `setUp` into this
         # method.
         cleanup_storage_server(self.anonymous_storage_server)
+        self.spending_recorder.reset()
 
         # Reset all of the metrics, too, so the individual tests have a
         # simpler job (can compare values relative to 0).
@@ -256,6 +261,7 @@ class PassValidationTests(TestCase):
             ),
             {},
         )
+        self.expectThat(self.spending_recorder.spent_tokens, Equals({}))
         self.assertThat(
             allocate_buckets,
             raises(MorePassesRequired),
@@ -295,6 +301,7 @@ class PassValidationTests(TestCase):
         try:
             result = mutable_write()
         except MorePassesRequired as e:
+            self.expectThat(self.spending_recorder.spent_tokens, Equals({}))
             self.assertThat(
                 e,
                 Equals(
@@ -361,6 +368,11 @@ class PassValidationTests(TestCase):
             "Server denied initial write.",
         )
 
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, valid_passes),
+        )
+
         # Pick any share to make larger.
         sharenum = next(iter(tw_vectors))
         _, data_vector, new_length = tw_vectors[sharenum]
@@ -400,6 +412,11 @@ class PassValidationTests(TestCase):
         else:
             self.fail("expected MorePassesRequired, got {}".format(result))
 
+        self.assertThat(
+            self.spending_recorder,
+            matches_spent_passes(self.public_key_hash, valid_passes),
+        )
+
     @given(
         storage_index=storage_indexes(),
         secrets=tuples(
@@ -507,6 +524,7 @@ class PassValidationTests(TestCase):
             )
         else:
             self.fail("Expected MorePassesRequired, got {}".format(result))
+        self.assertThat(self.spending_recorder.spent_tokens, Equals({}))
 
     @given(
         slot=storage_indexes(),
-- 
GitLab