From b9fa17f01833ec8ba60d957e8cfe4cd85635e9c7 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Tue, 10 Oct 2023 14:35:14 -0400
Subject: [PATCH] Move download and downloadDirectory into ExceptT

---
 CHANGELOG.md          |  2 ++
 gbs-downloader.cabal  |  2 ++
 src/Tahoe/Download.hs | 59 ++++++++++++++++---------------------------
 test/Spec.hs          | 19 +++++++-------
 4 files changed, 36 insertions(+), 46 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 70ee44b..89f37bd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,8 @@
 * The download APIs now only send requests to a storage server after that
   storage server is authenticated using information from the NURL.
 
+* ``Tahoe.Download.download`` and ``Tahoe.Download.downloadDirectory`` now return ``ExceptT``.
+
 ## 0.1.0.0 -- 2023-08-17
 
 * First version. Released on an unsuspecting world.
diff --git a/gbs-downloader.cabal b/gbs-downloader.cabal
index ee59009..2c306e2 100644
--- a/gbs-downloader.cabal
+++ b/gbs-downloader.cabal
@@ -118,6 +118,7 @@ library
     , http-client-tls          >=0.3.5.3  && <0.4
     , http-types               >=0.12.3   && <0.13
     , lens                     >=4.0      && <5.3
+    , mtl
     , network-uri              >=2.6.3    && <2.7
     , servant-client           >=0.16.0.1 && <0.21
     , servant-client-core      >=0.16     && <0.21
@@ -249,6 +250,7 @@ test-suite gbs-downloader-test
     , http-types           >=0.12.3   && <0.13
     , lens                 >=4.0      && <5.3
     , memory               >=0.15     && <0.17
+    , mtl
     , servant-client       >=0.16.0.1 && <0.21
     , servant-client-core  >=0.16     && <0.21
     , tahoe-chk            >=0.2      && <0.3
