diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index e0759ba9ced296fd254ee33109adf1def404ecca..b13d79668f49bb31f8da165e05b199a92df24a00 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -15,6 +15,7 @@ import Control.Exception ( Exception , throwIO , catch + , try ) import Data.Text @@ -27,6 +28,7 @@ import Data.IORef ( IORef , newIORef , modifyIORef + , atomicModifyIORef' , readIORef ) import qualified Database.SQLite.Simple as Sqlite @@ -110,10 +112,18 @@ data VoucherDatabaseState = | SQLiteDB { conn :: Sqlite.Connection } instance VoucherDatabase VoucherDatabaseState where - payForVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher pay = do - result <- pay - modifyIORef paid (Set.insert voucher) - return result + payForVoucher MemoryDB{ paid = paidRef, redeemed = redeemed } voucher pay = do + -- Surely far from ideal... + paid <- readIORef paidRef + if Set.member voucher paid + -- Avoid processing the payment if the voucher is already paid. + then throwIO AlreadyPaid + else + do + result <- pay + -- Only modify the paid set if the payment succeeds. + modifyIORef paidRef (Set.insert voucher) + return result payForVoucher SQLiteDB{ conn = conn } voucher pay = insertVoucher conn voucher pay diff --git a/test/Persistence.hs b/test/Persistence.hs index acc5a722d8fc573c152742a51f80476c178c882d..9ecb70ecb079494dcb4397868e866b9b6dec8a48 100644 --- a/test/Persistence.hs +++ b/test/Persistence.hs @@ -12,7 +12,7 @@ import qualified Data.Text as Text import Control.Exception ( Exception , throwIO - , catch + , try ) import Test.Tasty @@ -55,7 +55,10 @@ tests = testGroup "Persistence" voucher = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" fingerprint = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" +paySuccessfully :: IO () paySuccessfully = return () + +failPayment :: IO () failPayment = throwIO ArbitraryException makeVoucherPaymentTests @@ -92,15 +95,16 @@ makeVoucherPaymentTests label makeDatabase = assertEqual "re-redeeming paid voucher" (Left AlreadyRedeemed) second , testCase "pay with error" $ do db <- makeDatabase - payForVoucher db voucher failPayment - `catch` assertEqual "failing a payment for a voucher" ArbitraryException + payResult <- try $ payForVoucher db voucher failPayment + assertEqual "failing a payment for a voucher" (Left ArbitraryException) payResult result <- redeemVoucher db voucher fingerprint assertEqual "redeeming voucher with failed payment" (Left NotPaid) result , testCase "disallowed double payment" $ do db <- makeDatabase let pay = payForVoucher db voucher paySuccessfully () <- pay - pay `catch` assertEqual "double-paying for a voucher" AlreadyPaid + payResult <- try pay + assertEqual "double-paying for a voucher" (Left AlreadyPaid) payResult redeemResult <- redeemVoucher db voucher fingerprint assertEqual "redeeming double-paid voucher" (Right ()) redeemResult ]