From 9818bf0b73523dcdd6ab38360e28fe20e9d3fd09 Mon Sep 17 00:00:00 2001
From: Tom Prince <tom.prince@private.storage>
Date: Thu, 29 Jul 2021 15:08:56 -0600
Subject: [PATCH] use deep_traverse

---
 src/_zkapauthorizer/lease_maintenance.py | 28 +++++++++++--------
 src/_zkapauthorizer/pricecalculator.py   | 35 ++++++++++++++++++++++++
 src/_zkapauthorizer/resource.py          | 24 ++++++++++++++++
 3 files changed, 75 insertions(+), 12 deletions(-)

diff --git a/src/_zkapauthorizer/lease_maintenance.py b/src/_zkapauthorizer/lease_maintenance.py
index 8a3a095..26101c4 100644
--- a/src/_zkapauthorizer/lease_maintenance.py
+++ b/src/_zkapauthorizer/lease_maintenance.py
@@ -49,7 +49,6 @@ from twisted.python.log import (
 )
 
 from allmydata.interfaces import (
-    IDirectoryNode,
     IFilesystemNode,
 )
 from allmydata.util.hashutil import (
@@ -92,17 +91,22 @@ def visit_storage_indexes(root_nodes, visit):
                 node,
             ))
 
-    stack = root_nodes[:]
-    while stack:
-        elem = stack.pop()
-        visit(elem.get_storage_index())
-        if IDirectoryNode.providedBy(elem):
-            children = yield elem.list()
-            # Produce consistent results by forcing some consistent ordering
-            # here.  This will sort by name.
-            stable_children = sorted(children.items())
-            for (name, (child_node, child_metadata)) in stable_children:
-                stack.append(child_node)
+    class Renewer(object):
+        def set_monitor(self, monitor):
+            self.monitor = monitor
+
+        def add_node(self, node, childpath):
+            visit(node.get_storage_index())
+
+        def enter_directory(self, parent, children):
+            pass
+
+        def finish(self):
+            pass
+
+    for root_node in root_nodes:
+        monitor = root_node.deep_traverse(Renewer)
+        yield monitor.when_done()
 
 
 def iter_storage_indexes(visit_assets):
diff --git a/src/_zkapauthorizer/pricecalculator.py b/src/_zkapauthorizer/pricecalculator.py
index 007ec9c..99213ad 100644
--- a/src/_zkapauthorizer/pricecalculator.py
+++ b/src/_zkapauthorizer/pricecalculator.py
@@ -28,12 +28,14 @@ calculator).
 """
 
 import attr
+from twisted.internet.defer import inlineCallbacks, returnValue
 
 from .storage_common import (
     required_passes,
     share_size_for_data,
 )
 
+
 @attr.s
 class PriceCalculator(object):
     """
@@ -71,3 +73,36 @@ class PriceCalculator(object):
         )
         price = sum(all_required_passes, 0)
         return price
+
+    def _calculate_for_size(self, size):
+        share_size = share_size_for_data(self._shares_needed, size)
+        passes = required_passes(self._pass_value, [share_size] * self._shares_total)
+        return passes
+
+    @inlineCallbacks
+    def calculate_from_node(self, root_node):
+        @attr.s
+        class Sizer(object):
+            total = attr.ib(init=False, default=0)
+
+            def set_monitor(iself, monitor):
+                iself.monitor = monitor
+
+            @inlineCallbacks
+            def add_node(iself, node, childpath):
+                if node.get_storate_index() is None:
+                    return
+                size = yield node.get_current_size()
+                price = self._calculate_for_size(size)
+                iself.total += price
+
+            def enter_directory(iself, parent, children):
+                pass
+
+            def finish(iself):
+                pass
+
+        sizer = Sizer()
+        monitor = root_node.deep_traverse(Sizer)
+        yield monitor.when_done()
+        returnValue(sizer.total)
diff --git a/src/_zkapauthorizer/resource.py b/src/_zkapauthorizer/resource.py
index e5e31ea..3d3f541 100644
--- a/src/_zkapauthorizer/resource.py
+++ b/src/_zkapauthorizer/resource.py
@@ -40,6 +40,7 @@ from twisted.logger import (
 )
 from twisted.web.http import (
     BAD_REQUEST,
+    INTERNAL_SERVER_ERROR,
 )
 from twisted.web.server import (
     NOT_DONE_YET,
@@ -262,6 +263,29 @@ class _CalculatePrice(Resource):
                 "error": "could not parse request body",
             })
 
+        try:
+            version = body_object[u"version"]
+        except KeyError:
+            pass
+        else:
+            if version == 2:
+                root_node = body_object[u"root_node"]
+                application_json(request)
+
+                price = self._price_calculator.calculate_from_node(root_node).addCallback(
+                    lambda price: dumps({
+                        u"price": price,
+                        u"period": self._lease_period,
+                    })
+                ).addCallback(
+                    request.write
+                ).addErrback(
+                    lambda ignored: request.setResponseCode(INTERNAL_SERVER_ERROR)
+                ).addCallback(
+                    lambda ignored: request.finish()
+                )
+                return NOT_DONE_YET
+
         try:
             version = body_object[u"version"]
             sizes = body_object[u"sizes"]
-- 
GitLab