diff --git a/src/Tahoe/Download.hs b/src/Tahoe/Download.hs
index 4f0f41f..b7febdf 100644
--- a/src/Tahoe/Download.hs
+++ b/src/Tahoe/Download.hs
@@ -19,6 +19,7 @@ module Tahoe.Download (
 
 import Control.Concurrent.Async (mapConcurrently)
 import Control.Exception (Exception (displayException), SomeException, try)
+import Control.Monad.Except (ExceptT (..), lift, throwError, withExceptT)
 import Control.Monad.IO.Class (MonadIO (liftIO))
 import Data.Bifunctor (Bifunctor (first, second))
 import Data.Binary (Word16)
@@ -55,7 +56,7 @@ download ::
     -- the read capability has a Verifiable instance because Verifiable is
     -- what gives us the ability to locate the shares.  If we located
     -- separately from decrypting this might be simpler.
-    (MonadIO m, Readable readCap, Verifiable v, Verifier readCap ~ v) =>
+    (Readable readCap, Verifiable v, Verifier readCap ~ v) =>
     -- | Information about the servers from which to consider downloading shares
     -- representing the application data.
     Map.Map StorageServerID StorageServerAnnouncement ->
@@ -65,49 +66,35 @@ download ::
     LookupServer IO ->
     -- | Either a description of how the recovery failed or the recovered
     -- application data.
-    m (Either DownloadError LB.ByteString)
+    ExceptT DownloadError IO LB.ByteString
 download servers cap lookupServer = do
     print' ("Downloading: " <> show (getStorageIndex $ getVerifiable cap))
     let verifier = getVerifiable cap
     let storageIndex = getStorageIndex verifier
-    -- TODO: If getRequiredTotal fails on the first storage server, we may
-    -- need to try more.  If it fails for all of them, we need to represent
-    -- the failure coherently.
-    someParam <- liftIO $ firstRightM lookupServer (getRequiredTotal verifier) (Map.elems servers)
-    case someParam of
-        Left errs -> pure . Left $ if servers == mempty then NoConfiguredServers else NoReachableServers (StorageServerUnreachable <$> errs)
-        Right (required, _) -> do
-            locationE <- liftIO $ locateShares servers lookupServer storageIndex (fromIntegral required)
-            print' "Finished locating shares"
-            case locationE of
-                Left err -> do
-                    print' "Got an error locating shares"
-                    pure $ Left err
-                Right discovered -> do
-                    print' "Found some shares, fetching them"
-                    -- XXX note shares can contain failures
-                    shares <- liftIO $ executeDownloadTasks storageIndex (makeDownloadTasks =<< discovered)
-                    print' "Fetched the shares, decoding them"
-                    s <- liftIO $ decodeShares cap shares required
-                    print' "Decoded them"
-                    pure s
+    (required, _) <- withExceptT noReachableServers (firstRightM lookupServer (getRequiredTotal verifier) (Map.elems servers))
+    print' "Discovered required number of shares"
+    discovered <- ExceptT $ locateShares servers lookupServer storageIndex (fromIntegral required)
+    print' "Finished locating shares, fetching"
+    shares <- liftIO $ executeDownloadTasks storageIndex (makeDownloadTasks =<< discovered)
+    print' "Fetched the shares, decoding them"
+    ExceptT $ liftIO $ decodeShares cap shares required
+  where
+    noReachableServers = NoReachableServers . (StorageServerUnreachable <$>)
 
 {- | Apply a monadic operation to each element of a list and another monadic
  operation values in the resulting Rights.  If all of the results are Lefts or
  Nothings, return a list of the values in the Lefts.  Otherwise, return the
  *first* Right.
 -}
-firstRightM :: Monad m => (a -> m (Either b c)) -> (c -> m (Maybe d)) -> [a] -> m (Either [b] d)
-firstRightM _ _ [] = pure $ Left []
+firstRightM :: Monad m => (a -> m (Either b c)) -> (c -> m (Maybe d)) -> [a] -> ExceptT [b] m d
+firstRightM _ _ [] = throwError []
 firstRightM f op (x : xs) = do
-    s <- f x
+    s <- lift $ f x
     case s of
-        Left bs -> first (bs :) <$> recurse
+        Left bs -> (bs :) `withExceptT` recurse
         Right ss -> do
-            r <- op ss
-            case r of
-                Nothing -> recurse
-                Just d -> pure $ Right d
+            r <- lift $ op ss
+            maybe recurse pure r
   where
     recurse = firstRightM f op xs
 
@@ -232,7 +219,7 @@ downloadShare storageIndex (shareNum, s) = do
  as a collection of entries.
 -}
 downloadDirectory ::
-    (MonadIO m, Readable readCap, Verifiable v, Verifier readCap ~ v) =>
+    (Readable readCap, Verifiable v, Verifier readCap ~ v) =>
     -- | Information about the servers from which to consider downloading shares
     -- representing the application data.
     Map.Map StorageServerID StorageServerAnnouncement ->
@@ -242,12 +229,10 @@ downloadDirectory ::
     LookupServer IO ->
     -- | Either a description of how the recovery failed or the recovered
     -- application data.
-    m (Either DirectoryDownloadError Directory)
+    ExceptT DirectoryDownloadError IO Directory
 downloadDirectory anns (DirectoryCapability cap) lookupServer = do
-    bs <- download anns cap lookupServer
-    pure $ do
-        bs' <- first UnderlyingDownloadError bs
-        first (const DecodingError) $ Directory.parse (LB.toStrict bs')
+    bs <- UnderlyingDownloadError `withExceptT` download anns cap lookupServer
+    ExceptT . pure . first (const DecodingError) . Directory.parse . LB.toStrict $ bs
 
 data DirectoryDownloadError
     = UnderlyingDownloadError DownloadError
diff --git a/test/Spec.hs b/test/Spec.hs
index c5ba773..269248f 100644
--- a/test/Spec.hs
+++ b/test/Spec.hs
@@ -5,6 +5,7 @@ module Main where
 import Control.Exception (Exception, throwIO)
 import Control.Lens (view)
 import Control.Monad (replicateM, when)
+import Control.Monad.Except (MonadTrans (lift), runExceptT)
 import Control.Monad.IO.Class (liftIO)
 import Crypto.Cipher.AES (AES128)
 import Crypto.Cipher.Types (Cipher (cipherInit, cipherKeySize), KeySizeSpecifier (KeySizeEnum, KeySizeFixed, KeySizeRange), nullIV)
@@ -144,10 +145,10 @@ tests =
             $ do
                 -- If there are no servers then we can't possibly get enough
                 -- shares to recover the application data.
-                result <- liftIO $ download mempty (trivialCap 1 1) noServers
+                result <- runExceptT $ download mempty (trivialCap 1 1) noServers
                 assertEqual
                     "download should fail with no servers"
-                    (Left NoConfiguredServers)
+                    (Left (NoReachableServers []))
                     result
         , testCase "no reachable servers" $ do
             -- If we can't contact any configured server then we can't
@@ -158,7 +159,7 @@ tests =
                         [ ("v0-abc123", ann)
                         ]
 
-            result <- liftIO $ download anns (trivialCap 1 1) noServers
+            result <- runExceptT $ download anns (trivialCap 1 1) noServers
             assertEqual
                 "download should fail with no reachable servers"
                 (Left $ NoReachableServers [StorageServerUnreachable (URIParseError ann)])
@@ -179,7 +180,7 @@ tests =
             let openServer = simpleLookup [("somewhere", server)]
 
             -- Try to download the cap which requires three shares to reconstruct.
-            result <- liftIO $ download anns cap openServer
+            result <- runExceptT $ download anns cap openServer
             assertEqual
                 "download should fail with not enough shares"
                 (Left NotEnoughShares{notEnoughSharesNeeded = 3, notEnoughSharesFound = 2})
@@ -210,7 +211,7 @@ tests =
             let openServer = simpleLookup [("somewhere", somewhere), ("elsewhere", elsewhere)]
 
             -- Try to download the cap which requires three shares to reconstruct.
-            result <- liftIO $ download anns cap openServer
+            result <- runExceptT $ download anns cap openServer
             assertEqual
                 "download should fail with not enough shares"
                 (Left NotEnoughShares{notEnoughSharesNeeded = 3, notEnoughSharesFound = 2})
@@ -228,7 +229,7 @@ tests =
             let cap = trivialCap 3 13
 
             -- Try to download the cap which requires three shares to reconstruct.
-            result <- liftIO $ download anns cap openServer
+            result <- runExceptT $ download anns cap openServer
             assertEqual
                 "download should fail with details about unreachable server"
                 (Left (NoReachableServers [StorageServerCommunicationError "BespokeFailure"]))
@@ -291,7 +292,7 @@ tests =
 
             -- Try to download the cap which requires three shares to reconstruct.
 
-            result <- liftIO $ download anns cap openServer
+            result <- runExceptT $ download anns cap openServer
             assertEqual
                 "download should fail with details about unreachable server"
                 (Left (NotEnoughDecodedShares{notEnoughDecodedSharesNeeded = 3, notEnoughDecodedSharesFound = 0}))
@@ -341,7 +342,7 @@ tests =
                     serverAnnouncements = Map.fromSet makeAnn serverIDs'
 
                 -- Recover the plaintext from the servers.
-                result <- liftIO $ download serverAnnouncements cap lookupServer
+                result <- lift $ runExceptT $ download serverAnnouncements cap lookupServer
                 diff (Right plaintext) (==) result
         , testProperty "ssk success" $
             property $ do
@@ -385,7 +386,7 @@ tests =
                     serverAnnouncements = Map.fromSet makeAnn serverIDs'
 
                 -- Recover the plaintext from the servers.
-                result <- liftIO $ download serverAnnouncements readCap lookupServer
+                result <- lift $ runExceptT $ download serverAnnouncements readCap lookupServer
                 diff (Right plaintext) (==) result
         , testCase "immutable upload/download to using Great Black Swamp" $ do
             pure ()
-- 
GitLab