From 3906c01c820490ae1f55f8c0a3089a09d20a946f Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Tue, 15 Sep 2020 11:32:45 -0400
Subject: [PATCH] Fix ShareStat serialization implementation

Also add a test that actually serializes it
---
 src/_zkapauthorizer/foolscap.py            |  15 +-
 src/_zkapauthorizer/tests/foolscap.py      |  70 ++++++++--
 src/_zkapauthorizer/tests/test_foolscap.py | 153 +++++++++++++++++++++
 3 files changed, 222 insertions(+), 16 deletions(-)

diff --git a/src/_zkapauthorizer/foolscap.py b/src/_zkapauthorizer/foolscap.py
index 3d69a53..29ed94e 100644
--- a/src/_zkapauthorizer/foolscap.py
+++ b/src/_zkapauthorizer/foolscap.py
@@ -31,6 +31,7 @@ from foolscap.api import (
     DictOf,
     ListOf,
     Copyable,
+    RemoteCopy,
 )
 from foolscap.remoteinterface import (
     RemoteMethodSchema,
@@ -44,7 +45,7 @@ from allmydata.interfaces import (
 )
 
 @attr.s
-class ShareStat(Copyable):
+class ShareStat(Copyable, RemoteCopy):
     """
     Represent some metadata about a share.
 
@@ -53,8 +54,16 @@ class ShareStat(Copyable):
     :ivar int lease_expiration: The POSIX timestamp of the time at which the
         lease on this share expires, or None if there is no lease.
     """
-    size = attr.ib()
-    lease_expiration = attr.ib()
+    typeToCopy = copytype = "ShareStat"
+
+    # To be a RemoteCopy it must be possible to instantiate this with no
+    # arguments. :/ So supply defaults for these attributes.
+    size = attr.ib(default=0)
+    lease_expiration = attr.ib(default=0)
+
+    # The RemoteCopy interface
+    def setCopyableState(self, state):
+        self.__dict__ = state
 
 
 # The Foolscap convention seems to be to try to constrain inputs to valid
diff --git a/src/_zkapauthorizer/tests/foolscap.py b/src/_zkapauthorizer/tests/foolscap.py
index c9bb485..3571699 100644
--- a/src/_zkapauthorizer/tests/foolscap.py
+++ b/src/_zkapauthorizer/tests/foolscap.py
@@ -27,11 +27,19 @@ from zope.interface import (
 import attr
 
 from twisted.internet.defer import (
-    execute,
+    succeed,
+    fail,
 )
 
 from foolscap.api import (
     RemoteInterface,
+    Referenceable,
+    Copyable,
+    Any,
+)
+from foolscap.copyable import (
+    ICopyable,
+    CopyableSlicer,
 )
 
 from allmydata.interfaces import (
@@ -41,6 +49,11 @@ from allmydata.interfaces import (
 class RIStub(RemoteInterface):
     pass
 
+
+class RIEcho(RemoteInterface):
+    def echo(argument=Any()):
+        return Any()
+
 @implementer(RIStorageServer)
 class StubStorageServer(object):
     pass
@@ -50,6 +63,18 @@ def get_anonymous_storage_server():
     return StubStorageServer()
 
 
+class BrokenCopyable(Copyable):
+    """
+    I don't have a ``typeToCopy`` so I can't be serialized.
+    """
+
+
+@implementer(RIEcho)
+class Echoer(Referenceable):
+    def remote_echo(self, argument):
+        return argument
+
+
 @attr.s
 class DummyReferenceable(object):
     _interface = attr.ib()
@@ -103,18 +128,37 @@ class LocalRemote(object):
         """
         Call the given method on the wrapped object, passing the given arguments.
 
-        Arguments are checked for conformance to the remote interface but the
-        return value is not (because I don't know how -exarkun).
+        Arguments and return are checked for conformance to the remote
+        interface but they are not actually serialized.
 
         :return Deferred: The result of the call on the wrapped object.
         """
-        schema = self._referenceable.getInterface()[methname]
-        if self.check_args:
-            schema.checkAllArgs(args, kwargs, inbound=False)
-        # TODO: Figure out how to call checkResults on the result.
-        return execute(
-            self._referenceable.doRemoteCall,
-            methname,
-            args,
-            kwargs,
-        )
+        try:
+            schema = self._referenceable.getInterface()[methname]
+            if self.check_args:
+                schema.checkAllArgs(args, kwargs, inbound=True)
+            _check_copyables(list(args) + kwargs.values())
+            result = self._referenceable.doRemoteCall(
+                methname,
+                args,
+                kwargs,
+            )
+            schema.checkResults(result, inbound=False)
+            _check_copyables([result])
+            return succeed(result)
+        except:
+            return fail()
+
+
+def _check_copyables(copyables):
+    """
+    Check each object to see if it is a copyable and if it is make sure it can
+    be sliced.
+    """
+    for obj in copyables:
+        if ICopyable.providedBy(obj):
+            list(CopyableSlicer(obj).slice(False, None))
+        elif isinstance(obj, dict):
+            _check_copyables(obj.values())
+        elif isinstance(obj, list):
+            _check_copyables(obj)
diff --git a/src/_zkapauthorizer/tests/test_foolscap.py b/src/_zkapauthorizer/tests/test_foolscap.py
index 388cc11..5912b35 100644
--- a/src/_zkapauthorizer/tests/test_foolscap.py
+++ b/src/_zkapauthorizer/tests/test_foolscap.py
@@ -20,16 +20,36 @@ from __future__ import (
     absolute_import,
 )
 
+from fixtures import (
+    Fixture,
+)
 from testtools import (
     TestCase,
 )
 from testtools.matchers import (
+    Equals,
     MatchesAll,
     AfterPreprocessing,
     Always,
     IsInstance,
 )
+from testtools.twistedsupport import (
+    succeeded,
+    failed,
+)
+
+from twisted.trial.unittest import (
+    TestCase as TrialTestCase,
+)
+from twisted.internet.defer import (
+    inlineCallbacks,
+)
 
+from foolscap.api import (
+    Violation,
+    RemoteInterface,
+    Any,
+)
 from foolscap.furl import (
     decode_furl,
 )
@@ -51,10 +71,27 @@ from hypothesis.strategies import (
 
 from .foolscap import (
     RIStub,
+    Echoer,
     LocalRemote,
+    BrokenCopyable,
     DummyReferenceable,
 )
 
+from ..foolscap import (
+    ShareStat,
+)
+
+class IHasSchema(RemoteInterface):
+    def method(arg=int):
+        return bytes
+
+    def good_method(arg=int):
+        return None
+
+    def whatever_method(arg=Any()):
+        return Any()
+
+
 def remote_reference():
     tub = Tub()
     tub.setLocation("127.0.0.1:12345")
@@ -95,3 +132,119 @@ class LocalRemoteTests(TestCase):
                 ),
             ),
         )
+
+    def test_arg_schema(self):
+        """
+        ``LocalRemote.callRemote`` returns a ``Deferred`` that fails with a
+        ``Violation`` if an parameter receives an argument which doesn't
+        conform to its schema.
+        """
+        ref = LocalRemote(DummyReferenceable(IHasSchema))
+        self.assertThat(
+            ref.callRemote("method", None),
+            failed(
+                AfterPreprocessing(
+                    lambda f: f.type,
+                    Equals(Violation),
+                ),
+            ),
+        )
+
+    def test_result_schema(self):
+        """
+        ``LocalRemote.callRemote`` returns a ``Deferred`` that fails with a
+        ``Violation`` if a method returns an object which doesn't conform to
+        the method's result schema.
+        """
+        ref = LocalRemote(DummyReferenceable(IHasSchema))
+        self.assertThat(
+            ref.callRemote("method", 0),
+            failed(
+                AfterPreprocessing(
+                    lambda f: f.type,
+                    Equals(Violation),
+                ),
+            ),
+        )
+
+    def test_successful_method(self):
+        """
+        ``LocalRemote.callRemote`` returns a ``Deferred`` that fires with the
+        remote method's result if the arguments and result conform to their
+        respective schemas.
+        """
+        ref = LocalRemote(DummyReferenceable(IHasSchema))
+        self.assertThat(
+            ref.callRemote("good_method", 0),
+            succeeded(Equals(None)),
+        )
+
+    def test_argument_serialization_failure(self):
+        """
+        ``LocalRemote.callRemote`` returns a ``Deferred`` that fires with a
+        failure if an argument cannot be serialized.
+        """
+        ref = LocalRemote(DummyReferenceable(IHasSchema))
+        self.assertThat(
+            ref.callRemote("whatever_method", BrokenCopyable()),
+            failed(Always()),
+        )
+
+    def test_result_serialization_failure(self):
+        """
+        ``LocalRemote.callRemote`` returns a ``Deferred`` that fires with a
+        failure if the method's result cannot be serialized.
+        """
+        class BrokenResultReferenceable(DummyReferenceable):
+            def doRemoteCall(self, *a, **kw):
+                return BrokenCopyable()
+
+        ref = LocalRemote(BrokenResultReferenceable(IHasSchema))
+        self.assertThat(
+            ref.callRemote("whatever_method", None),
+            failed(Always()),
+        )
+
+
+class EchoerFixture(Fixture):
+    def __init__(self, reactor, tub_path):
+        self.reactor = reactor
+        self.tub = Tub()
+        self.tub.setLocation(b"tcp:0")
+
+    def _setUp(self):
+        self.tub.startService()
+        self.furl = self.tub.registerReference(Echoer())
+
+    def _cleanUp(self):
+        return self.tub.stopService()
+
+
+class SerializationTests(TrialTestCase):
+    """
+    Tests for the serialization of types used in the Foolscap API.
+    """
+    def test_sharestat(self):
+        """
+        A ``ShareStat`` instance can be sent as an argument to and received in a
+        response from a Foolscap remote method call.
+        """
+        return self._roundtrip_test(ShareStat(1, 2))
+
+    @inlineCallbacks
+    def _roundtrip_test(self, obj):
+        """
+        Send ``obj`` over Foolscap and receive it back again, equal to itself.
+        """
+        # Foolscap Tub implementation just uses the global reactor...
+        from twisted.internet import reactor
+
+        # So sad.  No Deferred support in testtools.TestCase or
+        # fixture.Fixture, no fixture support in
+        # twisted.trial.unittest.TestCase.
+        fx = EchoerFixture(reactor, self.mktemp())
+        fx.setUp()
+        self.addCleanup(fx._cleanUp)
+        echoer = yield fx.tub.getReference(fx.furl)
+        received = yield echoer.callRemote("echo", obj)
+        self.assertEqual(obj, received)
-- 
GitLab