Skip to content
Snippets Groups Projects
Unverified Commit b1a7d1c9 authored by Jean-Paul Calderone's avatar Jean-Paul Calderone
Browse files

Handle tokens running out while doing retries

parent 05b4220e
No related branches found
No related tags found
1 merge request!155Client-side two-phase spending protocol to reduce pass loss
......@@ -84,10 +84,9 @@ class IncorrectStorageServerReference(Exception):
)
def replace_invalid_passes_with_new_passes(passes, more_passes_required):
def invalidate_rejected_passes(passes, more_passes_required):
"""
Replace all rejected passes in the given pass group with new ones. Mark
any rejected passes as rejected.
Return a new ``IPassGroup`` with all rejected passes removed from it.
:param IPassGroup passes: A group of passes, some of which may have been
rejected.
......@@ -115,7 +114,14 @@ def replace_invalid_passes_with_new_passes(passes, more_passes_required):
SIGNATURE_CHECK_FAILED.log(count=num_failed)
rejected_passes, okay_passes = passes.split(more_passes_required.signature_check_failed)
rejected_passes.mark_invalid(u"signature check failed")
return okay_passes.expand(len(more_passes_required.signature_check_failed))
# It would be great to just expand okay_passes right here. However, if
# that fails (eg because we don't have enough tokens remaining) then the
# caller will have a hard time figuring out which okay passes remain that
# it needs to reset. :/ So, instead, pass back the complete okay set. The
# caller can figure out by how much to expand it by considering its size
# and the original number of passes it requested.
return okay_passes
@inline_callbacks
......@@ -142,28 +148,33 @@ def call_with_passes(method, num_passes, get_passes):
that trigger a retry).
"""
with CALL_WITH_PASSES(count=num_passes):
passes = get_passes(num_passes)
pass_group = get_passes(num_passes)
try:
# Try and repeat as necessary.
while True:
try:
result = yield method(passes)
result = yield method(pass_group)
except MorePassesRequired as e:
updated_passes = replace_invalid_passes_with_new_passes(
passes,
okay_pass_group = invalidate_rejected_passes(
pass_group,
e,
)
if updated_passes is None:
if okay_pass_group is None:
raise
else:
passes = updated_passes
# Update the local in case we end up going to the
# except suite below.
pass_group = okay_pass_group
# Add the necessary number of new passes. This might
# fail if we don't have enough tokens.
pass_group = pass_group.expand(num_passes - len(pass_group.passes))
else:
# Commit the spend of the passes when the operation finally succeeds.
passes.mark_spent()
pass_group.mark_spent()
break
except:
# Something went wrong that we can't address with a retry.
passes.reset()
pass_group.reset()
raise
# Give the operation's result to the caller.
......
......@@ -28,7 +28,6 @@ from struct import (
)
from itertools import (
count,
islice,
)
......@@ -56,6 +55,7 @@ from .privacypass import (
)
from ..model import (
NotEnoughTokens,
Pass,
)
......@@ -167,16 +167,19 @@ def whitebox_write_sparse_share(sharepath, version, size, leases, now):
)
def integer_passes():
def integer_passes(limit):
"""
:return: Return a function which can be used to get a number of passes.
The function accepts a unicode request-binding message and an integer
number of passes. It returns a list of integers which serve as passes.
Successive calls to the function return unique pass values.
"""
counter = count(0)
counter = iter(range(limit))
def get_passes(message, num_passes):
return list(islice(counter, num_passes))
result = list(islice(counter, num_passes))
if len(result) < num_passes:
raise NotEnoughTokens()
return result
return get_passes
......@@ -216,15 +219,13 @@ def privacypass_passes(signing_key):
return partial(get_passes, signing_key=signing_key)
def pass_factory(get_passes=None):
def pass_factory(get_passes):
"""
Get a new factory for passes.
:param (unicode -> int -> [pass]) get_passes: A function the factory can
use to get new passes.
"""
if get_passes is None:
get_passes = integer_passes()
return _PassFactory(get_passes=get_passes)
......
......@@ -16,6 +16,10 @@
Tests for ``_zkapauthorizer._storage_client``.
"""
from __future__ import (
division,
)
from functools import (
partial,
)
......@@ -32,6 +36,7 @@ from testtools.matchers import (
HasLength,
MatchesAll,
AllMatch,
IsInstance,
)
from testtools.twistedsupport import (
succeeded,
......@@ -59,7 +64,9 @@ from .strategies import (
from ..api import (
MorePassesRequired,
)
from ..model import (
NotEnoughTokens,
)
from .._storage_client import (
call_with_passes,
)
......@@ -69,6 +76,7 @@ from .._storage_server import (
from .storage_common import (
pass_factory,
integer_passes,
)
......@@ -88,7 +96,7 @@ class CallWithPassesTests(TestCase):
call_with_passes(
lambda group: succeed(result),
num_passes,
partial(pass_factory().get, u"message"),
partial(pass_factory(integer_passes(num_passes)).get, u"message"),
),
succeeded(Is(result)),
)
......@@ -105,7 +113,7 @@ class CallWithPassesTests(TestCase):
call_with_passes(
lambda group: fail(result),
num_passes,
partial(pass_factory().get, u"message"),
partial(pass_factory(integer_passes(num_passes)).get, u"message"),
),
failed(
AfterPreprocessing(
......@@ -122,7 +130,7 @@ class CallWithPassesTests(TestCase):
provider containing ``num_passes`` created by the function passed for
``get_passes``.
"""
passes = pass_factory()
passes = pass_factory(integer_passes(num_passes))
self.assertThat(
call_with_passes(
......@@ -143,7 +151,7 @@ class CallWithPassesTests(TestCase):
``call_with_passes`` marks the passes it uses as spent if the operation
succeeds.
"""
passes = pass_factory()
passes = pass_factory(integer_passes(num_passes))
self.assertThat(
call_with_passes(
......@@ -163,7 +171,7 @@ class CallWithPassesTests(TestCase):
"""
``call_with_passes`` returns the passes it uses if the operation fails.
"""
passes = pass_factory()
passes = pass_factory(integer_passes(num_passes))
self.assertThat(
call_with_passes(
......@@ -188,7 +196,7 @@ class CallWithPassesTests(TestCase):
of passes, still of length ```num_passes``, but without the passes
which were rejected on the first try.
"""
passes = pass_factory()
passes = pass_factory(integer_passes(num_passes * 2))
def reject_even_pass_values(group):
passes = group.passes
......@@ -233,7 +241,7 @@ class CallWithPassesTests(TestCase):
no passes have been marked as invalid. This happens if all passes
given were valid but too fewer were given.
"""
passes = pass_factory()
passes = pass_factory(integer_passes(num_passes))
def reject_passes(group):
passes = group.passes
......@@ -271,3 +279,66 @@ class CallWithPassesTests(TestCase):
returned=HasLength(num_passes),
),
)
@given(pass_counts(), pass_counts())
def test_not_enough_tokens_for_retry(self, num_passes, extras):
"""
When there are not enough tokens to successfully complete a retry with the
required number of passes, ``call_with_passes`` marks all passes
reported as invalid during its efforts as such and resets all other
passes it acquired.
"""
passes = pass_factory(integer_passes(num_passes + extras))
rejected = []
accepted = []
def reject_half_passes(group):
num = len(group.passes)
# Floor division will always short-change valid here, even for a
# group size of 1. Therefore there will always be some passes
# marked as invalid.
accept_indexes = range(num // 2)
reject_indexes = range(num // 2, num)
# Only keep this iteration's accepted passes. We'll want to see
# that the final iteration's passes are all returned. Passes from
# earlier iterations don't matter.
accepted[:] = list(group.passes[i] for i in accept_indexes)
# On the other hand, keep *all* rejected passes. They should all
# be marked as invalid and we want to make sure that's the case,
# no matter which iteration rejected them.
rejected.extend(group.passes[i] for i in reject_indexes)
_ValidationResult(
valid=accept_indexes,
signature_check_failed=reject_indexes,
).raise_for(num)
self.assertThat(
call_with_passes(
# Since half of every group is rejected, we'll eventually run
# out of passes no matter how many we start with.
reject_half_passes,
num_passes,
partial(passes.get, u"message"),
),
failed(
AfterPreprocessing(
lambda f: f.value,
IsInstance(NotEnoughTokens),
),
),
)
self.assertThat(
passes,
MatchesStructure(
# Whatever is left in the group when we run out of tokens must
# be returned.
returned=Equals(accepted),
in_use=HasLength(0),
invalid=AfterPreprocessing(
lambda invalid: invalid.keys(),
Equals(rejected),
),
spent=HasLength(0),
issued=Equals(set(accepted + rejected)),
),
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment