From c13fe216a0778b0db27d1f31e66d8e8a3d79f012 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Wed, 31 May 2023 20:33:56 -0400
Subject: [PATCH] start of SDMF support for the downloader

* check as many servers as necessary/possible for encoding parameters, instead
  of just the first
* add cap class instances for SDMF capability types
* teach the decoder about SDMF
---
 src/Tahoe/Download.hs | 90 ++++++++++++++++++++++++++++++-------------
 1 file changed, 63 insertions(+), 27 deletions(-)

diff --git a/src/Tahoe/Download.hs b/src/Tahoe/Download.hs
index 31697c3..4370665 100644
--- a/src/Tahoe/Download.hs
+++ b/src/Tahoe/Download.hs
@@ -17,7 +17,7 @@ module Tahoe.Download (
 
 import Control.Exception (Exception (displayException), SomeException, try)
 import Control.Monad.IO.Class (MonadIO (liftIO))
-import Data.Bifunctor (Bifunctor (first, second))
+import Data.Bifunctor (Bifunctor (bimap, first, second))
 import Data.Binary (Word16, decodeOrFail)
 import qualified Data.ByteString as B
 import qualified Data.ByteString.Lazy as LB
@@ -35,6 +35,8 @@ import qualified Tahoe.CHK.Share
 import Tahoe.CHK.Types (ShareNum, StorageIndex)
 import Tahoe.Download.Internal.Client
 import Tahoe.Download.Internal.Immutable
+import qualified Tahoe.SDMF as SDMF
+import qualified Tahoe.SDMF.Keys as SDMF.Keys
 
 print' :: MonadIO m => String -> m ()
 -- print' = liftIO . print
@@ -68,28 +70,38 @@ download servers cap lookupServer = do
     -- TODO: If getRequiredTotal fails on the first storage server, we may
     -- need to try more.  If it fails for all of them, we need to represent
     -- the failure coherently.
-    ss <- firstStorageServer (Map.elems servers) lookupServer
-    (required, _) <- getRequiredTotal verifier ss
-    locationE <- locateShares servers lookupServer storageIndex (fromIntegral required)
-    print' "Finished locating shares"
-    case locationE of
-        Left err -> do
-            print' "Got an error locating shares"
-            pure $ Left err
-        Right discovered -> do
-            print' "Found some shares, fetching them"
-            -- XXX note shares can contain failures
-            shares <- executeDownloadTasks storageIndex (makeDownloadTasks =<< discovered)
-            print' "Fetched the shares, decoding them"
-            s <- decodeShares cap shares required
-            print' "Decoded them"
-            pure s
+    someParam <- firstSuccessful lookupServer (getRequiredTotal verifier) (Map.elems servers)
+    case someParam of
+        Nothing -> pure $ Left NoConfiguredServers -- XXX Maybe not quite the right error
+        Just (required, _) -> do
+            locationE <- locateShares servers lookupServer storageIndex (fromIntegral required)
+            print' "Finished locating shares"
+            case locationE of
+                Left err -> do
+                    print' "Got an error locating shares"
+                    pure $ Left err
+                Right discovered -> do
+                    print' "Found some shares, fetching them"
+                    -- XXX note shares can contain failures
+                    shares <- executeDownloadTasks storageIndex (makeDownloadTasks =<< discovered)
+                    print' "Fetched the shares, decoding them"
+                    s <- decodeShares cap shares required
+                    print' "Decoded them"
+                    pure s
 
--- We also need "first successful share"!
-firstStorageServer :: Monad m => [StorageServerAnnouncement] -> LookupServer m -> m StorageServer
-firstStorageServer servers finder = do
-    responses <- mapM finder servers
-    pure $ head $ take 1 $ rights responses -- XXX don't do this at home kids, head isn't safe
+firstSuccessful :: MonadIO m => (a -> m (Either b c)) -> (c -> m (Maybe d)) -> [a] -> m (Maybe d)
+firstSuccessful _ _ [] = pure Nothing
+firstSuccessful f op (x : xs) = do
+    s <- f x
+    case s of
+        Left _ -> recurse
+        Right ss -> do
+            r <- op ss
+            case r of
+                Nothing -> recurse
+                d -> pure d
+  where
+    recurse = firstSuccessful f op xs
 
 -- | A capability which confers the ability to locate and verify some stored data.
 class Verifiable v where
@@ -101,7 +113,7 @@ class Verifiable v where
 
     -- | Get the encoding parameters used for the shares of this capability.
     -- The information is presented as a tuple of (required, total).
-    getRequiredTotal :: MonadIO m => v -> StorageServer -> m (Int, Int)
+    getRequiredTotal :: MonadIO m => v -> StorageServer -> m (Maybe (Int, Int))
 
     -- | Get the location information for shares of this capability.
     getStorageIndex :: v -> StorageIndex
@@ -111,7 +123,16 @@ instance Verifiable CHK.Verifier where
     getStorageIndex Verifier{storageIndex} = storageIndex
 
     -- CHK is pure, we don't have to ask the StorageServer
-    getRequiredTotal Verifier{required, total} _ = pure (fromIntegral required, fromIntegral total)
+    getRequiredTotal Verifier{required, total} _ = pure $ pure (fromIntegral required, fromIntegral total)
+
+instance Verifiable SDMF.Verifier where
+    getShareNumbers v s = liftIO $ storageServerGetBuckets s (SDMF.Keys.unStorageIndex $ SDMF.verifierStorageIndex v)
+    getStorageIndex = SDMF.Keys.unStorageIndex . SDMF.verifierStorageIndex
+    getRequiredTotal SDMF.Verifier{..} ss = do
+        shareBytes <- liftIO $ storageServerRead ss (SDMF.Keys.unStorageIndex verifierStorageIndex) 0
+        case decodeOrFail (LB.fromStrict shareBytes) of
+            Left _ -> pure Nothing
+            Right (_, _, sh) -> pure $ pure (fromIntegral $ SDMF.shareRequiredShares sh, fromIntegral $ SDMF.shareTotalShares sh)
 
 {- | A capability which confers the ability to interpret some stored data to
  recover the original plaintext.  Additionally, it can be attentuated to a
@@ -133,17 +154,30 @@ class (Verifiable v) => Readable r v | r -> v where
 instance Readable CHK.Reader CHK.Verifier where
     getVerifiable = verifier
     decodeShare r shareList = do
-        cipherText <- liftIO $ Tahoe.CHK.decode r (second unWhich <$> shareList)
+        cipherText <- liftIO $ Tahoe.CHK.decode r (second unWhichCHK <$> shareList)
         case cipherText of
             Nothing -> pure $ Left ShareDecodingFailed
             Just ct ->
                 pure . Right $ Tahoe.CHK.Encrypt.decrypt (readKey r) ct
 
+instance Readable SDMF.Reader SDMF.Verifier where
+    getVerifiable = SDMF.readerVerifier
+    decodeShare r shareList = do
+        cipherText <- Right <$> liftIO (SDMF.decode r (bimap fromIntegral unWhichSDMF <$> shareList))
+        case cipherText of
+            Left _ -> pure $ Left ShareDecodingFailed
+            Right ct ->
+                pure . Right $ SDMF.decrypt dataKey ct
+              where
+                Just dataKey = SDMF.Keys.deriveDataKey iv readKey
+                iv = SDMF.shareIV (unWhichSDMF . snd . head $ shareList)
+                readKey = SDMF.readerReadKey r
+
 {- | Represent the kind of share to operate on.  This forms a closed world of
  share types.  It might eventually be interesting to make an open world
  variation instead.
 -}
-newtype WhichShare = CHK {unWhich :: Tahoe.CHK.Share.Share} -- \| SDMF SDMF.Share
+data WhichShare = CHK {unWhichCHK :: Tahoe.CHK.Share.Share} | SDMF {unWhichSDMF :: SDMF.Share}
 
 {- | Deserialize some bytes representing some kind of share to that kind of
  share, if possible.
@@ -151,8 +185,10 @@ newtype WhichShare = CHK {unWhich :: Tahoe.CHK.Share.Share} -- \| SDMF SDMF.Shar
 bytesToShare :: LB.ByteString -> Either DeserializeError WhichShare
 bytesToShare bytes = do
     case decodeOrFail bytes of
-        Left _ -> Left UnknownDeserializeError
         Right (_, _, r) -> Right $ CHK r
+        Left _ -> case decodeOrFail bytes of
+            Right (_, _, r) -> Right $ SDMF r
+            Left _ -> Left UnknownDeserializeError
 
 {- | Execute each download task sequentially and return only the successful
  results.
-- 
GitLab