From 7b0e07ab474f0cb76578d9636188f0da2a2635b2 Mon Sep 17 00:00:00 2001 From: Jean-Paul Calderone <exarkun@twistedmatrix.com> Date: Mon, 25 Nov 2019 11:42:41 -0500 Subject: [PATCH] improved test suite and memory implementation fixes `catch` is not a good way to assert something is thrown --- src/PaymentServer/Persistence.hs | 18 ++++++++++++++---- test/Persistence.hs | 12 ++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index e0759ba..b13d796 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -15,6 +15,7 @@ import Control.Exception ( Exception , throwIO , catch + , try ) import Data.Text @@ -27,6 +28,7 @@ import Data.IORef ( IORef , newIORef , modifyIORef + , atomicModifyIORef' , readIORef ) import qualified Database.SQLite.Simple as Sqlite @@ -110,10 +112,18 @@ data VoucherDatabaseState = | SQLiteDB { conn :: Sqlite.Connection } instance VoucherDatabase VoucherDatabaseState where - payForVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher pay = do - result <- pay - modifyIORef paid (Set.insert voucher) - return result + payForVoucher MemoryDB{ paid = paidRef, redeemed = redeemed } voucher pay = do + -- Surely far from ideal... + paid <- readIORef paidRef + if Set.member voucher paid + -- Avoid processing the payment if the voucher is already paid. + then throwIO AlreadyPaid + else + do + result <- pay + -- Only modify the paid set if the payment succeeds. + modifyIORef paidRef (Set.insert voucher) + return result payForVoucher SQLiteDB{ conn = conn } voucher pay = insertVoucher conn voucher pay diff --git a/test/Persistence.hs b/test/Persistence.hs index acc5a72..9ecb70e 100644 --- a/test/Persistence.hs +++ b/test/Persistence.hs @@ -12,7 +12,7 @@ import qualified Data.Text as Text import Control.Exception ( Exception , throwIO - , catch + , try ) import Test.Tasty @@ -55,7 +55,10 @@ tests = testGroup "Persistence" voucher = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" fingerprint = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" +paySuccessfully :: IO () paySuccessfully = return () + +failPayment :: IO () failPayment = throwIO ArbitraryException makeVoucherPaymentTests @@ -92,15 +95,16 @@ makeVoucherPaymentTests label makeDatabase = assertEqual "re-redeeming paid voucher" (Left AlreadyRedeemed) second , testCase "pay with error" $ do db <- makeDatabase - payForVoucher db voucher failPayment - `catch` assertEqual "failing a payment for a voucher" ArbitraryException + payResult <- try $ payForVoucher db voucher failPayment + assertEqual "failing a payment for a voucher" (Left ArbitraryException) payResult result <- redeemVoucher db voucher fingerprint assertEqual "redeeming voucher with failed payment" (Left NotPaid) result , testCase "disallowed double payment" $ do db <- makeDatabase let pay = payForVoucher db voucher paySuccessfully () <- pay - pay `catch` assertEqual "double-paying for a voucher" AlreadyPaid + payResult <- try pay + assertEqual "double-paying for a voucher" (Left AlreadyPaid) payResult redeemResult <- redeemVoucher db voucher fingerprint assertEqual "redeeming double-paid voucher" (Right ()) redeemResult ] -- GitLab