diff --git a/src/_zkapauthorizer/storage_common.py b/src/_zkapauthorizer/storage_common.py index d5a24ff7ec0deb9241f8ce4b99dc8ed185fcff46..78b8dc5fa8c2b74cb84111e4bd836fd7010ebcff 100644 --- a/src/_zkapauthorizer/storage_common.py +++ b/src/_zkapauthorizer/storage_common.py @@ -55,7 +55,7 @@ def required_passes(bytes_per_pass, share_sizes): :param int bytes_per_pass: The number of bytes the storage of which for one lease period one pass covers. - :param set[int] share_sizes: The sizes of the shared which will be stored. + :param list[int] share_sizes: The sizes of the shared which will be stored. :return int: The number of passes required to cover the storage cost. """ @@ -64,6 +64,12 @@ def required_passes(bytes_per_pass, share_sizes): sum(share_sizes, 0) / bytes_per_pass, ), ) + if not isinstance(share_sizes, list): + raise TypeError( + "Share sizes must be a list of integers, got {!r} instead".format( + share_sizes, + ), + ) # print("required_passes({}, {}) == {}".format(bytes_per_pass, share_sizes, result)) return result diff --git a/src/_zkapauthorizer/tests/test_storage_protocol.py b/src/_zkapauthorizer/tests/test_storage_protocol.py index 53a272a1219950b775ded88c7b78478b0b49fafd..b0c83d7fc4041ea82bbcff2c9d3f26de457a917e 100644 --- a/src/_zkapauthorizer/tests/test_storage_protocol.py +++ b/src/_zkapauthorizer/tests/test_storage_protocol.py @@ -33,6 +33,7 @@ from testtools.matchers import ( HasLength, IsInstance, AfterPreprocessing, + raises, ) from testtools.twistedsupport import ( succeeded, @@ -49,7 +50,9 @@ from hypothesis import ( assume, ) from hypothesis.strategies import ( + sets, tuples, + integers, ) from twisted.python.filepath import ( @@ -101,6 +104,7 @@ from ..api import ( from ..storage_common import ( slot_testv_and_readv_and_writev_message, get_implied_data_length, + required_passes, ) from ..model import ( Pass, @@ -148,6 +152,21 @@ class LocalRemote(object): ) +class RequiredPassesTests(TestCase): + """ + Tests for ``required_passes``. + """ + @given(integers(min_value=1), sets(integers(min_value=0))) + def test_incorrect_types(self, bytes_per_pass, share_sizes): + """ + ``required_passes`` raises ``TypeError`` if passed a ``set`` for + ``share_sizes``. + """ + self.assertThat( + lambda: required_passes(bytes_per_pass, share_sizes), + raises(TypeError), + ) + class ShareTests(TestCase): """ Tests for interaction with shares.