diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs index 6655c8aaba550bb36e275854ba6fcbdc83787842..f32512abaaeee4eaa4ceeba6623de581aed2bf16 100644 --- a/src/PaymentServer/Main.hs +++ b/src/PaymentServer/Main.hs @@ -64,7 +64,7 @@ import qualified Web.Stripe.Client as Stripe import PaymentServer.Persistence ( memory - , getDBConnection + , sqlite ) import PaymentServer.Issuer ( trivialIssue @@ -281,7 +281,7 @@ getApp config = getDatabase ServerConfig{ database, databasePath } = case (database, databasePath) of (Memory, Nothing) -> Right memory - (SQLite3, Just path) -> Right (getDBConnection path) + (SQLite3, Just path) -> Right (sqlite path) _ -> Left "invalid options" stripeConfig ServerConfig diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index 324a8880e5ca8f6f2af12ee5e8435742f2c23acc..fbb9720caf7392115ba36dff1e39249424f01a3d 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -1,5 +1,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeSynonymInstances #-} + module PaymentServer.Persistence ( Voucher , Fingerprint @@ -8,14 +9,14 @@ module PaymentServer.Persistence , VoucherDatabase(payForVoucher, redeemVoucher) , VoucherDatabaseState(MemoryDB, SQLiteDB) , memory - , getDBConnection + , sqlite ) where import Control.Exception ( Exception , throwIO , catch - , try + , bracket ) import Data.Text @@ -28,7 +29,6 @@ import Data.IORef ( IORef , newIORef , modifyIORef - , atomicModifyIORef' , readIORef ) import qualified Database.SQLite.Simple as Sqlite @@ -111,7 +111,7 @@ data VoucherDatabaseState = -- redemption. , redeemed :: IORef (Map.Map Voucher Fingerprint) } - | SQLiteDB { conn :: Sqlite.Connection } + | SQLiteDB { connect :: IO Sqlite.Connection } instance VoucherDatabase VoucherDatabaseState where payForVoucher MemoryDB{ paid = paidRef, redeemed = redeemed } voucher pay = do @@ -127,7 +127,8 @@ instance VoucherDatabase VoucherDatabaseState where modifyIORef paidRef (Set.insert voucher) return result - payForVoucher SQLiteDB{ conn = conn } voucher pay = + payForVoucher SQLiteDB{ connect = connect } voucher pay = + bracket connect Sqlite.close $ \conn -> insertVoucher conn voucher pay redeemVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher fingerprint = do @@ -136,11 +137,14 @@ instance VoucherDatabase VoucherDatabaseState where let insertFn = (modifyIORef redeemed .) . Map.insert redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn - redeemVoucher SQLiteDB { conn = conn } voucher fingerprint = Sqlite.withExclusiveTransaction conn $ do - unpaid <- isVoucherUnpaid conn voucher - existingFingerprint <- getVoucherFingerprint conn voucher - let insertFn = insertVoucherAndFingerprint conn - redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn + redeemVoucher SQLiteDB { connect = connect } voucher fingerprint = + bracket connect Sqlite.close $ \conn -> + Sqlite.withExclusiveTransaction conn $ + do + unpaid <- isVoucherUnpaid conn voucher + existingFingerprint <- getVoucherFingerprint conn voucher + let insertFn = insertVoucherAndFingerprint conn + redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn -- | Allow a voucher to be redeemed if it has been paid for and not redeemed -- before or redeemed with the same fingerprint. @@ -225,12 +229,18 @@ insertVoucherAndFingerprint :: Sqlite.Connection -> Voucher -> Fingerprint -> IO insertVoucherAndFingerprint dbConn voucher fingerprint = Sqlite.execute dbConn "INSERT INTO redeemed (voucher_id, fingerprint) VALUES ((SELECT id FROM vouchers WHERE name = ?), ?)" (voucher, fingerprint) --- | Create and open a database with a given `name` and create the `voucher` --- table and `redeemed` table with the provided schema. -getDBConnection :: Text -> IO VoucherDatabaseState -getDBConnection path = do - dbConn <- Sqlite.open (unpack path) - Sqlite.execute_ dbConn "PRAGMA foreign_keys = ON" - Sqlite.execute_ dbConn "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" - Sqlite.execute_ dbConn "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" - return $ SQLiteDB dbConn +-- | Open and create (if necessary) a SQLite3 database which can persistently +-- store all of the relevant information about voucher state. +sqlite :: Text -> IO VoucherDatabaseState +sqlite path = + let + connect :: IO Sqlite.Connection + connect = do + dbConn <- Sqlite.open (unpack path) + let exec = Sqlite.execute_ dbConn + exec "PRAGMA foreign_keys = ON" + exec "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" + exec "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" + return dbConn + in + return . SQLiteDB $ connect diff --git a/src/PaymentServer/Processors/Stripe.hs b/src/PaymentServer/Processors/Stripe.hs index 96d01b23357937ec29d5dd4aa0e121086a8886b8..8506c6b697b2981627fee02edefceca617f2653b 100644 --- a/src/PaymentServer/Processors/Stripe.hs +++ b/src/PaymentServer/Processors/Stripe.hs @@ -73,8 +73,7 @@ import Web.Stripe.Charge , TokenId(TokenId) ) import Web.Stripe.Client - ( StripeConfig(StripeConfig) - , StripeKey(StripeKey) + ( StripeConfig ) import Web.Stripe ( stripe diff --git a/src/PaymentServer/Ristretto.hs b/src/PaymentServer/Ristretto.hs index a2968331eb42368772901ade53344f447ff7b84b..7c065206f5d93963bd410fbd588c17c41e726942 100644 --- a/src/PaymentServer/Ristretto.hs +++ b/src/PaymentServer/Ristretto.hs @@ -9,7 +9,6 @@ module PaymentServer.Ristretto import Control.Exception ( bracket - , assert ) import System.IO.Unsafe ( unsafePerformIO @@ -26,7 +25,6 @@ import Foreign.Ptr ) import Foreign.C.String ( CString - , withCString , newCString , peekCString ) diff --git a/test/Persistence.hs b/test/Persistence.hs index 4d435e5b058847d2c39b3e1a59fd767dcf34f4e4..caf789fe9a57972f5a7081ccce3d3db4a44796b5 100644 --- a/test/Persistence.hs +++ b/test/Persistence.hs @@ -43,7 +43,7 @@ import PaymentServer.Persistence , PaymentError(AlreadyPaid) , VoucherDatabase(payForVoucher, redeemVoucher) , memory - , getDBConnection + , sqlite ) data ArbitraryException = ArbitraryException @@ -159,4 +159,4 @@ sqlite3DatabaseVoucherPaymentTests = do tempdir <- getTemporaryDirectory (path, handle) <- openTempFile tempdir "voucher-.db" - return . getDBConnection . Text.pack $ path + return . sqlite . Text.pack $ path