From a182d6fdee4c94285c2db6c12b0794d816e05f50 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Mon, 25 Nov 2019 11:11:09 -0500
Subject: [PATCH] Avoid completing payment if the voucher was already paid

---
 PaymentServer.cabal                    |  12 +++
 src/PaymentServer/Persistence.hs       |  60 +++++++++++--
 src/PaymentServer/Processors/Stripe.hs |  95 +++++++++-----------
 test/Persistence.hs                    | 118 +++++++++++++++++++++++++
 test/Spec.hs                           |  20 +++++
 5 files changed, 243 insertions(+), 62 deletions(-)
 create mode 100644 test/Persistence.hs
 create mode 100644 test/Spec.hs

diff --git a/PaymentServer.cabal b/PaymentServer.cabal
index 246e201..785e0f4 100644
--- a/PaymentServer.cabal
+++ b/PaymentServer.cabal
@@ -63,6 +63,18 @@ executable PaymentServer-generate-key
                      , PaymentServer
   default-language:    Haskell2010
 
+test-suite PaymentServer-tests
+  type:            exitcode-stdio-1.0
+  hs-source-dirs:  test
+  main-is:         Spec.hs
+  other-modules:   Persistence
+  build-depends:   base
+                 , text
+                 , tasty
+                 , tasty-hunit
+                 , directory
+                 , PaymentServer
+
 source-repository head
   type:     git
   location: https://github.com/privatestorageio/PaymentServer
diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs
index edefd3e..e0759ba 100644
--- a/src/PaymentServer/Persistence.hs
+++ b/src/PaymentServer/Persistence.hs
@@ -4,12 +4,19 @@ module PaymentServer.Persistence
   ( Voucher
   , Fingerprint
   , RedeemError(NotPaid, AlreadyRedeemed)
+  , PaymentError(AlreadyPaid)
   , VoucherDatabase(payForVoucher, redeemVoucher)
   , VoucherDatabaseState(MemoryDB, SQLiteDB)
   , memory
   , getDBConnection
   ) where
 
