diff --git a/src/_zkapauthorizer/_plugin.py b/src/_zkapauthorizer/_plugin.py index 4e19a3973db20994f044f5cf6b199a87479a6723..8f9b629077e8610095c6938849ac3c457a40a624 100644 --- a/src/_zkapauthorizer/_plugin.py +++ b/src/_zkapauthorizer/_plugin.py @@ -64,10 +64,6 @@ from .api import ( ZKAPAuthorizerStorageClient, ) -from .eliot import ( - GET_PASSES, -) - from .model import ( VoucherStore, ) @@ -82,6 +78,10 @@ from .storage_common import ( from .controller import ( get_redeemer, ) +from .spending import ( + SpendingController, +) + from .lease_maintenance import ( SERVICE_NAME, lease_maintenance_service, @@ -173,20 +173,15 @@ class ZKAPAuthorizer(object): """ from twisted.internet import reactor redeemer = self._get_redeemer(node_config, announcement, reactor) - extract_unblinded_tokens = self._get_store(node_config).extract_unblinded_tokens - def get_passes(message, count): - unblinded_tokens = extract_unblinded_tokens(count) - passes = redeemer.tokens_to_passes(message, unblinded_tokens) - GET_PASSES.log( - message=message, - count=count, - ) - return passes - + store = self._get_store(node_config) + controller = SpendingController.for_store( + tokens_to_passes=redeemer.tokens_to_passes, + store=store, + ) return ZKAPAuthorizerStorageClient( get_configured_pass_value(node_config), get_rref, - get_passes, + controller.get, ) diff --git a/src/_zkapauthorizer/_storage_client.py b/src/_zkapauthorizer/_storage_client.py index 2c5a7af7393eb6677b86e99222f5d5131d3e94f8..04a465ddbd7607c9b8c274773311da87186bdf5a 100644 --- a/src/_zkapauthorizer/_storage_client.py +++ b/src/_zkapauthorizer/_storage_client.py @@ -36,13 +36,12 @@ from zope.interface import ( ) from eliot.twisted import ( - DeferredContext, + inline_callbacks, ) from twisted.internet.defer import ( inlineCallbacks, returnValue, - maybeDeferred, ) from allmydata.interfaces import ( IStorageServer, @@ -65,6 +64,7 @@ from .storage_common import ( get_required_new_passes_for_mutable_write, ) + class IncorrectStorageServerReference(Exception): """ A Foolscap remote object which should reference a ZKAPAuthorizer storage @@ -84,55 +84,101 @@ class IncorrectStorageServerReference(Exception): ) +def invalidate_rejected_passes(passes, more_passes_required): + """ + Return a new ``IPassGroup`` with all rejected passes removed from it. + + :param IPassGroup passes: A group of passes, some of which may have been + rejected. + + :param MorePassesRequired more_passes_required: An exception possibly + detailing the rejection of some passes from the group. + + :return: ``None`` if no passes in the group were rejected and so there is + nothing to replace. Otherwise, a new ``IPassGroup`` created from + ``passes`` but with rejected passes replaced with new ones. + """ + num_failed = len(more_passes_required.signature_check_failed) + if num_failed == 0: + # If no signature checks failed then the call just didn't supply + # enough passes. The exception tells us how many passes we should + # spend so we could try again with that number of passes but for + # now we'll just let the exception propagate. The client should + # always figure out the number of passes right on the first try so + # this case is somewhat suspicious. Err on the side of lack of + # service instead of burning extra passes. + # + # We *could* just `raise` here and only be called from an `except` + # suite... but let's not be so vulgar. + return None + SIGNATURE_CHECK_FAILED.log(count=num_failed) + rejected_passes, okay_passes = passes.split(more_passes_required.signature_check_failed) + rejected_passes.mark_invalid(u"signature check failed") + + # It would be great to just expand okay_passes right here. However, if + # that fails (eg because we don't have enough tokens remaining) then the + # caller will have a hard time figuring out which okay passes remain that + # it needs to reset. :/ So, instead, pass back the complete okay set. The + # caller can figure out by how much to expand it by considering its size + # and the original number of passes it requested. + return okay_passes + + +@inline_callbacks def call_with_passes(method, num_passes, get_passes): """ Call a method, passing the requested number of passes as the first argument, and try again if the call fails with an error related to some of the passes being rejected. - :param method: A callable which accepts a list of encoded passes as its - only argument and returns a ``Deferred``. If the ``Deferred`` fires - with ``MorePassesRequired`` then the invalid passes will be discarded - and replacement passes will be requested for a new call of ``method``. - This will repeat until no passes remain, the method succeeds, or the - methods fails in a different way. + :param (IPassGroup -> Deferred) method: An operation to call with some passes. + If the returned ``Deferred`` fires with ``MorePassesRequired`` then + the invalid passes will be discarded and replacement passes will be + requested for a new call of ``method``. This will repeat until no + passes remain, the method succeeds, or the methods fails in a + different way. :param int num_passes: The number of passes to pass to the call. - :param (unicode -> int -> [bytes]) get_passes: A function for getting + :param (int -> IPassGroup) get_passes: A function for getting passes. - :return: Whatever ``method`` returns. + :return: A ``Deferred`` that fires with whatever the ``Deferred`` returned + by ``method`` fires with (apart from ``MorePassesRequired`` failures + that trigger a retry). """ - def get_more_passes(reason): - reason.trap(MorePassesRequired) - num_failed = len(reason.value.signature_check_failed) - if num_failed == 0: - # If no signature checks failed then the call just didn't supply - # enough passes. The exception tells us how many passes we should - # spend so we could try again with that number of passes but for - # now we'll just let the exception propagate. The client should - # always figure out the number of passes right on the first try so - # this case is somewhat suspicious. Err on the side of lack of - # service instead of burning extra passes. - return reason - SIGNATURE_CHECK_FAILED.log(count=num_failed) - new_passes = get_passes(num_failed) - for idx, new_pass in zip(reason.value.signature_check_failed, new_passes): - passes[idx] = new_pass - return go(passes) - - def go(passes): - # Capture the Eliot context for the errback. - d = DeferredContext(maybeDeferred(method, passes)) - d.addErrback(get_more_passes) - # Return the underlying Deferred without finishing the action. - return d.result - - with CALL_WITH_PASSES(count=num_passes).context(): - passes = get_passes(num_passes) - # Finish the Eliot action when this is done. - return DeferredContext(go(passes)).addActionFinish() + with CALL_WITH_PASSES(count=num_passes): + pass_group = get_passes(num_passes) + try: + # Try and repeat as necessary. + while True: + try: + result = yield method(pass_group) + except MorePassesRequired as e: + okay_pass_group = invalidate_rejected_passes( + pass_group, + e, + ) + if okay_pass_group is None: + raise + else: + # Update the local in case we end up going to the + # except suite below. + pass_group = okay_pass_group + # Add the necessary number of new passes. This might + # fail if we don't have enough tokens. + pass_group = pass_group.expand(num_passes - len(pass_group.passes)) + else: + # Commit the spend of the passes when the operation finally succeeds. + pass_group.mark_spent() + break + except: + # Something went wrong that we can't address with a retry. + pass_group.reset() + raise + + # Give the operation's result to the caller. + returnValue(result) def with_rref(f): @@ -149,6 +195,19 @@ def with_rref(f): return g +def _encode_passes(group): + """ + :param IPassGroup group: A group of passes to encode. + + :return list[bytes]: The encoded form of the passes in the given group. + """ + return list( + t.pass_text.encode("ascii") + for t + in group.passes + ) + + @implementer(IStorageServer) @attr.s class ZKAPAuthorizerStorageClient(object): @@ -168,11 +227,11 @@ class ZKAPAuthorizerStorageClient(object): valid ``RemoteReference`` corresponding to the server-side object for this scheme. - :ivar _get_passes: A two-argument callable which retrieves some passes - which can be used to authorize an operation. The first argument is a - bytes (valid utf-8) message binding the passes to the request for - which they will be used. The second is an integer giving the number - of passes to request. + :ivar (bytes -> int -> IPassGroup) _get_passes: A callable to use to + retrieve passes which can be used to authorize an operation. The + first argument is utf-8 encoded message binding the passes to the + request for which they will be used. The second gives the number of + passes to request. """ _expected_remote_interface_name = ( "RIPrivacyPassAuthorizedStorageServer.tahoe.privatestorage.io" @@ -200,20 +259,6 @@ class ZKAPAuthorizerStorageClient(object): ) return rref - def _get_encoded_passes(self, message, count): - """ - :param unicode message: The message to which to bind the passes. - - :return: A list of passes from ``_get_passes`` encoded into their - ``bytes`` representation. - """ - assert isinstance(message, unicode) - return list( - t.pass_text.encode("ascii") - for t - in self._get_passes(message.encode("utf-8"), count) - ) - @with_rref def get_version(self, rref): return rref.callRemote( @@ -231,12 +276,11 @@ class ZKAPAuthorizerStorageClient(object): allocated_size, canary, ): - message = allocate_buckets_message(storage_index) num_passes = required_passes(self._pass_value, [allocated_size] * len(sharenums)) return call_with_passes( lambda passes: rref.callRemote( "allocate_buckets", - passes, + _encode_passes(passes), storage_index, renew_secret, cancel_secret, @@ -245,7 +289,7 @@ class ZKAPAuthorizerStorageClient(object): canary, ), num_passes, - partial(self._get_encoded_passes, message), + partial(self._get_passes, allocate_buckets_message(storage_index).encode("utf-8")), ) @with_rref @@ -278,13 +322,13 @@ class ZKAPAuthorizerStorageClient(object): result = yield call_with_passes( lambda passes: rref.callRemote( "add_lease", - passes, + _encode_passes(passes), storage_index, renew_secret, cancel_secret, ), num_passes, - partial(self._get_encoded_passes, add_lease_message(storage_index)), + partial(self._get_passes, add_lease_message(storage_index).encode("utf-8")), ) returnValue(result) @@ -306,12 +350,12 @@ class ZKAPAuthorizerStorageClient(object): result = yield call_with_passes( lambda passes: rref.callRemote( "renew_lease", - passes, + _encode_passes(passes), storage_index, renew_secret, ), num_passes, - partial(self._get_encoded_passes, renew_lease_message(storage_index)), + partial(self._get_passes, renew_lease_message(storage_index).encode("utf-8")), ) returnValue(result) @@ -377,7 +421,7 @@ class ZKAPAuthorizerStorageClient(object): result = yield call_with_passes( lambda passes: rref.callRemote( "slot_testv_and_readv_and_writev", - passes, + _encode_passes(passes), storage_index, secrets, tw_vectors, @@ -385,8 +429,8 @@ class ZKAPAuthorizerStorageClient(object): ), num_passes, partial( - self._get_encoded_passes, - slot_testv_and_readv_and_writev_message(storage_index), + self._get_passes, + slot_testv_and_readv_and_writev_message(storage_index).encode("utf-8"), ), ) returnValue(result) diff --git a/src/_zkapauthorizer/eliot.py b/src/_zkapauthorizer/eliot.py index da3960d07a9f2ef09fb112fed4dd33bae61bc545..a2e99d8b89e2c885ab83f2a54c83b1a07333f376 100644 --- a/src/_zkapauthorizer/eliot.py +++ b/src/_zkapauthorizer/eliot.py @@ -32,6 +32,12 @@ PRIVACYPASS_MESSAGE = Field( u"The PrivacyPass request-binding data associated with a pass.", ) +INVALID_REASON = Field( + u"reason", + unicode, + u"The reason given by the server for rejecting a pass as invalid.", +) + PASS_COUNT = Field( u"count", int, @@ -41,7 +47,25 @@ PASS_COUNT = Field( GET_PASSES = MessageType( u"zkapauthorizer:get-passes", [PRIVACYPASS_MESSAGE, PASS_COUNT], - u"Passes are being spent.", + u"An attempt to spend passes is beginning.", +) + +SPENT_PASSES = MessageType( + u"zkapauthorizer:spent-passes", + [PASS_COUNT], + u"An attempt to spend passes has succeeded.", +) + +INVALID_PASSES = MessageType( + u"zkapauthorizer:invalid-passes", + [INVALID_REASON, PASS_COUNT], + u"An attempt to spend passes has found some to be invalid.", +) + +RESET_PASSES = MessageType( + u"zkapauthorizer:reset-passes", + [PRIVACYPASS_MESSAGE, PASS_COUNT], + u"Some passes involved in a failed spending attempt have not definitely been spent and are being returned for future use.", ) SIGNATURE_CHECK_FAILED = MessageType( diff --git a/src/_zkapauthorizer/model.py b/src/_zkapauthorizer/model.py index b7d590bec3e26bf7ac8f9c288ed92fad88079e62..12b0393b1eddc8869b55e759f4ebbfb99e03773a 100644 --- a/src/_zkapauthorizer/model.py +++ b/src/_zkapauthorizer/model.py @@ -144,14 +144,60 @@ def open_and_initialize(path, connect=None): actual_version = get_schema_version(cursor) schema_upgrades = list(get_schema_upgrades(actual_version)) run_schema_upgrades(schema_upgrades, cursor) + + # Create some tables that only exist (along with their contents) for + # this connection. These are outside of the schema because they are not + # persistent. We can change them any time we like without worrying about + # upgrade logic because we re-create them on every connection. + conn.execute( + """ + -- Track tokens in use by the process holding this connection. + CREATE TEMPORARY TABLE [in-use] ( + [unblinded-token] text, -- The base64 encoded unblinded token. + + PRIMARY KEY([unblinded-token]) + -- A foreign key on unblinded-token to [unblinded-tokens]([token]) + -- would be alright - however SQLite3 foreign key constraints + -- can't cross databases (and temporary tables are considered to + -- be in a different database than normal tables). + ) + """, + ) + conn.execute( + """ + -- Track tokens that we want to remove from the database. Mainly just + -- works around the awkward DB-API interface for dealing with deleting + -- many rows. + CREATE TEMPORARY TABLE [to-discard] ( + [unblinded-token] text + ) + """, + ) + conn.execute( + """ + -- Track tokens that we want to remove from the [in-use] set. Similar + -- to [to-discard]. + CREATE TEMPORARY TABLE [to-reset] ( + [unblinded-token] text + ) + """, + ) return conn def with_cursor(f): + """ + Decorate a function so it is automatically passed a cursor with an active + transaction as the first positional argument. If the function returns + normally then the transaction will be committed. Otherwise, the + transaction will be rolled back. + """ @wraps(f) def with_cursor(self, *a, **kw): with self._connection: - return f(self, self._connection.cursor(), *a, **kw) + cursor = self._connection.cursor() + cursor.execute("BEGIN IMMEDIATE TRANSACTION") + return f(self, cursor, *a, **kw) return with_cursor @@ -162,6 +208,11 @@ def memory_connect(path, *a, **kw): return _connect(":memory:", *a, **kw) +# The largest integer SQLite3 can represent in an integer column. Larger than +# this an the representation loses precision as a floating point. +_SQLITE3_INTEGER_MAX = 2 ** 63 - 1 + + @attr.s(frozen=True) class VoucherStore(object): """ @@ -260,10 +311,9 @@ class VoucherStore(object): if not isinstance(now, datetime): raise TypeError("{} returned {}, expected datetime".format(self.now, now)) - cursor.execute("BEGIN IMMEDIATE TRANSACTION") cursor.execute( """ - SELECT ([text]) + SELECT [text] FROM [tokens] WHERE [voucher] = ? AND [counter] = ? """, @@ -306,7 +356,6 @@ class VoucherStore(object): in tokens ), ) - cursor.connection.commit() return tokens @with_cursor @@ -446,56 +495,147 @@ class VoucherStore(object): ), ) - @with_cursor - def extract_unblinded_tokens(self, cursor, count): + def get_unblinded_tokens(self, cursor, count): """ - Remove and return some unblinded tokens. + Get some unblinded tokens. - :param int count: The maximum number of unblinded tokens to remove and - return. If fewer than this are available, only as many as are - available are returned. + These tokens are not removed from the store but they will not be + returned from a future call to ``get_unblinded_tokens`` *on this + ``VoucherStore`` instance* unless ``reset_unblinded_tokens`` is used + to reset their state. + + If the underlying storage is access via another ``VoucherStore`` + instance then the behavior of this method will be as if all tokens + which have not had their state changed to invalid or spent have been + reset. :return list[UnblindedTokens]: The removed unblinded tokens. """ + if count > _SQLITE3_INTEGER_MAX: + # An unreasonable number of tokens and also large enough to + # provoke undesirable behavior from the database. + raise NotEnoughTokens() + cursor.execute( """ - SELECT COUNT(token) + SELECT [token] FROM [unblinded-tokens] + WHERE [token] NOT IN [in-use] + LIMIT ? """, + (count,), ) - [(existing_tokens,)] = cursor.fetchall() - if existing_tokens < count: + texts = cursor.fetchall() + if len(texts) < count: raise NotEnoughTokens() + cursor.executemany( + """ + INSERT INTO [in-use] VALUES (?) + """, + texts, + ) + return list( + UnblindedToken(t) + for (t,) + in texts + ) + + @with_cursor + def discard_unblinded_tokens(self, cursor, unblinded_tokens): + """ + Get rid of some unblinded tokens. The tokens will be completely removed + from the system. This is useful when the tokens have been + successfully spent. + + :param list[UnblindedToken] unblinded_tokens: The tokens to discard. + + :return: ``None`` + """ + cursor.executemany( + """ + INSERT INTO [to-discard] VALUES (?) + """, + list((token.unblinded_token,) for token in unblinded_tokens), + ) cursor.execute( """ - CREATE TEMPORARY TABLE [extracting] - AS - SELECT [token] FROM [unblinded-tokens] LIMIT ? + DELETE FROM [in-use] + WHERE [unblinded-token] IN [to-discard] """, - (count,), ) cursor.execute( """ - DELETE FROM [unblinded-tokens] WHERE [token] IN [extracting] + DELETE FROM [unblinded-tokens] + WHERE [token] IN [to-discard] """, ) cursor.execute( """ - SELECT [token] FROM [extracting] + DELETE FROM [to-discard] """, ) - texts = cursor.fetchall() + + @with_cursor + def invalidate_unblinded_tokens(self, cursor, reason, unblinded_tokens): + """ + Mark some unblinded tokens as invalid and unusable. Some record of the + tokens may be retained for future inspection. These tokens will not + be returned by any future ``get_unblinded_tokens`` call. This is + useful when an attempt to spend a token has met with rejection by the + validator. + + :param list[UnblindedToken] unblinded_tokens: The tokens to mark. + + :return: ``None`` + """ + cursor.executemany( + """ + INSERT INTO [invalid-unblinded-tokens] VALUES (?, ?) + """, + list( + (token.unblinded_token, reason) + for token + in unblinded_tokens + ), + ) cursor.execute( """ - DROP TABLE [extracting] + DELETE FROM [in-use] + WHERE [unblinded-token] IN (SELECT [token] FROM [invalid-unblinded-tokens]) """, ) - return list( - UnblindedToken(t) - for (t,) - in texts + cursor.execute( + """ + DELETE FROM [unblinded-tokens] + WHERE [token] IN (SELECT [token] FROM [invalid-unblinded-tokens]) + """, + ) + + @with_cursor + def reset_unblinded_tokens(self, cursor, unblinded_tokens): + """ + Make some unblinded tokens available to be retrieved from the store again. + This is useful if a spending operation has failed with a transient + error. + """ + cursor.executemany( + """ + INSERT INTO [to-reset] VALUES (?) + """, + list((token.unblinded_token,) for token in unblinded_tokens), + ) + cursor.execute( + """ + DELETE FROM [in-use] + WHERE [unblinded-token] IN [to-reset] + """, + ) + cursor.execute( + """ + DELETE FROM [to-reset] + """, ) @with_cursor diff --git a/src/_zkapauthorizer/schema.py b/src/_zkapauthorizer/schema.py index a23d3373c9a230d874710183e046d9e9cef954e6..5044153e08c31a211d8bcaf35dbb2efffea46626 100644 --- a/src/_zkapauthorizer/schema.py +++ b/src/_zkapauthorizer/schema.py @@ -156,4 +156,15 @@ _UPGRADES = { ALTER TABLE [vouchers] ADD COLUMN [expected-tokens] integer NOT NULL DEFAULT 32768 """, ], + + 4: [ + """ + CREATE TABLE [invalid-unblinded-tokens] ( + [token] text, -- The base64 encoded unblinded token. + [reason] text, -- The reason given for it being considered invalid. + + PRIMARY KEY([token]) + ) + """, + ], } diff --git a/src/_zkapauthorizer/spending.py b/src/_zkapauthorizer/spending.py new file mode 100644 index 0000000000000000000000000000000000000000..20f0f775f17622f868e19790a35f7725b53f4759 --- /dev/null +++ b/src/_zkapauthorizer/spending.py @@ -0,0 +1,217 @@ +# Copyright 2019 PrivateStorage.io, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A module for logic controlling the manner in which ZKAPs are spent. +""" + +from zope.interface import ( + Interface, + Attribute, + implementer, +) + +import attr + +from .eliot import ( + GET_PASSES, + SPENT_PASSES, + INVALID_PASSES, + RESET_PASSES, +) + +class IPassGroup(Interface): + """ + A group of passed meant to be spent together. + """ + passes = Attribute(":ivar list[Pass] passes: The passes themselves.") + + def split(select_indices): + """ + Create two new ``IPassGroup`` providers. The first contains all passes in + this group at the given indices. The second contains all the others. + + :param list[int] select_indices: The indices of the passes to include + in the first resulting group. + + :return (IPassGroup, IPassGroup): The two new groups. + """ + + def expand(by_amount): + """ + Create a new ``IPassGroup`` provider which contains all of this group's + passes and some more. + + :param int by_amount: The number of additional passes the resulting + group should contain. + + :return IPassGroup: The new group. + """ + + def mark_spent(): + """ + The passes have been spent successfully. Ensure none of them appear in + any ``IPassGroup`` provider created in the future. + + :return: ``None`` + """ + + def mark_invalid(reason): + """ + The passes could not be spent. Ensure none of them appear in any + ``IPassGroup`` provider created in the future. + + :param unicode reason: A short description of the reason the passes + could not be spent. + + :return: ``None`` + """ + + def reset(): + """ + The passes have not been spent. Return them to for use in a future + ``IPassGroup`` provider. + + :return: ``None`` + """ + + +class IPassFactory(Interface): + """ + An object which can create passes. + """ + def get(message, num_passes): + """ + :param unicode message: A request-binding message for the resulting passes. + + :param int num_passes: The number of passes to request. + + :return IPassGroup: A group of passes bound to the given message and + of the requested size. + """ + + +@implementer(IPassGroup) +@attr.s +class PassGroup(object): + """ + Track the state of a group of passes intended as payment for an operation. + + :ivar unicode _message: The request binding message for this group of + passes. + + :ivar IPassFactory _factory: The factory which created this pass group. + + :ivar list[Pass] passes: The passes of which this group consists. + """ + _message = attr.ib() + _factory = attr.ib() + _tokens = attr.ib() + + @property + def passes(self): + return list( + pass_ + for (unblinded_token, pass_) + in self._tokens + ) + + @property + def unblinded_tokens(self): + return list( + unblinded_token + for (unblinded_token, pass_) + in self._tokens + ) + + def split(self, select_indices): + selected = [] + unselected = [] + for idx, t in enumerate(self._tokens): + if idx in select_indices: + selected.append(t) + else: + unselected.append(t) + return ( + attr.evolve(self, tokens=selected), + attr.evolve(self, tokens=unselected), + ) + + def expand(self, by_amount): + return attr.evolve( + self, + tokens=self._tokens + self._factory.get(self._message, by_amount)._tokens, + ) + + def mark_spent(self): + self._factory._mark_spent(self.unblinded_tokens) + + def mark_invalid(self, reason): + self._factory._mark_invalid(reason, self.unblinded_tokens) + + def reset(self): + self._factory._reset(self.unblinded_tokens) + + +@implementer(IPassFactory) +@attr.s +class SpendingController(object): + """ + A ``SpendingController`` gives out ZKAPs and arranges for re-spend + attempts when necessary. + """ + get_unblinded_tokens = attr.ib() + discard_unblinded_tokens = attr.ib() + invalidate_unblinded_tokens = attr.ib() + reset_unblinded_tokens = attr.ib() + + tokens_to_passes = attr.ib() + + @classmethod + def for_store(cls, tokens_to_passes, store): + return cls( + get_unblinded_tokens=store.get_unblinded_tokens, + discard_unblinded_tokens=store.discard_unblinded_tokens, + invalidate_unblinded_tokens=store.invalidate_unblinded_tokens, + reset_unblinded_tokens=store.reset_unblinded_tokens, + tokens_to_passes=tokens_to_passes, + ) + + def get(self, message, num_passes): + unblinded_tokens = self.get_unblinded_tokens(num_passes) + passes = self.tokens_to_passes(message, unblinded_tokens) + GET_PASSES.log( + message=message, + count=num_passes, + ) + return PassGroup(message, self, zip(unblinded_tokens, passes)) + + def _mark_spent(self, unblinded_tokens): + SPENT_PASSES.log( + count=len(unblinded_tokens), + ) + self.discard_unblinded_tokens(unblinded_tokens) + + def _mark_invalid(self, reason, unblinded_tokens): + INVALID_PASSES.log( + reason=reason, + count=len(unblinded_tokens), + ) + self.invalidate_unblinded_tokens(reason, unblinded_tokens) + + def _reset(self, unblinded_tokens): + RESET_PASSES.log( + count=len(unblinded_tokens), + ) + self.reset_unblinded_tokens(unblinded_tokens) diff --git a/src/_zkapauthorizer/tests/__init__.py b/src/_zkapauthorizer/tests/__init__.py index 0f9529a87a1b9837c7b34798450815918ee5a9f5..102647a022c45553eb27ea9da5dfd0e433a11941 100644 --- a/src/_zkapauthorizer/tests/__init__.py +++ b/src/_zkapauthorizer/tests/__init__.py @@ -57,6 +57,12 @@ def _configure_hypothesis(): settings.register_profile( "big", max_examples=10000, + # The only rule-based state machine we have now is quite simple and + # can probably be completely explored in about 5 steps. Give it some + # headroom beyond that in case I'm wrong but don't let it run to the + # full 50 because, combined with searching for 10000 successful + # examples this makes the stateful test take *ages* to complete. + stateful_step_count=15, **base ) diff --git a/src/_zkapauthorizer/tests/fixtures.py b/src/_zkapauthorizer/tests/fixtures.py index eb64887b6798e2d2b164e289bf095b5777276f66..00be5b25283194c4a9454d6fa5314a6695b60650 100644 --- a/src/_zkapauthorizer/tests/fixtures.py +++ b/src/_zkapauthorizer/tests/fixtures.py @@ -37,8 +37,12 @@ from allmydata.storage.server import ( from ..model import ( VoucherStore, + open_and_initialize, memory_connect, ) +from ..controller import ( + PaymentController, +) class AnonymousStorageServer(Fixture): """ @@ -82,3 +86,45 @@ class TemporaryVoucherStore(Fixture): self.get_now, memory_connect, ) + + +@attr.s +class ConfiglessMemoryVoucherStore(Fixture): + """ + Create a ``VoucherStore`` backed by an in-memory database and with no + associated Tahoe-LAFS configuration or node. + + This is like ``TemporaryVoucherStore`` but faster because it skips the + Tahoe-LAFS parts. + """ + redeemer = attr.ib() + get_now = attr.ib() + + def _setUp(self): + here = FilePath(u".") + self.store = VoucherStore( + pass_value=2 ** 15, + database_path=here, + now=self.get_now, + connection=open_and_initialize(here, memory_connect), + ) + + def redeem(self, voucher, num_passes): + """ + Redeem a voucher for some passes. + + :return: A ``Deferred`` that fires with the redemption result. + """ + return PaymentController( + self.store, + self.redeemer, + # Have to pass it here or to redeem, doesn't matter which. + default_token_count=num_passes, + # No value in splitting it into smaller groups in this case. + # Doing so only complicates the test by imposing a different + # minimum token count requirement (can't have fewer tokens + # than groups). + num_redemption_groups=1, + ).redeem( + voucher, + ) diff --git a/src/_zkapauthorizer/tests/matchers.py b/src/_zkapauthorizer/tests/matchers.py index 6c7ab457c04c6971965779bb4445517decf9e933..5ea2613373b6b2b10bb91113031f45ad8fbfd42c 100644 --- a/src/_zkapauthorizer/tests/matchers.py +++ b/src/_zkapauthorizer/tests/matchers.py @@ -54,7 +54,7 @@ class Provides(object): """ Match objects that provide all of a list of Zope Interface interfaces. """ - interfaces = attr.ib() + interfaces = attr.ib(validator=attr.validators.instance_of(list)) def match(self, obj): missing = set() @@ -154,3 +154,23 @@ def leases_current(relevant_storage_indexes, now, min_lease_remaining): ), ), ) + + +def even(): + """ + Matches even integers. + """ + return AfterPreprocessing( + lambda n: n % 2, + Equals(0), + ) + + +def odd(): + """ + Matches odd integers. + """ + return AfterPreprocessing( + lambda n: n % 2, + Equals(1), + ) diff --git a/src/_zkapauthorizer/tests/storage_common.py b/src/_zkapauthorizer/tests/storage_common.py index d00a580c29adf51f1d39583012fbe09b11555678..ddc58f3fc10b1afca250f0dbaf2377f88279509e 100644 --- a/src/_zkapauthorizer/tests/storage_common.py +++ b/src/_zkapauthorizer/tests/storage_common.py @@ -16,6 +16,10 @@ ``allmydata.storage``-related helpers shared across the test suite. """ +from functools import ( + partial, +) + from os import ( SEEK_CUR, ) @@ -23,15 +27,43 @@ from struct import ( pack, ) +from itertools import ( + islice, +) + +import attr + +from zope.interface import ( + implementer, +) + from twisted.python.filepath import ( FilePath, ) +from challenge_bypass_ristretto import ( + RandomToken, +) + from .strategies import ( # Not really a strategy... bytes_for_share, ) +from .privacypass import ( + make_passes, +) + +from ..model import ( + NotEnoughTokens, + Pass, +) + +from ..spending import ( + IPassFactory, + PassGroup, +) + # Hard-coded in Tahoe-LAFS LEASE_INTERVAL = 60 * 60 * 24 * 31 @@ -133,3 +165,129 @@ def whitebox_write_sparse_share(sharepath, version, size, leases, now): in leases ), ) + + +def integer_passes(limit): + """ + :return: A function which can be used to get a number of passes. The + function accepts a unicode request-binding message and an integer + number of passes. It returns a list of integers which serve as + passes. Successive calls to the function return unique pass values. + """ + counter = iter(range(limit)) + def get_passes(message, num_passes): + result = list(islice(counter, num_passes)) + if len(result) < num_passes: + raise NotEnoughTokens() + return result + return get_passes + + +def get_passes(message, count, signing_key): + """ + :param unicode message: Request-binding message for PrivacyPass. + + :param int count: The number of passes to get. + + :param SigningKey signing_key: The key to use to sign the passes. + + :return list[Pass]: ``count`` new random passes signed with the given key + and bound to the given message. + """ + return list( + Pass(*pass_.split(u" ")) + for pass_ + in make_passes( + signing_key, + message, + list(RandomToken.create() for n in range(count)), + ) + ) + + +def privacypass_passes(signing_key): + """ + Get a PrivacyPass issuing function. + + :param SigningKey signing_key: The key to use to issue passes. + + :return: Return a function which can be used to get a number of passes. + The function accepts a unicode request-binding message and an integer + number of passes. It returns a list of real pass values signed by the + given key. Successive calls to the function return unique passes. + """ + return partial(get_passes, signing_key=signing_key) + + +def pass_factory(get_passes): + """ + Get a new factory for passes. + + :param (unicode -> int -> [pass]) get_passes: A function the factory can + use to get new passes. + """ + return _PassFactory(get_passes=get_passes) + + +@implementer(IPassFactory) +@attr.s +class _PassFactory(object): + """ + A stateful pass issuer. + + :ivar (unicode -> int -> [bytes]) _get_passes: A function for getting + passes. + + :ivar set[int] in_use: All of the passes given out without a confirmed + terminal state. + + :ivar dict[int, unicode] invalid: All of the passes given out and returned + using ``IPassGroup.invalid`` mapped to the reason given. + + :ivar set[int] spent: All of the passes given out and returned via + ``IPassGroup.mark_spent``. + + :ivar set[int] issued: All of the passes ever given out. + + :ivar list[int] returned: A list of passes which were given out but then + returned via ``IPassGroup.reset``. + """ + _get_passes = attr.ib() + + returned = attr.ib(default=attr.Factory(list), init=False) + in_use = attr.ib(default=attr.Factory(set), init=False) + invalid = attr.ib(default=attr.Factory(dict), init=False) + spent = attr.ib(default=attr.Factory(set), init=False) + issued = attr.ib(default=attr.Factory(set), init=False) + + def get(self, message, num_passes): + passes = [] + if self.returned: + passes.extend(self.returned[:num_passes]) + del self.returned[:num_passes] + num_passes -= len(passes) + passes.extend(self._get_passes(message, num_passes)) + self.issued.update(passes) + self.in_use.update(passes) + return PassGroup(message, self, zip(passes, passes)) + + def _mark_spent(self, passes): + for p in passes: + if p not in self.in_use: + raise ValueError("Pass {} cannot be spent, it is not in use.".format(p)) + self.spent.update(passes) + self.in_use.difference_update(passes) + + def _mark_invalid(self, reason, passes): + for p in passes: + if p not in self.in_use: + raise ValueError("Pass {} cannot be invalid, it is not in use.".format(p)) + self.invalid.update(dict.fromkeys(passes, reason)) + self.in_use.difference_update(passes) + + def _reset(self, passes): + for p in passes: + if p not in self.in_use: + raise ValueError("Pass {} cannot be reset, it is not in use.".format(p)) + self.returned.extend(passes) + self.in_use.difference_update(passes) diff --git a/src/_zkapauthorizer/tests/strategies.py b/src/_zkapauthorizer/tests/strategies.py index 28028fd87ad78348725343f9ac19bf710c6eb040..0c448cda3cce269ab18715c4de2fa560837b80d8 100644 --- a/src/_zkapauthorizer/tests/strategies.py +++ b/src/_zkapauthorizer/tests/strategies.py @@ -813,3 +813,12 @@ def node_hierarchies(): ).filter( storage_indexes_are_distinct, ) + + +def pass_counts(): + """ + Build integers usable as a number of passes to work on. There is always + at least one pass in a group and there are never "too many", whatever that + means. + """ + return integers(min_value=1, max_value=2 ** 8) diff --git a/src/_zkapauthorizer/tests/test_client_resource.py b/src/_zkapauthorizer/tests/test_client_resource.py index 7aabbdb359b3ad9a70644cfd806520884d84fdb3..9ff7ffb7f1e246ff8b5093a64661b2af291d5f8a 100644 --- a/src/_zkapauthorizer/tests/test_client_resource.py +++ b/src/_zkapauthorizer/tests/test_client_resource.py @@ -523,7 +523,9 @@ class UnblindedTokenTests(TestCase): return d def use_a_token(): - root.store.extract_unblinded_tokens(1) + root.store.discard_unblinded_tokens( + root.store.get_unblinded_tokens(1), + ) tempdir = self.useFixture(TempDir()) config = get_config(tempdir.join(b"tahoe"), b"tub.port") diff --git a/src/_zkapauthorizer/tests/test_model.py b/src/_zkapauthorizer/tests/test_model.py index e13856f349176a47cd1f4347cc1f471d38a66945..46a794e7f99d09f151ebd876242ce831a1ebb11c 100644 --- a/src/_zkapauthorizer/tests/test_model.py +++ b/src/_zkapauthorizer/tests/test_model.py @@ -28,6 +28,7 @@ from errno import ( EACCES, ) from datetime import ( + datetime, timedelta, ) @@ -39,6 +40,8 @@ from testtools import ( TestCase, ) from testtools.matchers import ( + Always, + HasLength, AfterPreprocessing, MatchesStructure, MatchesAll, @@ -46,15 +49,26 @@ from testtools.matchers import ( Raises, IsInstance, ) +from testtools.twistedsupport import ( + succeeded, +) from fixtures import ( TempDir, ) from hypothesis import ( + note, given, + assume, +) +from hypothesis.stateful import ( + RuleBasedStateMachine, + rule, + precondition, + invariant, + run_state_machine_as_test ) - from hypothesis.strategies import ( data, booleans, @@ -63,6 +77,7 @@ from hypothesis.strategies import ( datetimes, timedeltas, integers, + randoms, ) from twisted.python.runtime import ( @@ -80,7 +95,9 @@ from ..model import ( LeaseMaintenanceActivity, memory_connect, ) - +from ..controller import ( + DummyRedeemer, +) from .strategies import ( tahoe_configs, vouchers, @@ -90,9 +107,11 @@ from .strategies import ( unblinded_tokens, posix_safe_datetimes, dummy_ristretto_keys, + pass_counts, ) from .fixtures import ( TemporaryVoucherStore, + ConfiglessMemoryVoucherStore, ) from .matchers import ( raises, @@ -314,7 +333,7 @@ class VoucherStoreTests(TestCase): def test_spend_order_equals_backup_order(self, get_config, voucher_value, public_key, now, data): """ Unblinded tokens returned by ``VoucherStore.backup`` appear in the same - order as they are returned ``VoucherStore.extract_unblinded_tokens``. + order as they are returned by ``VoucherStore.get_unblinded_tokens``. """ backed_up_tokens, spent_tokens, inserted_tokens = self._spend_order_test( get_config, @@ -332,7 +351,7 @@ class VoucherStoreTests(TestCase): @given(tahoe_configs(), vouchers(), dummy_ristretto_keys(), datetimes(), data()) def test_spend_order_equals_insert_order(self, get_config, voucher_value, public_key, now, data): """ - Unblinded tokens returned by ``VoucherStore.extract_unblinded_tokens`` + Unblinded tokens returned by ``VoucherStore.get_unblinded_tokens`` appear in the same order as they were inserted. """ backed_up_tokens, spent_tokens, inserted_tokens = self._spend_order_test( @@ -386,7 +405,7 @@ class VoucherStoreTests(TestCase): extracted_tokens.extend( token.unblinded_token for token - in store.extract_unblinded_tokens(to_spend) + in store.get_unblinded_tokens(to_spend) ) tokens_remaining -= to_spend @@ -397,6 +416,185 @@ class VoucherStoreTests(TestCase): ) +class UnblindedTokenStateMachine(RuleBasedStateMachine): + """ + Transition rules for a state machine corresponding to the state of + unblinded tokens in a ``VoucherStore`` - usable, in-use, spent, invalid, + etc. + """ + def __init__(self, case): + super(UnblindedTokenStateMachine, self).__init__() + self.case = case + self.redeemer = DummyRedeemer() + self.configless = ConfiglessMemoryVoucherStore( + self.redeemer, + # Time probably not actually relevant to this state machine. + datetime.now, + ) + self.configless.setUp() + + self.available = 0 + self.using = [] + self.spent = [] + self.invalid = [] + + def teardown(self): + self.configless.cleanUp() + + @rule(voucher=vouchers(), num_passes=pass_counts()) + def redeem_voucher(self, voucher, num_passes): + """ + A voucher can be redeemed, adding more unblinded tokens to the store. + """ + try: + self.configless.store.get(voucher) + except KeyError: + pass + else: + # Cannot redeem a voucher more than once. We redeemed this one + # already. + assume(False) + + self.case.assertThat( + self.configless.redeem(voucher, num_passes), + succeeded(Always()), + ) + self.available += num_passes + + @rule(num_passes=pass_counts()) + def get_passes(self, num_passes): + """ + Some passes can be requested from the store. The resulting passes are not + spent, invalid, or already in-use. + """ + assume(num_passes <= self.available) + tokens = self.configless.store.get_unblinded_tokens(num_passes) + note("get_passes: {}".format(tokens)) + + # No tokens we are currently using may be returned again. Nor may + # tokens which have reached a terminal state of spent or invalid. + unavailable = set(self.using) | set(self.spent) | set(self.invalid) + + self.case.assertThat( + tokens, + MatchesAll( + HasLength(num_passes), + AfterPreprocessing( + lambda t: set(t) & unavailable, + Equals(set()), + ), + ), + ) + self.using.extend(tokens) + self.available -= num_passes + + @rule(excess_passes=pass_counts()) + def not_enough_passes(self, excess_passes): + """ + If an attempt is made to get more passes than are available, + ``get_unblinded_tokens`` raises ``NotEnoughTokens``. + """ + self.case.assertThat( + lambda: self.configless.store.get_unblinded_tokens( + self.available + excess_passes, + ), + raises(NotEnoughTokens), + ) + + @precondition(lambda self: len(self.using) > 0) + @rule(random=randoms(), data=data()) + def spend_passes(self, random, data): + """ + Some in-use passes can be discarded. + """ + self.using, to_spend = random_slice(self.using, random, data) + note("spend_passes: {}".format(to_spend)) + self.configless.store.discard_unblinded_tokens(to_spend) + + @precondition(lambda self: len(self.using) > 0) + @rule(random=randoms(), data=data()) + def reset_passes(self, random, data): + """ + Some in-use passes can be returned to not-in-use state. + """ + self.using, to_reset = random_slice(self.using, random, data) + note("reset_passes: {}".format(to_reset)) + self.configless.store.reset_unblinded_tokens(to_reset) + self.available += len(to_reset) + + @precondition(lambda self: len(self.using) > 0) + @rule(random=randoms(), data=data()) + def invalidate_passes(self, random, data): + """ + Some in-use passes are unusable and should be set aside. + """ + self.using, to_invalidate = random_slice(self.using, random, data) + note("invalidate_passes: {}".format(to_invalidate)) + self.configless.store.invalidate_unblinded_tokens( + u"reason", + to_invalidate, + ) + self.invalid.extend(to_invalidate) + + @rule() + def discard_ephemeral_state(self): + """ + Reset all state that cannot outlive a single process, simulating a + restart. + + XXX We have to reach into the guts of ``VoucherStore`` to do this + because we're using an in-memory database. We can't just open a new + ``VoucherStore``. :/ Perhaps we should use an on-disk database... Or + maybe this is a good argument for using an explicitly attached + temporary database instead of the built-in ``temp`` database. + """ + with self.configless.store._connection: + self.configless.store._connection.execute( + """ + DELETE FROM [in-use] + """, + ) + self.available += len(self.using) + del self.using[:] + + @invariant() + def report_state(self): + note("available={} using={} invalid={} spent={}".format( + self.available, + len(self.using), + len(self.invalid), + len(self.spent), + )) + + +def random_slice(taken_from, random, data): + """ + Divide ``taken_from`` into two pieces with elements randomly assigned to + one piece or the other. + + :param list taken_from: A list of elements to divide. This will be + mutated. + + :param random: A ``random`` module-alike. + + :param data: A Hypothesis data object for drawing values. + + :return: A two-tuple of the two resulting lists. + """ + count = data.draw(integers(min_value=1, max_value=len(taken_from))) + random.shuffle(taken_from) + remaining = taken_from[:-count] + sliced = taken_from[-count:] + return remaining, sliced + + +class UnblindedTokenStateTests(TestCase): + """ + Glue ``UnblindedTokenStateTests`` into our test runner. + """ + def test_states(self): + run_state_machine_as_test(lambda: UnblindedTokenStateMachine(self)) + class LeaseMaintenanceTests(TestCase): """ @@ -552,19 +750,13 @@ class UnblindedTokenStoreTests(TestCase): store = self.useFixture(TemporaryVoucherStore(get_config, lambda: now)).store store.add(voucher_value, len(random_tokens), 0, lambda: random_tokens) store.insert_unblinded_tokens_for_voucher(voucher_value, public_key, unblinded_tokens, completed) - retrieved_tokens = store.extract_unblinded_tokens(len(random_tokens)) + retrieved_tokens = store.get_unblinded_tokens(len(random_tokens)) self.expectThat( set(unblinded_tokens), Equals(set(retrieved_tokens)), ) - # After extraction, the unblinded tokens are no longer available. - self.assertThat( - lambda: store.extract_unblinded_tokens(1), - raises(NotEnoughTokens), - ) - @given( tahoe_configs(), datetimes(), @@ -692,44 +884,30 @@ class UnblindedTokenStoreTests(TestCase): vouchers(), dummy_ristretto_keys(), booleans(), - integers(min_value=1, max_value=100), integers(min_value=1), data(), ) - def test_not_enough_unblinded_tokens(self, get_config, now, voucher_value, public_key, completed, num_tokens, extra, data): + def test_not_enough_unblinded_tokens(self, get_config, now, voucher_value, public_key, completed, extra, data): """ - ``extract_unblinded_tokens`` raises ``NotEnoughTokens`` if ``count`` is + ``get_unblinded_tokens`` raises ``NotEnoughTokens`` if ``count`` is greater than the number of unblinded tokens in the store. """ - random = data.draw( - lists( - random_tokens(), - min_size=num_tokens, - max_size=num_tokens, - unique=True, - ), - ) - unblinded = data.draw( - lists( - unblinded_tokens(), - min_size=num_tokens, - max_size=num_tokens, - unique=True, - ), - ) + random, unblinded = paired_tokens(data) + num_tokens = len(random) store = self.useFixture(TemporaryVoucherStore(get_config, lambda: now)).store store.add(voucher_value, len(random), 0, lambda: random) - store.insert_unblinded_tokens_for_voucher(voucher_value, public_key, unblinded, completed) - + store.insert_unblinded_tokens_for_voucher( + voucher_value, + public_key, + unblinded, + completed, + ) self.assertThat( - lambda: store.extract_unblinded_tokens(num_tokens + extra), + lambda: store.get_unblinded_tokens(num_tokens + extra), raises(NotEnoughTokens), ) - # TODO: Other error states and transient states - - def store_for_test(testcase, get_config, get_now): """ Create a ``VoucherStore`` in a temporary directory associated with the diff --git a/src/_zkapauthorizer/tests/test_plugin.py b/src/_zkapauthorizer/tests/test_plugin.py index ebd714863a3ab95a590826698001ba4cac469965..ce04c94e3bc9fbae49b38e1f7c5d0d3f1e1adaa2 100644 --- a/src/_zkapauthorizer/tests/test_plugin.py +++ b/src/_zkapauthorizer/tests/test_plugin.py @@ -104,7 +104,7 @@ from twisted.plugins.zkapauthorizer import ( storage_server, ) -from .._plugin import ( +from ..spending import ( GET_PASSES, ) @@ -415,7 +415,7 @@ class ClientPluginTests(TestCase): size=sizes(), ) @capture_logging(lambda self, logger: logger.validate()) - def test_unblinded_tokens_extracted( + def test_unblinded_tokens_spent( self, logger, get_config, @@ -430,7 +430,7 @@ class ClientPluginTests(TestCase): ): """ The ``ZKAPAuthorizerStorageServer`` returned by ``get_storage_client`` - extracts unblinded tokens from the plugin database. + spends unblinded tokens from the plugin database. """ tempdir = self.useFixture(TempDir()) node_config = get_config( @@ -476,7 +476,7 @@ class ClientPluginTests(TestCase): # There should be no unblinded tokens left to extract. self.assertThat( - lambda: store.extract_unblinded_tokens(1), + lambda: store.get_unblinded_tokens(1), raises(NotEnoughTokens), ) diff --git a/src/_zkapauthorizer/tests/test_spending.py b/src/_zkapauthorizer/tests/test_spending.py new file mode 100644 index 0000000000000000000000000000000000000000..e55f289a3936a709566101f2effe35fecb2855dc --- /dev/null +++ b/src/_zkapauthorizer/tests/test_spending.py @@ -0,0 +1,211 @@ +# Copyright 2019 PrivateStorage.io, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for ``_zkapauthorizer.spending``. +""" + +from testtools import ( + TestCase, +) +from testtools.matchers import ( + Always, + Equals, + MatchesAll, + MatchesStructure, + HasLength, + AfterPreprocessing, +) +from testtools.twistedsupport import ( + succeeded, +) + +from hypothesis import ( + given, +) +from hypothesis.strategies import ( + integers, + randoms, + data, +) + +from .strategies import ( + vouchers, + pass_counts, + posix_safe_datetimes, +) +from .matchers import ( + Provides, +) +from .fixtures import ( + ConfiglessMemoryVoucherStore, +) +from ..controller import ( + DummyRedeemer, +) +from ..spending import ( + IPassGroup, + SpendingController, +) + +class PassGroupTests(TestCase): + """ + Tests for ``IPassGroup`` and the factories that create them. + """ + @given(vouchers(), pass_counts(), posix_safe_datetimes()) + def test_get(self, voucher, num_passes, now): + """ + ``IPassFactory.get`` returns an ``IPassGroup`` provider containing the + requested number of passes. + """ + configless = self.useFixture( + ConfiglessMemoryVoucherStore( + DummyRedeemer(), + lambda: now, + ), + ) + # Make sure there are enough tokens for us to extract! + self.assertThat( + configless.redeem(voucher, num_passes), + succeeded(Always()), + ) + + pass_factory = SpendingController.for_store( + tokens_to_passes=configless.redeemer.tokens_to_passes, + store=configless.store, + ) + + group = pass_factory.get(u"message", num_passes) + self.assertThat( + group, + MatchesAll( + Provides([IPassGroup]), + MatchesStructure( + passes=HasLength(num_passes), + ), + ), + ) + + def _test_token_group_operation( + self, + operation, + matches_tokens, + voucher, + num_passes, + now, + random, + data, + ): + configless = self.useFixture( + ConfiglessMemoryVoucherStore( + DummyRedeemer(), + lambda: now, + ), + ) + # Make sure there are enough tokens for us to use! + self.assertThat( + configless.redeem(voucher, num_passes), + succeeded(Always()), + ) + + # Figure out some subset, maybe empty, of passes from the group that + # we will try to operate on. + group_size = data.draw(integers(min_value=0, max_value=num_passes)) + indices = range(num_passes) + random.shuffle(indices) + spent_indices = indices[:group_size] + + # Get some passes and perform the operation. + pass_factory = SpendingController.for_store( + tokens_to_passes=configless.redeemer.tokens_to_passes, + store=configless.store, + ) + group = pass_factory.get(u"message", num_passes) + spent, rest = group.split(spent_indices) + operation(spent) + + # Verify the expected outcome of the operation using the supplied + # matcher factory. + self.assertThat( + configless.store, + matches_tokens(num_passes, spent), + ) + + @given(vouchers(), pass_counts(), posix_safe_datetimes(), randoms(), data()) + def test_spent(self, voucher, num_passes, now, random, data): + """ + Passes in a group can be marked as successfully spent to prevent them from + being re-used by a future ``get`` call. + """ + def matches_tokens(num_passes, group): + return AfterPreprocessing( + # The use of `backup` here to check is questionable. TODO: + # Straight-up query interface for tokens in different states. + lambda store: store.backup()[u"unblinded-tokens"], + HasLength(num_passes - len(group.passes)), + ) + return self._test_token_group_operation( + lambda group: group.mark_spent(), + matches_tokens, + voucher, + num_passes, + now, + random, + data, + ) + + @given(vouchers(), pass_counts(), posix_safe_datetimes(), randoms(), data()) + def test_invalid(self, voucher, num_passes, now, random, data): + """ + Passes in a group can be marked as invalid to prevent them from being + re-used by a future ``get`` call. + """ + def matches_tokens(num_passes, group): + return AfterPreprocessing( + # The use of `backup` here to check is questionable. TODO: + # Straight-up query interface for tokens in different states. + lambda store: store.backup()[u"unblinded-tokens"], + HasLength(num_passes - len(group.passes)), + ) + return self._test_token_group_operation( + lambda group: group.mark_invalid(u"reason"), + matches_tokens, + voucher, + num_passes, + now, + random, + data, + ) + + @given(vouchers(), pass_counts(), posix_safe_datetimes(), randoms(), data()) + def test_reset(self, voucher, num_passes, now, random, data): + """ + Passes in a group can be reset to allow them to be re-used by a future + ``get`` call. + """ + def matches_tokens(num_passes, group): + return AfterPreprocessing( + # They've been reset so we should be able to re-get them. + lambda store: store.get_unblinded_tokens(len(group.passes)), + Equals(group.unblinded_tokens), + ) + return self._test_token_group_operation( + lambda group: group.reset(), + matches_tokens, + voucher, + num_passes, + now, + random, + data, + ) diff --git a/src/_zkapauthorizer/tests/test_storage_client.py b/src/_zkapauthorizer/tests/test_storage_client.py index 77018c2928f1c7cd991093ed416e12c5ce3a4dff..611bc6127d252a4d28eceda6caa1c3faacfe4886 100644 --- a/src/_zkapauthorizer/tests/test_storage_client.py +++ b/src/_zkapauthorizer/tests/test_storage_client.py @@ -16,11 +16,12 @@ Tests for ``_zkapauthorizer._storage_client``. """ -import attr +from __future__ import ( + division, +) -from itertools import ( - count, - islice, +from functools import ( + partial, ) from testtools import ( @@ -31,6 +32,11 @@ from testtools.matchers import ( Is, Equals, AfterPreprocessing, + MatchesStructure, + HasLength, + MatchesAll, + AllMatch, + IsInstance, ) from testtools.twistedsupport import ( succeeded, @@ -41,7 +47,7 @@ from hypothesis import ( given, ) from hypothesis.strategies import ( - integers, + sampled_from, ) from twisted.internet.defer import ( @@ -49,10 +55,22 @@ from twisted.internet.defer import ( fail, ) +from .matchers import ( + even, + odd, + raises, +) + +from .strategies import ( + pass_counts, +) + from ..api import ( MorePassesRequired, ) - +from ..model import ( + NotEnoughTokens, +) from .._storage_client import ( call_with_passes, ) @@ -60,30 +78,10 @@ from .._storage_server import ( _ValidationResult, ) -def pass_counts(): - return integers(min_value=1, max_value=2 ** 8) - - -def pass_factory(): - return _PassFactory() - -@attr.s -class _PassFactory(object): - """ - A stateful pass issuer. - - :ivar list spent: All of the passes ever issued. - - :ivar _fountain: A counter for making each new pass issued unique. - """ - spent = attr.ib(default=attr.Factory(list)) - - _fountain = attr.ib(default=attr.Factory(count)) - - def get(self, num_passes): - passes = list(islice(self._fountain, num_passes)) - self.spent.extend(passes) - return passes +from .storage_common import ( + pass_factory, + integer_passes, +) class CallWithPassesTests(TestCase): @@ -100,9 +98,9 @@ class CallWithPassesTests(TestCase): result = object() self.assertThat( call_with_passes( - lambda passes: succeed(result), + lambda group: succeed(result), num_passes, - pass_factory().get, + partial(pass_factory(integer_passes(num_passes)).get, u"message"), ), succeeded(Is(result)), ) @@ -117,9 +115,9 @@ class CallWithPassesTests(TestCase): result = Exception() self.assertThat( call_with_passes( - lambda passes: fail(result), + lambda group: fail(result), num_passes, - pass_factory().get, + partial(pass_factory(integer_passes(num_passes)).get, u"message"), ), failed( AfterPreprocessing( @@ -130,27 +128,71 @@ class CallWithPassesTests(TestCase): ) @given(pass_counts()) - def test_passes(self, num_passes): + def test_passes_issued(self, num_passes): """ - ``call_with_passes`` calls the given method with a list of passes - containing ``num_passes`` created by the function passed for + ``call_with_passes`` calls the given method with an ``IPassGroup`` + provider containing ``num_passes`` created by the function passed for ``get_passes``. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes)) self.assertThat( call_with_passes( - lambda passes: succeed(passes), + lambda group: succeed(group.passes), num_passes, - passes.get, + partial(passes.get, u"message"), ), succeeded( Equals( - passes.spent, + sorted(passes.issued), ), ), ) + @given(pass_counts()) + def test_passes_spent_on_success(self, num_passes): + """ + ``call_with_passes`` marks the passes it uses as spent if the operation + succeeds. + """ + passes = pass_factory(integer_passes(num_passes)) + + self.assertThat( + call_with_passes( + lambda group: None, + num_passes, + partial(passes.get, u"message"), + ), + succeeded(Always()), + ) + self.assertThat( + passes.issued, + Equals(passes.spent), + ) + + @given(pass_counts()) + def test_passes_returned_on_failure(self, num_passes): + """ + ``call_with_passes`` returns the passes it uses if the operation fails. + """ + passes = pass_factory(integer_passes(num_passes)) + + self.assertThat( + call_with_passes( + lambda group: fail(Exception("Anything")), + num_passes, + partial(passes.get, u"message"), + ), + failed(Always()), + ) + self.assertThat( + passes, + MatchesStructure( + issued=Equals(set(passes.returned)), + spent=Equals(set()), + ), + ) + @given(pass_counts()) def test_retry_on_rejected_passes(self, num_passes): """ @@ -158,9 +200,12 @@ class CallWithPassesTests(TestCase): of passes, still of length ```num_passes``, but without the passes which were rejected on the first try. """ - passes = pass_factory() + # Half of the passes are going to be rejected so make twice as many as + # the operation uses available. + passes = pass_factory(integer_passes(num_passes * 2)) - def reject_even_pass_values(passes): + def reject_even_pass_values(group): + passes = group.passes good_passes = list(idx for (idx, p) in enumerate(passes) if p % 2) bad_passes = list(idx for (idx, p) in enumerate(passes) if idx not in good_passes) if len(good_passes) < num_passes: @@ -174,10 +219,26 @@ class CallWithPassesTests(TestCase): call_with_passes( reject_even_pass_values, num_passes, - passes.get, + partial(passes.get, u"message"), ), succeeded(Always()), ) + self.assertThat( + passes, + MatchesStructure( + returned=HasLength(0), + in_use=HasLength(0), + invalid=MatchesAll( + HasLength(num_passes), + AllMatch(even()), + ), + spent=MatchesAll( + HasLength(num_passes), + AllMatch(odd()), + ), + issued=Equals(passes.spent | set(passes.invalid.keys())), + ), + ) @given(pass_counts()) def test_pass_through_too_few_passes(self, num_passes): @@ -186,9 +247,10 @@ class CallWithPassesTests(TestCase): no passes have been marked as invalid. This happens if all passes given were valid but too fewer were given. """ - passes = pass_factory() + passes = pass_factory(integer_passes(num_passes)) - def reject_passes(passes): + def reject_passes(group): + passes = group.passes _ValidationResult( valid=range(len(passes)), signature_check_failed=[], @@ -198,7 +260,7 @@ class CallWithPassesTests(TestCase): call_with_passes( reject_passes, num_passes, - passes.get, + partial(passes.get, u"message"), ), failed( AfterPreprocessing( @@ -213,3 +275,172 @@ class CallWithPassesTests(TestCase): ), ), ) + + # The passes in the group that was rejected are also returned for + # later use. + self.assertThat( + passes, + MatchesStructure( + spent=HasLength(0), + returned=HasLength(num_passes), + ), + ) + + @given(pass_counts(), pass_counts()) + def test_not_enough_tokens_for_retry(self, num_passes, extras): + """ + When there are not enough tokens to successfully complete a retry with the + required number of passes, ``call_with_passes`` marks all passes + reported as invalid during its efforts as such and resets all other + passes it acquired. + """ + passes = pass_factory(integer_passes(num_passes + extras)) + rejected = [] + accepted = [] + + def reject_half_passes(group): + num = len(group.passes) + # Floor division will always short-change valid here, even for a + # group size of 1. Therefore there will always be some passes + # marked as invalid. + accept_indexes = range(num // 2) + reject_indexes = range(num // 2, num) + # Only keep this iteration's accepted passes. We'll want to see + # that the final iteration's passes are all returned. Passes from + # earlier iterations don't matter. + accepted[:] = list(group.passes[i] for i in accept_indexes) + # On the other hand, keep *all* rejected passes. They should all + # be marked as invalid and we want to make sure that's the case, + # no matter which iteration rejected them. + rejected.extend(group.passes[i] for i in reject_indexes) + _ValidationResult( + valid=accept_indexes, + signature_check_failed=reject_indexes, + ).raise_for(num) + + self.assertThat( + call_with_passes( + # Since half of every group is rejected, we'll eventually run + # out of passes no matter how many we start with. + reject_half_passes, + num_passes, + partial(passes.get, u"message"), + ), + failed( + AfterPreprocessing( + lambda f: f.value, + IsInstance(NotEnoughTokens), + ), + ), + ) + self.assertThat( + passes, + MatchesStructure( + # Whatever is left in the group when we run out of tokens must + # be returned. + returned=Equals(accepted), + in_use=HasLength(0), + invalid=AfterPreprocessing( + lambda invalid: invalid.keys(), + Equals(rejected), + ), + spent=HasLength(0), + issued=Equals(set(accepted + rejected)), + ), + ) + +def reset(group): + group.reset() + +def spend(group): + group.mark_spent() + +def invalidate(group): + group.mark_invalid(u"reason") + + +class PassFactoryTests(TestCase): + """ + Tests for ``pass_factory``. + + It is unfortunate that this isn't the same test suite as + ``test_spending.PassGroupTests``. + """ + @given(pass_counts(), pass_counts()) + def test_returned_passes_reused(self, num_passes_a, num_passes_b): + """ + ``IPassGroup.reset`` makes passes available to be returned by + ``IPassGroup.get`` again. + """ + message = u"message" + min_passes = min(num_passes_a, num_passes_b) + max_passes = max(num_passes_a, num_passes_b) + + factory = pass_factory(integer_passes(max_passes)) + group_a = factory.get(message, num_passes_a) + group_a.reset() + + group_b = factory.get(message, num_passes_b) + self.assertThat( + group_a.passes[:min_passes], + Equals(group_b.passes[:min_passes]), + ) + + def _test_disallowed_transition(self, num_passes, setup_op, invalid_op): + """ + Assert that after some setup operation completes, another operation raises + ``ValueError``. + + :param int num_passes: The number of passes to make available from the + factory. + + :param (IPassGroup -> None) setup_op: Some initial operation to + perform with the pass group. + + :param (IPassGroup -> None) invalid_op: Some follow-up operation to + perform with the pass group and to assert raises an exception. + """ + message = u"message" + factory = pass_factory(integer_passes(num_passes)) + group = factory.get(message, num_passes) + setup_op(group) + self.assertThat( + lambda: invalid_op(group), + raises(ValueError), + ) + + @given(pass_counts(), sampled_from([reset, spend, invalidate])) + def test_not_spendable(self, num_passes, setup_op): + """ + ``PassGroup.mark_spent`` raises ``ValueError`` if any passes in the group + are in a state other than in-use. + """ + self._test_disallowed_transition( + num_passes, + setup_op, + spend, + ) + + @given(pass_counts(), sampled_from([reset, spend, invalidate])) + def test_not_resetable(self, num_passes, setup_op): + """ + ``PassGroup.reset`` raises ``ValueError`` if any passes in the group are + in a state other than in-use. + """ + self._test_disallowed_transition( + num_passes, + setup_op, + reset, + ) + + @given(pass_counts(), sampled_from([reset, spend, invalidate])) + def test_not_invalidateable(self, num_passes, setup_op): + """ + ``PassGroup.mark_invalid`` raises ``ValueError`` if any passes in the + group are in a state other than in-use. + """ + self._test_disallowed_transition( + num_passes, + setup_op, + invalidate, + ) diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py index bb79fb25e919110d0454595fe55cb81203dfbac9..b649753880ae5026f1adadf0ca5e5118244df791 100644 --- a/src/_zkapauthorizer/tests/test_storage_protocol.py +++ b/src/_zkapauthorizer/tests/test_storage_protocol.py @@ -27,6 +27,7 @@ from testtools import ( TestCase, ) from testtools.matchers import ( + Always, Equals, HasLength, IsInstance, @@ -67,7 +68,6 @@ from foolscap.referenceable import ( ) from challenge_bypass_ristretto import ( - RandomToken, random_signing_key, ) @@ -79,9 +79,6 @@ from .common import ( skipIf, ) -from .privacypass import ( - make_passes, -) from .strategies import ( storage_indexes, lease_renew_secrets, @@ -107,6 +104,9 @@ from .storage_common import ( cleanup_storage_server, write_toy_shares, whitebox_write_sparse_share, + get_passes, + privacypass_passes, + pass_factory, ) from .foolscap import ( LocalRemote, @@ -122,8 +122,8 @@ from ..storage_common import ( get_implied_data_length, required_passes, ) -from ..model import ( - Pass, +from .._storage_client import ( + _encode_passes, ) from ..foolscap import ( ShareStat, @@ -167,36 +167,12 @@ class RequiredPassesTests(TestCase): ) -def get_passes(message, count, signing_key): - """ - :param unicode message: Request-binding message for PrivacyPass. - - :param int count: The number of passes to get. - - :param SigningKEy signing_key: The key to use to sign the passes. - - :return list[Pass]: ``count`` new random passes signed with the given key - and bound to the given message. - """ - return list( - Pass(*pass_.split(u" ")) - for pass_ - in make_passes( - signing_key, - message, - list(RandomToken.create() for n in range(count)), - ) - ) - - class ShareTests(TestCase): """ Tests for interaction with shares. - :ivar int spent_passes: The number of passes which have been spent so far - in the course of a single test (in the case of Hypothesis, every - iteration of the test so far, probably; so make relative comparisons - instead of absolute ones). + :ivar pass_factory: An object which is responsible for creating passes + which are used by these tests. """ pass_value = 128 * 1024 @@ -205,11 +181,8 @@ class ShareTests(TestCase): self.canary = LocalReferenceable(None) self.anonymous_storage_server = self.useFixture(AnonymousStorageServer()).storage_server self.signing_key = random_signing_key() - self.spent_passes = 0 - def counting_get_passes(message, count): - self.spent_passes += count - return get_passes(message, count, self.signing_key) + self.pass_factory = pass_factory(get_passes=privacypass_passes(self.signing_key)) self.server = ZKAPAuthorizerStorageServer( self.anonymous_storage_server, @@ -220,7 +193,7 @@ class ShareTests(TestCase): self.client = ZKAPAuthorizerStorageClient( self.pass_value, get_rref=lambda: self.local_remote_server, - get_passes=counting_get_passes, + get_passes=self.pass_factory.get, ) def test_get_version(self): @@ -411,12 +384,13 @@ class ShareTests(TestCase): canary=self.canary, ) - extract_result( + self.assertThat( self.client.add_lease( storage_index, renew_lease_secret, cancel_secret, ), + succeeded(Always()), ) leases = list(self.anonymous_storage_server.get_leases(storage_index)) self.assertThat(leases, HasLength(2)) @@ -453,11 +427,12 @@ class ShareTests(TestCase): ) now += 100000 - extract_result( + self.assertThat( self.client.renew_lease( storage_index, renew_secret, ), + succeeded(Always()), ) [lease] = self.anonymous_storage_server.get_leases(storage_index) @@ -495,9 +470,6 @@ class ShareTests(TestCase): finally: patch.cleanUp() - stats = extract_result( - self.client.stat_shares([storage_index]), - ) expected = [{ sharenum: ShareStat( size=size, @@ -505,8 +477,8 @@ class ShareTests(TestCase): ), }] self.assertThat( - stats, - Equals(expected), + self.client.stat_shares([storage_index]), + succeeded(Equals(expected)), ) @given( @@ -722,9 +694,6 @@ class ShareTests(TestCase): u"Server rejected a write to a new mutable slot", ) - stats = extract_result( - self.client.stat_shares([storage_index]), - ) expected = [{ sharenum: ShareStat( size=get_implied_data_length( @@ -737,8 +706,8 @@ class ShareTests(TestCase): in test_and_write_vectors_for_shares.items() }] self.assertThat( - stats, - Equals(expected), + self.client.stat_shares([storage_index]), + succeeded(Equals(expected)), ) @@ -772,13 +741,14 @@ class ShareTests(TestCase): canary=self.canary, ) - extract_result( + self.assertThat( self.client.advise_corrupt_share( b"immutable", storage_index, sharenum, b"the bits look bad", ), + succeeded(Always()), ) self.assertThat( FilePath(self.anonymous_storage_server.corruption_advisory_dir).children(), @@ -826,12 +796,12 @@ class ShareTests(TestCase): u"Server gave back read results when we asked for none.", ) # Now we can read it back without spending any more passes. - before_spent_passes = self.spent_passes + before_passes = len(self.pass_factory.issued) assert_read_back_data(self, storage_index, secrets, test_and_write_vectors_for_shares) - after_spent_passes = self.spent_passes + after_passes = len(self.pass_factory.issued) self.assertThat( - before_spent_passes, - Equals(after_spent_passes), + before_passes, + Equals(after_passes), ) @given( @@ -924,9 +894,11 @@ class ShareTests(TestCase): d = self.local_remote_server.callRemote( "slot_testv_and_readv_and_writev", # passes - self.client._get_encoded_passes( - slot_testv_and_readv_and_writev_message(storage_index), - 1, + _encode_passes( + self.pass_factory.get( + slot_testv_and_readv_and_writev_message(storage_index), + 1, + ), ), # storage_index storage_index,