From 749871d18aabfac8b9698b36c6ff8d5da0dac783 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Fri, 2 Jun 2023 14:49:19 -0400
Subject: [PATCH] Turn WhichShare into an associated type family

This makes share decoding open instead of closed!
---
 src/Tahoe/Download.hs                     | 17 ++--------
 src/Tahoe/Download/Internal/Capability.hs | 40 +++++++++++++++--------
 2 files changed, 29 insertions(+), 28 deletions(-)

diff --git a/src/Tahoe/Download.hs b/src/Tahoe/Download.hs
index 93488ce..84720d8 100644
--- a/src/Tahoe/Download.hs
+++ b/src/Tahoe/Download.hs
@@ -1,5 +1,3 @@
-{-# LANGUAGE FlexibleContexts #-}
-
 {- | A high-level interface to downloading share data as bytes from storage
  servers.
 -}
@@ -18,7 +16,7 @@ module Tahoe.Download (
 import Control.Exception (Exception (displayException), SomeException, try)
 import Control.Monad.IO.Class (MonadIO (liftIO))
 import Data.Bifunctor (Bifunctor (first, second))
-import Data.Binary (Word16, decodeOrFail)
+import Data.Binary (Word16)
 import qualified Data.ByteString as B
 import qualified Data.ByteString.Lazy as LB
 import Data.Either (partitionEithers, rights)
@@ -99,17 +97,6 @@ firstRightM f op (x : xs) = do
   where
     recurse = firstRightM f op xs
 
-{- | Deserialize some bytes representing some kind of share to that kind of
- share, if possible.
--}
-bytesToShare :: LB.ByteString -> Either DeserializeError WhichShare
-bytesToShare bytes = do
-    case decodeOrFail bytes of
-        Right (_, _, r) -> Right $ CHK r
-        Left _ -> case decodeOrFail bytes of
-            Right (_, _, r) -> Right $ SDMF r
-            Left _ -> Left UnknownDeserializeError
-
 {- | Execute each download task sequentially and return only the successful
  results.
 -}
@@ -172,7 +159,7 @@ decodeShares ::
     m (Either DownloadError LB.ByteString)
 decodeShares r shares required = do
     -- Filter down to shares we actually got.
-    let fewerShares :: [(ShareNum, Either DeserializeError WhichShare)] = second bytesToShare <$> shares
+    let fewerShares = second (deserializeShare (getVerifiable r)) <$> shares
         onlyDecoded = rights $ (\(a, b) -> (fromIntegral a,) <$> b) <$> fewerShares
     if length onlyDecoded < required
         then pure $ Left NotEnoughDecodedShares{notEnoughDecodedSharesNeeded = fromIntegral required, notEnoughDecodedSharesFound = length onlyDecoded}
diff --git a/src/Tahoe/Download/Internal/Capability.hs b/src/Tahoe/Download/Internal/Capability.hs
index 4351387..84c575f 100644
--- a/src/Tahoe/Download/Internal/Capability.hs
+++ b/src/Tahoe/Download/Internal/Capability.hs
@@ -1,10 +1,12 @@
 {-# LANGUAGE FunctionalDependencies #-}
+{-# LANGUAGE TypeFamilies #-}
 
 module Tahoe.Download.Internal.Capability where
 
 import Control.Monad.IO.Class
 import Data.Bifunctor (Bifunctor (..))
 import Data.Binary (decodeOrFail)
+import Data.Binary.Get (ByteOffset)
 import qualified Data.ByteString.Lazy as LB
 import qualified Data.Set as Set
 import qualified Tahoe.CHK
@@ -17,16 +19,11 @@ import Tahoe.Download.Internal.Client
 import qualified Tahoe.SDMF as SDMF
 import qualified Tahoe.SDMF.Internal.Keys as SDMF.Keys
 
-{- | Represent the kind of share to operate on.  This forms a closed world of
- share types.  It might eventually be interesting to make an open world
- variation instead.
--}
-data WhichShare
-    = CHK {unWhichCHK :: Tahoe.CHK.Share.Share}
-    | SDMF {unWhichSDMF :: SDMF.Share}
-
 -- | A capability which confers the ability to locate and verify some stored data.
 class Verifiable v where
+    -- | Represent the type of share to operate on.
+    type ShareT v
+
     -- | Ask a storage server which share numbers related to this capability it
     -- is holding.  This is an unverified result and the storage server could
     -- present incorrect information.  Even if it correctly reports that it
@@ -40,6 +37,15 @@ class Verifiable v where
     -- | Get the location information for shares of this capability.
     getStorageIndex :: v -> StorageIndex
 
+    -- | Deserialize some bytes representing some kind of share to the kind of
+    -- share associated with this capability type, if possible.
+    deserializeShare ::
+        -- | A type witness revealing what type of share to decode to.
+        v ->
+        -- | The bytes of the serialized share.
+        LB.ByteString ->
+        Either (LB.ByteString, ByteOffset, String) (ShareT v)
+
 class (Verifiable v) => Readable r v | r -> v where
     -- | Attentuate the capability.
     getVerifiable :: r -> v
@@ -48,18 +54,22 @@ class (Verifiable v) => Readable r v | r -> v where
     --
     -- Note: might want to split the two functions below out of decodeShare
     --
-    -- shareToCipherText :: r -> [(Int, WhichShare)] -> LB.ByteString
+    -- shareToCipherText :: r -> [(Int, ShareT r)] -> LB.ByteString
     --
     -- cipherTextToPlainText :: r -> LB.ByteString -> LB.ByteString
-    decodeShare :: MonadIO m => r -> [(Int, WhichShare)] -> m (Either DownloadError LB.ByteString)
+    decodeShare :: MonadIO m => r -> [(Int, ShareT v)] -> m (Either DownloadError LB.ByteString)
 
 instance Verifiable CHK.Verifier where
+    type ShareT CHK.Verifier = Tahoe.CHK.Share.Share
+
     getShareNumbers v s = liftIO $ storageServerGetBuckets s (CHK.storageIndex v)
     getStorageIndex CHK.Verifier{storageIndex} = storageIndex
 
     -- CHK is pure, we don't have to ask the StorageServer
     getRequiredTotal CHK.Verifier{required, total} _ = pure $ pure (fromIntegral required, fromIntegral total)
 
+    deserializeShare _ = fmap (\(_, _, c) -> c) . decodeOrFail
+
 {- | A capability which confers the ability to interpret some stored data to
  recover the original plaintext.  Additionally, it can be attentuated to a
  Verifiable.
@@ -67,13 +77,15 @@ instance Verifiable CHK.Verifier where
 instance Readable CHK.Reader CHK.Verifier where
     getVerifiable = CHK.verifier
     decodeShare r shareList = do
-        cipherText <- liftIO $ Tahoe.CHK.decode r (second unWhichCHK <$> shareList)
+        cipherText <- liftIO $ Tahoe.CHK.decode r shareList
         case cipherText of
             Nothing -> pure $ Left ShareDecodingFailed
             Just ct ->
                 pure . Right $ Tahoe.CHK.Encrypt.decrypt (CHK.readKey r) ct
 
 instance Verifiable SDMF.Verifier where
+    type ShareT SDMF.Verifier = SDMF.Share
+
     getShareNumbers v s = liftIO $ storageServerGetBuckets s (SDMF.Keys.unStorageIndex $ SDMF.verifierStorageIndex v)
     getStorageIndex = SDMF.Keys.unStorageIndex . SDMF.verifierStorageIndex
     getRequiredTotal SDMF.Verifier{..} ss = do
@@ -82,10 +94,12 @@ instance Verifiable SDMF.Verifier where
             Left _ -> pure Nothing
             Right (_, _, sh) -> pure $ pure (fromIntegral $ SDMF.shareRequiredShares sh, fromIntegral $ SDMF.shareTotalShares sh)
 
+    deserializeShare _ = fmap (\(_, _, c) -> c) . decodeOrFail
+
 instance Readable SDMF.Reader SDMF.Verifier where
     getVerifiable = SDMF.readerVerifier
     decodeShare r shareList = do
-        cipherText <- Right <$> liftIO (SDMF.decode r (bimap fromIntegral unWhichSDMF <$> shareList))
+        cipherText <- Right <$> liftIO (SDMF.decode r (first fromIntegral <$> shareList))
         case cipherText of
             Left _ -> pure $ Left ShareDecodingFailed
             Right ct -> do
@@ -94,7 +108,7 @@ instance Readable SDMF.Reader SDMF.Verifier where
                 pure . Right $ SDMF.decrypt readKey iv ct
               where
                 readKey = SDMF.readerReadKey r
-                iv = SDMF.shareIV (unWhichSDMF . snd . head $ shareList)
+                iv = SDMF.shareIV (snd . head $ shareList)
 
 print' :: MonadIO m => String -> m ()
 -- print' = liftIO . putStrLn
-- 
GitLab