+import Control.Exception
+  ( Exception
+  , throwIO
+  , catch
+  )
+
 import Data.Text
   ( Text
   , unpack
@@ -36,6 +43,14 @@ import Data.Maybe
 -- voucher itself.
 type Voucher = Text
 
+-- | Reasons that a voucher cannot be paid for.
+data PaymentError =
+  -- | The voucher has already been paid for.
+  AlreadyPaid
+  deriving (Show, Eq)
+
+instance Exception PaymentError
+
 -- | Reasons that a voucher cannot be redeemed.
 data RedeemError =
   -- | The voucher has not been paid for.
@@ -61,7 +76,13 @@ class VoucherDatabase d where
   payForVoucher
     :: d             -- ^ The database in which to record the change
     -> Voucher       -- ^ A voucher which should be considered paid
-    -> IO ()
+    -> IO a          -- ^ An operation which completes the payment.  This is
+                     -- evaluated in the context of a database transaction so
+                     -- that if it fails the voucher is not marked as paid in
+                     -- the database but if it succeeds the database state is
+                     -- not confused by a competing transaction run around the
+                     -- same time.
+    -> IO a
 
   -- | Attempt to redeem a voucher.  If it has not been redeemed before or it
   -- has been redeemed with the same fingerprint, the redemption succeeds.
@@ -89,11 +110,13 @@ data VoucherDatabaseState =
   | SQLiteDB { conn :: Sqlite.Connection }
 
 instance VoucherDatabase VoucherDatabaseState where
-  payForVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher =
+  payForVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher pay = do
+    result <- pay
     modifyIORef paid (Set.insert voucher)
+    return result
 
-  payForVoucher SQLiteDB{ conn = conn } voucher =
-    insertVoucher conn voucher
+  payForVoucher SQLiteDB{ conn = conn } voucher pay =
+    insertVoucher conn voucher pay
 
   redeemVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher fingerprint = do
     unpaid <- Set.notMember voucher <$> readIORef paid
@@ -157,9 +180,32 @@ getVoucherFingerprint dbConn voucher =
   listToMaybe <$> Sqlite.query dbConn "SELECT redeemed.fingerprint FROM vouchers INNER JOIN redeemed ON vouchers.id = redeemed.voucher_id AND vouchers.name = ?" (Sqlite.Only voucher)
 
 -- | Mark the given voucher as paid in the database.
-insertVoucher :: Sqlite.Connection -> Voucher -> IO ()
-insertVoucher dbConn voucher =
-  Sqlite.execute dbConn "INSERT INTO vouchers (name) VALUES (?)" (Sqlite.Only voucher)
+insertVoucher :: Sqlite.Connection -> Voucher -> IO a -> IO a
+insertVoucher dbConn voucher pay =
+  -- Begin an immediate transaction so that it includes the IO.  The first
+  -- thing we do is execute our one and only statement so the transaction is
+  -- immediate anyway but it doesn't hurt to be explicit.
+  Sqlite.withImmediateTransaction dbConn $
+  do
+    -- Vouchers must be unique in this table.  This might fail if someone is
+    -- trying to double-pay for a voucher.  In this case, we won't ever
+    -- finalize the payment.
+    Sqlite.execute dbConn "INSERT INTO vouchers (name) VALUES (?)" (Sqlite.Only voucher)
+      `catch` handleConstraintError
+    -- If we managed to insert the voucher, try to finalize the payment.  If
+    -- this succeeds, the transaction is committed and we expect the payment
+    -- system to actually be moving money around.  If it fails, we expect the
+    -- payment system *not* to move money around and the voucher should not be
+    -- marked as paid.  The transaction will be rolled back so, indeed, it
+    -- won't be marked thus.
+    pay
+
+  where
+    handleConstraintError Sqlite.SQLError { Sqlite.sqlError = Sqlite.ErrorConstraint } =
+      throwIO AlreadyPaid
+    handleConstraintError e =
+      throwIO e
+
 
 -- | Mark the given voucher as having been redeemed (with the given
 -- fingerprint) in the database.
diff --git a/src/PaymentServer/Processors/Stripe.hs b/src/PaymentServer/Processors/Stripe.hs
index a087904..bee378c 100644
--- a/src/PaymentServer/Processors/Stripe.hs
+++ b/src/PaymentServer/Processors/Stripe.hs
@@ -86,10 +86,7 @@ data Acknowledgement = Ok
 instance ToJSON Acknowledgement where
   toJSON Ok = object []
 
-type StripeAPI = WebhookAPI
-               :<|> ChargesAPI
-
-type WebhookAPI = "webhook" :> ReqBody '[JSON] Event :> Post '[JSON] Acknowledgement
+type StripeAPI = ChargesAPI
 
 -- | getVoucher finds the metadata item with the key `"Voucher"` and returns
 -- the corresponding value, or Nothing.
@@ -99,31 +96,7 @@ getVoucher (MetaData (("Voucher", value):xs)) = Just value
 getVoucher (MetaData (x:xs)) = getVoucher (MetaData xs)
 
 stripeServer :: VoucherDatabase d => StripeSecretKey -> d -> Server StripeAPI
-stripeServer key d = webhook d
-                     :<|> charge d key
-
--- | Process charge succeeded events
-webhook :: VoucherDatabase d => d -> Event -> Handler Acknowledgement
-webhook d Event{eventId=Just (EventId eventId), eventType=ChargeSucceededEvent, eventData=(ChargeEvent charge)} =
-  case getVoucher $ chargeMetaData charge of
-    Nothing ->
-      -- TODO: Record the eventId somewhere.  In all cases where we don't
-      -- associate the value of the charge with something in our system, we
-      -- probably need enough information to issue a refund.  We're early
-      -- enough in the system here that refunds are possible and not even
-      -- particularly difficult.
-      return Ok
-    Just v  -> do
-      -- TODO: What if it is a duplicate payment?  payForVoucher should be
-      -- able to indicate error I guess.
-      () <- liftIO $ payForVoucher d v
-      return Ok
-
--- Disregard anything else - but return success so that Stripe doesn't retry.
-webhook d _ =
-  -- TODO: Record the eventId somewhere.
-  return Ok
-
+stripeServer key d = charge d key
 
 -- | Browser facing API that takes token, voucher and a few other information
 -- and calls stripe charges API. If payment succeeds, then the voucher is stored
@@ -151,38 +124,50 @@ instance FromJSON Charges where
 -- and if the Charge is okay, then set the voucher as "paid" in the database.
 charge :: VoucherDatabase d => d -> StripeSecretKey -> Charges -> Handler Acknowledgement
 charge d key (Charges token voucher amount currency) = do
-  let config = StripeConfig (StripeKey key) Nothing
-      tokenId = TokenId token
   currency' <- getCurrency currency
-  result <- liftIO $ stripe config $
-    createCharge (Amount amount) currency'
-      -&- tokenId
-      -&- MetaData [("Voucher", voucher)]
+  result <- liftIO $ payForVoucher d voucher (completeStripeCharge currency')
   case result of
     Right Charge { chargeMetaData = metadata } ->
-      -- verify that we are getting the same metadata that we sent.
-      case metadata of
-        MetaData [("Voucher", v)] ->
-          if v == voucher
-            then
-            do
-              -- TODO Handle payForVoucher errors
-              liftIO $ payForVoucher d voucher
-              return Ok
-            else
-            throwError err500 { errBody = "Voucher code mismatch" }
-        _ -> throwError err400 { errBody = "Voucher code not found" }
+      checkVoucherMetadata metadata
     Left StripeError {} ->
-      let
-        errCode = (read "foo") :: Int
-      in
-        throwError err400
-        { errHTTPCode = errCode
-        , errBody = "Stripe charge didn't succeed"
-        }
+      throwError stripeChargeFailed
     where
       getCurrency :: Text -> Handler Currency
       getCurrency maybeCurrency =
         case readMaybe (unpack currency) of
           Just currency' -> return currency'
-          Nothing -> throwError err400 { errBody = "Invalid currency specified" }
+          Nothing -> throwError unsupportedCurrency
+
+      config = StripeConfig (StripeKey key) Nothing
+      tokenId = TokenId token
+      completeStripeCharge currency' = stripe config $
+        createCharge (Amount amount) currency'
+        -&- tokenId
+        -&- MetaData [("Voucher", voucher)]
+
+      checkVoucherMetadata :: MetaData -> Handler Acknowledgement
+      checkVoucherMetadata metadata =
+        -- verify that we are getting the same metadata that we sent.
+        case metadata of
+          MetaData [("Voucher", v)] ->
+            if v == voucher
+            then return Ok
+            else throwError voucherCodeMismatch
+          _ -> throwError voucherCodeNotFound
+
+      unsupportedCurrency =
+        err400
+        { errBody = "Invalid currency specified"
+        }
+      voucherCodeNotFound =
+        err400
+        { errBody = "Voucher code not found"
+        }
+      voucherCodeMismatch =
+        err500
+        { errBody = "Voucher code mismatch"
+        }
+      stripeChargeFailed =
+        err400
+        { errBody = "Stripe charge didn't succeed"
+        }
diff --git a/test/Persistence.hs b/test/Persistence.hs
new file mode 100644
index 0000000..acc5a72
--- /dev/null
+++ b/test/Persistence.hs
@@ -0,0 +1,118 @@
+{-# LANGUAGE OverloadedStrings #-}
+
+-- | Tests related to PaymentServer.Persistence and the persistence system in
+-- general.
+
+module Persistence
+  ( tests
+  ) where
+
+import qualified Data.Text as Text
+
+import Control.Exception
+  ( Exception
+  , throwIO
+  , catch
+  )
+
+import Test.Tasty
+  ( TestTree
+  , testGroup
+  )
+import Test.Tasty.HUnit
+  ( testCase
+  , assertEqual
+  )
+
+import System.IO
+  ( openTempFile
+  )
+import System.Directory
+  ( getTemporaryDirectory
+  )
+
+import PaymentServer.Persistence
+  ( Voucher
+  , Fingerprint
+  , RedeemError(NotPaid, AlreadyRedeemed)
+  , PaymentError(AlreadyPaid)
+  , VoucherDatabase(payForVoucher, redeemVoucher)
+  , memory
+  , getDBConnection
+  )
+
+data ArbitraryException = ArbitraryException
+  deriving (Show, Eq)
+
+instance Exception ArbitraryException
+
+tests :: TestTree
+tests = testGroup "Persistence"
+  [ memoryDatabaseVoucherPaymentTests
+  , sqlite3DatabaseVoucherPaymentTests
+  ]
+
+voucher = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+fingerprint = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
+
+paySuccessfully = return ()
+failPayment = throwIO ArbitraryException
+
+makeVoucherPaymentTests
+  :: VoucherDatabase d
+  => Text.Text
+  -> IO d
+  -> TestTree
+makeVoucherPaymentTests label makeDatabase =
+  testGroup ("voucher payments (" ++ Text.unpack label ++ ")")
+  [ testCase "not paid for" $ do
+      db <- makeDatabase
+      result <- redeemVoucher db voucher fingerprint
+      assertEqual "redeeming unpaid voucher" (Left NotPaid) result
+  , testCase "paid for" $ do
+      db <- makeDatabase
+      () <- payForVoucher db voucher paySuccessfully
+      result <- redeemVoucher db voucher fingerprint
+      assertEqual "redeeming paid voucher" (Right ()) result
+  , testCase "allowed double redemption" $ do
+      db <- makeDatabase
+      () <- payForVoucher db voucher paySuccessfully
+      let redeem = redeemVoucher db voucher fingerprint
+      first <- redeem
+      second <- redeem
+      assertEqual "redeeming paid voucher" (Right ()) first
+      assertEqual "re-redeeming paid voucher" (Right ()) second
+  , testCase "disallowed double redemption" $ do
+      db <- makeDatabase
+      () <- payForVoucher db voucher paySuccessfully
+      let redeem = redeemVoucher db voucher
+      first <- redeem fingerprint
+      second <- redeem (Text.cons 'a' $ Text.tail fingerprint)
+      assertEqual "redeeming paid voucher" (Right ()) first
+      assertEqual "re-redeeming paid voucher" (Left AlreadyRedeemed) second
+  , testCase "pay with error" $ do
+      db <- makeDatabase
+      payForVoucher db voucher failPayment
+        `catch` assertEqual "failing a payment for a voucher" ArbitraryException
+      result <- redeemVoucher db voucher fingerprint
+      assertEqual "redeeming voucher with failed payment" (Left NotPaid) result
+  , testCase "disallowed double payment" $ do
+      db <- makeDatabase
+      let pay = payForVoucher db voucher paySuccessfully
+      () <- pay
+      pay `catch`  assertEqual "double-paying for a voucher" AlreadyPaid
+      redeemResult <- redeemVoucher db voucher fingerprint
+      assertEqual "redeeming double-paid voucher" (Right ()) redeemResult
+  ]
+
+
+memoryDatabaseVoucherPaymentTests :: TestTree
+memoryDatabaseVoucherPaymentTests = makeVoucherPaymentTests "memory" memory
+
+sqlite3DatabaseVoucherPaymentTests :: TestTree
+sqlite3DatabaseVoucherPaymentTests =
+  makeVoucherPaymentTests "sqlite3" $
+  do
+    tempdir <- getTemporaryDirectory
+    (path, handle) <- openTempFile tempdir "voucher-.db"
+    getDBConnection $ Text.pack path
diff --git a/test/Spec.hs b/test/Spec.hs
new file mode 100644
index 0000000..5821f65
--- /dev/null
+++ b/test/Spec.hs
@@ -0,0 +1,20 @@
+-- | Collect all of the various test groups into a single tree.
+
+module Main
+  ( main
+  ) where
+
+import Test.Tasty
+  ( TestTree
+  , testGroup
+  , defaultMain
+  )
+
+import qualified Persistence
+
+tests :: TestTree
+tests = testGroup "Tests"
+  [ Persistence.tests
+  ]
+
+main = defaultMain tests
-- 
GitLab