diff --git a/src/_zkapauthorizer/storage_common.py b/src/_zkapauthorizer/storage_common.py index d5a24ff7ec0deb9241f8ce4b99dc8ed185fcff46..9bf9435e69e5429cf7bdf596d7e1b18fe0472da1 100644 --- a/src/_zkapauthorizer/storage_common.py +++ b/src/_zkapauthorizer/storage_common.py @@ -24,10 +24,6 @@ from base64 import ( b64encode, ) -from math import ( - ceil, -) - def _message_maker(label): def make_message(storage_index): return u"{label} {storage_index}".format( @@ -55,15 +51,20 @@ def required_passes(bytes_per_pass, share_sizes): :param int bytes_per_pass: The number of bytes the storage of which for one lease period one pass covers. - :param set[int] share_sizes: The sizes of the shared which will be stored. + :param list[int] share_sizes: The sizes of the shared which will be stored. :return int: The number of passes required to cover the storage cost. """ - result = int( - ceil( - sum(share_sizes, 0) / bytes_per_pass, - ), - ) + if not isinstance(share_sizes, list): + raise TypeError( + "Share sizes must be a list of integers, got {!r} instead".format( + share_sizes, + ), + ) + result, b = divmod(sum(share_sizes, 0), bytes_per_pass) + if b: + result += 1 + # print("required_passes({}, {}) == {}".format(bytes_per_pass, share_sizes, result)) return result diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py index 53a272a1219950b775ded88c7b78478b0b49fafd..d595e7c104c3ac0edf79d374131ed36a505f8711 100644 --- a/src/_zkapauthorizer/tests/test_storage_protocol.py +++ b/src/_zkapauthorizer/tests/test_storage_protocol.py @@ -33,6 +33,7 @@ from testtools.matchers import ( HasLength, IsInstance, AfterPreprocessing, + raises, ) from testtools.twistedsupport import ( succeeded, @@ -49,7 +50,10 @@ from hypothesis import ( assume, ) from hypothesis.strategies import ( + sets, + lists, tuples, + integers, ) from twisted.python.filepath import ( @@ -101,6 +105,7 @@ from ..api import ( from ..storage_common import ( slot_testv_and_readv_and_writev_message, get_implied_data_length, + required_passes, ) from ..model import ( Pass, @@ -148,6 +153,45 @@ class LocalRemote(object): ) +class RequiredPassesTests(TestCase): + """ + Tests for ``required_passes``. + """ + @given(integers(min_value=1), sets(integers(min_value=0))) + def test_incorrect_types(self, bytes_per_pass, share_sizes): + """ + ``required_passes`` raises ``TypeError`` if passed a ``set`` for + ``share_sizes``. + """ + self.assertThat( + lambda: required_passes(bytes_per_pass, share_sizes), + raises(TypeError), + ) + + @given( + bytes_per_pass=integers(min_value=1), + expected_per_share=lists(integers(min_value=1), min_size=1), + ) + def test_minimum_result(self, bytes_per_pass, expected_per_share): + """ + ``required_passes`` returns an integer giving the fewest passes required + to pay for the storage represented by the given share sizes. + """ + actual = required_passes( + bytes_per_pass, + list( + passes * bytes_per_pass + for passes + in expected_per_share + ), + ) + self.assertThat( + actual, + Equals(sum(expected_per_share)), + ) + + + class ShareTests(TestCase): """ Tests for interaction with shares.