diff --git a/PaymentServer.cabal b/PaymentServer.cabal index 0c5286c1282dc7f33ba50e06852314b5363a130f..dd51a39404009e27bf993c5eac9c2842374ef8e2 100644 --- a/PaymentServer.cabal +++ b/PaymentServer.cabal @@ -73,6 +73,7 @@ test-suite PaymentServer-tests , tasty , tasty-hunit , directory + , async , PaymentServer default-language: Haskell2010 diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs index 6655c8aaba550bb36e275854ba6fcbdc83787842..f32512abaaeee4eaa4ceeba6623de581aed2bf16 100644 --- a/src/PaymentServer/Main.hs +++ b/src/PaymentServer/Main.hs @@ -64,7 +64,7 @@ import qualified Web.Stripe.Client as Stripe import PaymentServer.Persistence ( memory - , getDBConnection + , sqlite ) import PaymentServer.Issuer ( trivialIssue @@ -281,7 +281,7 @@ getApp config = getDatabase ServerConfig{ database, databasePath } = case (database, databasePath) of (Memory, Nothing) -> Right memory - (SQLite3, Just path) -> Right (getDBConnection path) + (SQLite3, Just path) -> Right (sqlite path) _ -> Left "invalid options" stripeConfig ServerConfig diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index 324a8880e5ca8f6f2af12ee5e8435742f2c23acc..fbb9720caf7392115ba36dff1e39249424f01a3d 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -1,5 +1,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeSynonymInstances #-} + module PaymentServer.Persistence ( Voucher , Fingerprint @@ -8,14 +9,14 @@ module PaymentServer.Persistence , VoucherDatabase(payForVoucher, redeemVoucher) , VoucherDatabaseState(MemoryDB, SQLiteDB) , memory - , getDBConnection + , sqlite ) where import Control.Exception ( Exception , throwIO , catch - , try + , bracket ) import Data.Text @@ -28,7 +29,6 @@ import Data.IORef ( IORef , newIORef , modifyIORef - , atomicModifyIORef' , readIORef ) import qualified Database.SQLite.Simple as Sqlite @@ -111,7 +111,7 @@ data VoucherDatabaseState = -- redemption. , redeemed :: IORef (Map.Map Voucher Fingerprint) } - | SQLiteDB { conn :: Sqlite.Connection } + | SQLiteDB { connect :: IO Sqlite.Connection } instance VoucherDatabase VoucherDatabaseState where payForVoucher MemoryDB{ paid = paidRef, redeemed = redeemed } voucher pay = do @@ -127,7 +127,8 @@ instance VoucherDatabase VoucherDatabaseState where modifyIORef paidRef (Set.insert voucher) return result - payForVoucher SQLiteDB{ conn = conn } voucher pay = + payForVoucher SQLiteDB{ connect = connect } voucher pay = + bracket connect Sqlite.close $ \conn -> insertVoucher conn voucher pay redeemVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher fingerprint = do @@ -136,11 +137,14 @@ instance VoucherDatabase VoucherDatabaseState where let insertFn = (modifyIORef redeemed .) . Map.insert redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn - redeemVoucher SQLiteDB { conn = conn } voucher fingerprint = Sqlite.withExclusiveTransaction conn $ do - unpaid <- isVoucherUnpaid conn voucher - existingFingerprint <- getVoucherFingerprint conn voucher - let insertFn = insertVoucherAndFingerprint conn - redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn + redeemVoucher SQLiteDB { connect = connect } voucher fingerprint = + bracket connect Sqlite.close $ \conn -> + Sqlite.withExclusiveTransaction conn $ + do + unpaid <- isVoucherUnpaid conn voucher + existingFingerprint <- getVoucherFingerprint conn voucher + let insertFn = insertVoucherAndFingerprint conn + redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn -- | Allow a voucher to be redeemed if it has been paid for and not redeemed -- before or redeemed with the same fingerprint. @@ -225,12 +229,18 @@ insertVoucherAndFingerprint :: Sqlite.Connection -> Voucher -> Fingerprint -> IO insertVoucherAndFingerprint dbConn voucher fingerprint = Sqlite.execute dbConn "INSERT INTO redeemed (voucher_id, fingerprint) VALUES ((SELECT id FROM vouchers WHERE name = ?), ?)" (voucher, fingerprint) --- | Create and open a database with a given `name` and create the `voucher` --- table and `redeemed` table with the provided schema. -getDBConnection :: Text -> IO VoucherDatabaseState -getDBConnection path = do - dbConn <- Sqlite.open (unpack path) - Sqlite.execute_ dbConn "PRAGMA foreign_keys = ON" - Sqlite.execute_ dbConn "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" - Sqlite.execute_ dbConn "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" - return $ SQLiteDB dbConn +-- | Open and create (if necessary) a SQLite3 database which can persistently +-- store all of the relevant information about voucher state. +sqlite :: Text -> IO VoucherDatabaseState +sqlite path = + let + connect :: IO Sqlite.Connection + connect = do + dbConn <- Sqlite.open (unpack path) + let exec = Sqlite.execute_ dbConn + exec "PRAGMA foreign_keys = ON" + exec "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" + exec "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" + return dbConn + in + return . SQLiteDB $ connect diff --git a/src/PaymentServer/Processors/Stripe.hs b/src/PaymentServer/Processors/Stripe.hs index 96d01b23357937ec29d5dd4aa0e121086a8886b8..8506c6b697b2981627fee02edefceca617f2653b 100644 --- a/src/PaymentServer/Processors/Stripe.hs +++ b/src/PaymentServer/Processors/Stripe.hs @@ -73,8 +73,7 @@ import Web.Stripe.Charge , TokenId(TokenId) ) import Web.Stripe.Client - ( StripeConfig(StripeConfig) - , StripeKey(StripeKey) + ( StripeConfig ) import Web.Stripe ( stripe diff --git a/src/PaymentServer/Ristretto.hs b/src/PaymentServer/Ristretto.hs index a2968331eb42368772901ade53344f447ff7b84b..7c065206f5d93963bd410fbd588c17c41e726942 100644 --- a/src/PaymentServer/Ristretto.hs +++ b/src/PaymentServer/Ristretto.hs @@ -9,7 +9,6 @@ module PaymentServer.Ristretto import Control.Exception ( bracket - , assert ) import System.IO.Unsafe ( unsafePerformIO @@ -26,7 +25,6 @@ import Foreign.Ptr ) import Foreign.C.String ( CString - , withCString , newCString , peekCString ) diff --git a/test/Persistence.hs b/test/Persistence.hs index f425e2fd06738ca71739e91fbf8870038f8bd6b9..a66196f2433aed44fdf9bea7980c0337735e6ebc 100644 --- a/test/Persistence.hs +++ b/test/Persistence.hs @@ -15,6 +15,11 @@ import Control.Exception , try ) +import Control.Concurrent.Async + ( withAsync + , waitBoth + ) + import Test.Tasty ( TestTree , testGroup @@ -38,7 +43,7 @@ import PaymentServer.Persistence , PaymentError(AlreadyPaid) , VoucherDatabase(payForVoucher, redeemVoucher) , memory - , getDBConnection + , sqlite ) data ArbitraryException = ArbitraryException @@ -54,6 +59,7 @@ tests = testGroup "Persistence" -- Some dummy values that should be replaced by the use of QuickCheck. voucher = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" +anotherVoucher = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz" fingerprint = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" -- Mock a successful payment. @@ -68,55 +74,96 @@ failPayment = throwIO ArbitraryException makeVoucherPaymentTests :: VoucherDatabase d => Text.Text -- ^ A distinctive identifier for this group's label. - -> IO d -- ^ An operation that creates a new, empty voucher - -- database. + -> IO (IO d) -- ^ An operation that creates a new, empty voucher + -- database and results in an operation that creates + -- a new connection to that database. -> TestTree makeVoucherPaymentTests label makeDatabase = testGroup ("voucher payments (" ++ Text.unpack label ++ ")") [ testCase "not paid for" $ do - db <- makeDatabase - result <- redeemVoucher db voucher fingerprint + connect <- makeDatabase + conn <- connect + result <- redeemVoucher conn voucher fingerprint assertEqual "redeeming unpaid voucher" (Left NotPaid) result , testCase "paid for" $ do - db <- makeDatabase - () <- payForVoucher db voucher paySuccessfully - result <- redeemVoucher db voucher fingerprint + connect <- makeDatabase + conn <- connect + () <- payForVoucher conn voucher paySuccessfully + result <- redeemVoucher conn voucher fingerprint assertEqual "redeeming paid voucher" (Right ()) result , testCase "allowed double redemption" $ do - db <- makeDatabase - () <- payForVoucher db voucher paySuccessfully - let redeem = redeemVoucher db voucher fingerprint + connect <- makeDatabase + conn <- connect + () <- payForVoucher conn voucher paySuccessfully + let redeem = redeemVoucher conn voucher fingerprint first <- redeem second <- redeem assertEqual "redeeming paid voucher" (Right ()) first assertEqual "re-redeeming paid voucher" (Right ()) second , testCase "disallowed double redemption" $ do - db <- makeDatabase - () <- payForVoucher db voucher paySuccessfully - let redeem = redeemVoucher db voucher + connect <- makeDatabase + conn <- connect + () <- payForVoucher conn voucher paySuccessfully + let redeem = redeemVoucher conn voucher first <- redeem fingerprint second <- redeem (Text.cons 'a' $ Text.tail fingerprint) assertEqual "redeeming paid voucher" (Right ()) first assertEqual "re-redeeming paid voucher" (Left AlreadyRedeemed) second , testCase "pay with exception" $ do - db <- makeDatabase - payResult <- try $ payForVoucher db voucher failPayment + connect <- makeDatabase + conn <- connect + payResult <- try $ payForVoucher conn voucher failPayment assertEqual "failing a payment for a voucher" (Left ArbitraryException) payResult - result <- redeemVoucher db voucher fingerprint + result <- redeemVoucher conn voucher fingerprint assertEqual "redeeming voucher with failed payment" (Left NotPaid) result , testCase "disallowed double payment" $ do - db <- makeDatabase - let pay = payForVoucher db voucher paySuccessfully + connect <- makeDatabase + conn <- connect + let pay = payForVoucher conn voucher paySuccessfully () <- pay payResult <- try pay assertEqual "double-paying for a voucher" (Left AlreadyPaid) payResult - redeemResult <- redeemVoucher db voucher fingerprint + redeemResult <- redeemVoucher conn voucher fingerprint assertEqual "redeeming double-paid voucher" (Right ()) redeemResult + , testCase "concurrent payment" $ do + connect <- makeDatabase + connA <- connect + connB <- connect + + let payment = payForVoucher connA voucher paySuccessfully + let anotherPayment = payForVoucher connB anotherVoucher paySuccessfully + + result <- withAsync payment $ \p1 -> do + withAsync anotherPayment $ \p2 -> do + waitBoth p1 p2 + + assertEqual "Both payments should succeed" ((), ()) result + , testCase "concurrent redemption" $ do + connect <- makeDatabase + connA <- connect + connB <- connect + -- It doesn't matter which connection pays for the vouchers. They + -- payments are concurrent and the connections are to the same database. + () <- payForVoucher connA voucher paySuccessfully + () <- payForVoucher connA anotherVoucher paySuccessfully + + -- It does matter which connection is used to redeem the voucher. A + -- connection can only do one thing at a time. + let redeem = redeemVoucher connA voucher fingerprint + let anotherRedeem = redeemVoucher connB anotherVoucher fingerprint + + result <- withAsync redeem $ \r1 -> do + withAsync anotherRedeem $ \r2 -> do + waitBoth r1 r2 + + assertEqual "Both redemptions should succeed" (Right (), Right ()) result ] -- | Instantiate the persistence tests for the memory backend. memoryDatabaseVoucherPaymentTests :: TestTree -memoryDatabaseVoucherPaymentTests = makeVoucherPaymentTests "memory" memory +memoryDatabaseVoucherPaymentTests = makeVoucherPaymentTests "memory" $ do + db <- memory + return $ return db -- | Instantiate the persistence tests for the sqlite3 backend. sqlite3DatabaseVoucherPaymentTests :: TestTree @@ -125,4 +172,4 @@ sqlite3DatabaseVoucherPaymentTests = do tempdir <- getTemporaryDirectory (path, handle) <- openTempFile tempdir "voucher-.db" - getDBConnection $ Text.pack path + return . sqlite . Text.pack $ path