diff --git a/PaymentServer.cabal b/PaymentServer.cabal index 134ce45a529fc52c21f8806b3bc2299dd34f5ea6..c80b5951c0f6ad2e78b087a9cff2c9441d4ebaa1 100644 --- a/PaymentServer.cabal +++ b/PaymentServer.cabal @@ -86,6 +86,7 @@ test-suite PaymentServer-tests , wai-extra , servant-server , prometheus-client + , stripe-core , PaymentServer default-language: Haskell2010 diff --git a/nix/PaymentServer.nix b/nix/PaymentServer.nix index 861ed69234842b225433f4a88831db097ca551a7..a8df27a4643c47274a384d63fc4d6c9d8e99a530 100644 --- a/nix/PaymentServer.nix +++ b/nix/PaymentServer.nix @@ -115,6 +115,7 @@ in { system, compiler, flags, pkgs, hsPkgs, pkgconfPkgs, ... }: (hsPkgs."wai-extra" or (buildDepError "wai-extra")) (hsPkgs."servant-server" or (buildDepError "servant-server")) (hsPkgs."prometheus-client" or (buildDepError "prometheus-client")) + (hsPkgs."stripe-core" or (buildDepError "stripe-core")) (hsPkgs."PaymentServer" or (buildDepError "PaymentServer")) ]; }; diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index f4ae39fc91d8ce877fa0b68f9bfd057d528c99bf..8e0b25eca7370f1adc1efac142065510a1befb57 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -7,10 +7,14 @@ module PaymentServer.Persistence , Fingerprint , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint, DatabaseUnavailable) , PaymentError(AlreadyPaid, PaymentFailed) + , ProcessorResult , VoucherDatabase(payForVoucher, redeemVoucher, redeemVoucherWithCounter) , VoucherDatabaseState(MemoryDB, SQLiteDB) , memory , sqlite + , upgradeSchema + , latestVersion + , readVersion ) where import Control.Exception @@ -47,6 +51,9 @@ import Data.Maybe import Web.Stripe.Error ( StripeError ) +import Web.Stripe.Types + ( ChargeId(ChargeId) + ) -- | A voucher is a unique identifier which can be associated with a payment. -- A paid voucher can be redeemed for ZKAPs which can themselves be exchanged @@ -108,19 +115,28 @@ type Fingerprint = Text -- allowed). type RedemptionKey = (Voucher, Integer) +-- | The result of completing payment processing. This is either an error +-- indicating that the payment has *not* been completed (funds will not move) +-- or a payment processor-specific identifier for the completed transaction +-- (funds will move). +type ProcessorResult = Either PaymentError ChargeId + -- | A VoucherDatabase provides persistence for state related to vouchers. class VoucherDatabase d where -- | Change the state of the given voucher to indicate that it has been paid. payForVoucher - :: d -- ^ The database in which to record the change - -> Voucher -- ^ A voucher which should be considered paid - -> IO a -- ^ An operation which completes the payment. This is - -- evaluated in the context of a database transaction so - -- that if it fails the voucher is not marked as paid in - -- the database but if it succeeds the database state is - -- not confused by a competing transaction run around the - -- same time. - -> IO a + :: d + -- ^ The database in which to record the change + -> Voucher + -- ^ A voucher which should be considered paid + -> IO ProcessorResult + -- ^ An operation which completes the payment. This is evaluated in the + -- context of a database transaction so that if it fails the voucher is + -- not marked as paid in the database but if it succeeds the database + -- state is not confused by a competing transaction run around the same + -- time. + -> IO ProcessorResult + -- ^ The result of the attempt to complete payment processing. -- | Attempt to redeem a voucher. If it has not been redeemed before or it -- has been redeemed with the same fingerprint, the redemption succeeds. @@ -177,8 +193,12 @@ instance VoucherDatabase VoucherDatabaseState where else do result <- pay - -- Only modify the paid set if the payment succeeds. - modifyIORef paidRef (Set.insert voucher) + case result of + Right chargeId -> + -- Only modify the paid set if the payment succeeds. + modifyIORef paidRef (Set.insert voucher) + + Left _ -> return () return result payForVoucher SQLiteDB{ connect = connect } voucher pay = @@ -326,32 +346,34 @@ getVoucherFingerprint dbConn (voucher, counter) = listToMaybe <$> Sqlite.query dbConn sql ((voucher :: Text), (counter :: Integer)) -- | Mark the given voucher as paid in the database. -insertVoucher :: Sqlite.Connection -> Voucher -> IO a -> IO a +insertVoucher :: Sqlite.Connection -> Voucher -> IO ProcessorResult -> IO ProcessorResult insertVoucher dbConn voucher pay = - -- Begin an immediate transaction so that it includes the IO. The first - -- thing we do is execute our one and only statement so the transaction is - -- immediate anyway but it doesn't hurt to be explicit. - Sqlite.withImmediateTransaction dbConn $ + -- Begin an immediate transaction so that it includes the IO. The + -- transaction is immediate so that we can first check that the voucher is + -- unique and then proceed to do the IO without worrying that another + -- request will concurrently begin operating on the same voucher. + Sqlite.withExclusiveTransaction dbConn $ do - -- Vouchers must be unique in this table. This might fail if someone is - -- trying to double-pay for a voucher. In this case, we won't ever - -- finalize the payment. - Sqlite.execute dbConn "INSERT INTO vouchers (name) VALUES (?)" (Sqlite.Only voucher) - `catch` handleConstraintError - -- If we managed to insert the voucher, try to finalize the payment. If - -- this succeeds, the transaction is committed and we expect the payment - -- system to actually be moving money around. If it fails, we expect the - -- payment system *not* to move money around and the voucher should not be - -- marked as paid. The transaction will be rolled back so, indeed, it - -- won't be marked thus. - pay - - where - handleConstraintError Sqlite.SQLError { Sqlite.sqlError = Sqlite.ErrorConstraint } = - throwIO AlreadyPaid - handleConstraintError e = - throwIO e - + -- Vouchers must be unique in this table. Check to see if this one + -- already exists. + rows <- Sqlite.query dbConn "SELECT 1 FROM vouchers WHERE name = ?" (Sqlite.Only voucher) :: IO [Sqlite.Only Int] + if length rows /= 0 + then throwIO AlreadyPaid + else + do + -- If the voucher isn't present yet, try to finalize the payment. If + -- this succeeds, the transaction is committed and we expect the + -- payment system to actually be moving money around. If it fails, we + -- expect the payment system *not* to move money around and the + -- voucher should not be marked as paid. The transaction will be + -- rolled back so, indeed, it won't be marked thus. + result <- pay + case result of + Right (ChargeId chargeId) -> do + Sqlite.execute dbConn "INSERT INTO vouchers (name, charge_id) VALUES (?, ?)" (voucher, chargeId) + return result + Left err -> + return result -- | Mark the given voucher as having been redeemed (with the given -- fingerprint) in the database. @@ -373,9 +395,7 @@ sqlite path = let exec = Sqlite.execute_ dbConn exec "PRAGMA busy_timeout = 60000" exec "PRAGMA foreign_keys = ON" - Sqlite.withExclusiveTransaction dbConn $ do - exec "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" - exec "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, counter INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" + Sqlite.withExclusiveTransaction dbConn (upgradeSchema latestVersion dbConn) return dbConn connect :: IO Sqlite.Connection @@ -383,3 +403,103 @@ sqlite path = bracketOnError (Sqlite.open . unpack $ path) Sqlite.close initialize in return . SQLiteDB $ connect + + +-- | updateVersions gives the SQL statements necessary to initialize the +-- database schema at each version that has ever existed. The first element +-- is a list of SQL statements that modify an empty schema to create the first +-- version. The second element is a list of SQL statements that modify the +-- first version to create the second version. etc. +updateVersions :: [[Sqlite.Query]] +updateVersions = + [ [ "CREATE TABLE vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" + , "CREATE TABLE redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, counter INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" + ] + , [ "CREATE TABLE version AS SELECT 2 AS version" + , "ALTER TABLE vouchers ADD COLUMN charge_id" + ] + ] + +latestVersion :: Int +latestVersion = length updateVersions + +-- | readVersion reads the schema version out of a database using the given +-- query function. Since not all versions of the schema had an explicit +-- version marker, it digs around a little bit to find the answer. +readVersion :: Sqlite.Connection -> IO (Either UpgradeError Int) +readVersion conn = do + versionExists <- doesTableExist "version" + if versionExists + -- If there is a version table then it knows the answer. + then + do + versions <- Sqlite.query_ conn "SELECT version FROM version" :: IO [Sqlite.Only Int] + case versions of + [] -> return $ Left VersionMissing + (Sqlite.Only v):[] -> return $ Right v + vs -> return $ Left $ ExcessVersions (map Sqlite.fromOnly vs) + else + do + vouchersExists <- doesTableExist "vouchers" + if vouchersExists + -- If there is a vouchers table then we have version 1 + then return $ Right 1 + -- Otherwise we have version 0 + else return $ Right 0 + + where + doesTableExist :: Text -> IO Bool + doesTableExist name = do + (Sqlite.Only count):[] <- + Sqlite.query + conn + "SELECT COUNT(*) FROM [sqlite_master] WHERE [type] = 'table' AND [name] = ?" + (Sqlite.Only name) :: IO [Sqlite.Only Int] + return $ count > 0 + + + +-- | upgradeSchema determines what schema changes need to be applied to the +-- database associated with a connection to make the schema match the +-- requested version. +upgradeSchema :: Int -> Sqlite.Connection -> IO (Either UpgradeError ()) +upgradeSchema targetVersion conn = do + errOrCurrentVersion <- readVersion conn + case errOrCurrentVersion of + Left err -> return $ Left err + Right currentVersion -> perhapsUpgrade targetVersion currentVersion + + where + perhapsUpgrade :: Int -> Int -> IO (Either UpgradeError ()) + perhapsUpgrade targetVersion currentVersion = + case compareVersion targetVersion currentVersion of + Lesser -> return $ Left DatabaseSchemaTooNew + Equal -> return $ Right () + Greater -> runUpgrades currentVersion targetVersion + + runUpgrades :: Int -> Int -> IO (Either UpgradeError ()) + runUpgrades currentVersion targetVersion = + let + upgrades :: [[Sqlite.Query]] + upgrades = drop currentVersion $ take targetVersion updateVersions + + oneStep :: [Sqlite.Query] -> IO [()] + oneStep = mapM $ Sqlite.execute_ conn + in do + mapM oneStep upgrades + return $ Right () + + +data UpgradeError + = VersionMissing + | ExcessVersions [Int] + | DatabaseSchemaTooNew + deriving (Show, Eq) + +data ComparisonResult = Lesser | Equal | Greater + +compareVersion :: Int -> Int -> ComparisonResult +compareVersion a b + | a < b = Lesser + | a == b = Equal + | otherwise = Greater diff --git a/src/PaymentServer/Processors/Stripe.hs b/src/PaymentServer/Processors/Stripe.hs index c279b4bbaf7f6601f55835dceb23a5b95f2de647..1696eaf619feb53b5ca46b246b13434a44572280 100644 --- a/src/PaymentServer/Processors/Stripe.hs +++ b/src/PaymentServer/Processors/Stripe.hs @@ -1,6 +1,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE NamedFieldPuns #-} module PaymentServer.Processors.Stripe ( StripeAPI @@ -47,7 +48,7 @@ import Web.Stripe.Error , StripeErrorType(InvalidRequest, APIError, ConnectionFailure, CardError) ) import Web.Stripe.Types - ( Charge(Charge, chargeMetaData) + ( Charge(Charge, chargeId) , MetaData(MetaData) , Currency ) @@ -70,6 +71,7 @@ import PaymentServer.Persistence ( Voucher , VoucherDatabase(payForVoucher) , PaymentError(AlreadyPaid, PaymentFailed) + , ProcessorResult ) data Acknowledgement = Ok @@ -149,7 +151,7 @@ withSuccessFailureMetrics attemptCount successCount op = do charge :: VoucherDatabase d => StripeConfig -> d -> Charges -> Handler Acknowledgement charge stripeConfig d (Charges token voucher amount currency) = do currency' <- getCurrency currency - result <- liftIO ((payForVoucher d voucher (completeStripeCharge currency')) :: IO (Either PaymentError Charge)) + result <- liftIO ((payForVoucher d voucher (completeStripeCharge currency')) :: IO ProcessorResult) case result of Left AlreadyPaid -> throwError voucherAlreadyPaid @@ -157,8 +159,7 @@ charge stripeConfig d (Charges token voucher amount currency) = do liftIO $ print "Stripe createCharge failed:" liftIO $ print msg throwError . errorForStripeType $ errorType - Right Charge { chargeMetaData = metadata } -> - checkVoucherMetadata metadata + Right chargeId -> return Ok where getCurrency :: Text -> Handler Currency getCurrency maybeCurrency = @@ -167,30 +168,20 @@ charge stripeConfig d (Charges token voucher amount currency) = do Nothing -> throwError unsupportedCurrency tokenId = TokenId token - completeStripeCharge :: Currency -> IO (Either PaymentError Charge) + completeStripeCharge :: Currency -> IO ProcessorResult completeStripeCharge currency' = do - result <- (stripe stripeConfig charge) :: IO (Either StripeError Charge) + result <- stripe stripeConfig charge case result of Left any -> return . Left $ PaymentFailed any - Right any -> - return . Right $ any + Right (Charge { chargeId }) -> + return . Right $ chargeId where charge = createCharge (Amount amount) currency' -&- tokenId -&- MetaData [("Voucher", voucher)] - checkVoucherMetadata :: MetaData -> Handler Acknowledgement - checkVoucherMetadata metadata = - -- verify that we are getting the same metadata that we sent. - case metadata of - MetaData [("Voucher", v)] -> - if v == voucher - then return Ok - else throwError voucherCodeMismatch - _ -> throwError voucherCodeNotFound - -- "Invalid request errors arise when your request has invalid parameters." errorForStripeType InvalidRequest = internalServerError @@ -211,7 +202,6 @@ charge stripeConfig d (Charges token voucher amount currency) = do serviceUnavailable = jsonErr 503 "Service temporarily unavailable" internalServerError = jsonErr 500 "Internal server error" - voucherCodeMismatch = jsonErr 500 "Voucher code mismatch" unsupportedCurrency = jsonErr 400 "Invalid currency specified" voucherCodeNotFound = jsonErr 400 "Voucher code not found" stripeChargeFailed = jsonErr 400 "Stripe charge didn't succeed" diff --git a/test/Persistence.hs b/test/Persistence.hs index 4afbb11204b170bc5a3146362e4769f2bddbad3c..ba91eecc7ca421f674ea7d9d94f65f99a61ccd51 100644 --- a/test/Persistence.hs +++ b/test/Persistence.hs @@ -38,15 +38,27 @@ import System.Directory import qualified Database.SQLite.Simple as Sqlite +import Web.Stripe.Types + ( ChargeId(ChargeId) + ) +import Web.Stripe.Error + ( StripeErrorType(CardError) + , StripeError(StripeError) + ) + import PaymentServer.Persistence ( Voucher , Fingerprint , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint, DatabaseUnavailable) - , PaymentError(AlreadyPaid) + , PaymentError(AlreadyPaid, PaymentFailed) , VoucherDatabase(payForVoucher, redeemVoucher, redeemVoucherWithCounter) , VoucherDatabaseState(SQLiteDB) + , ProcessorResult , memory , sqlite + , upgradeSchema + , latestVersion + , readVersion ) data ArbitraryException = ArbitraryException @@ -58,6 +70,7 @@ tests :: TestTree tests = testGroup "Persistence" [ memoryDatabaseVoucherPaymentTests , sqlite3DatabaseVoucherPaymentTests + , sqlite3DatabaseSchemaTests ] -- Some dummy values that should be replaced by the use of QuickCheck. @@ -66,14 +79,24 @@ anotherVoucher = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz" fingerprint = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" anotherFingerprint = "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc" +aChargeId :: ChargeId +aChargeId = ChargeId "abc" + -- Mock a successful payment. -paySuccessfully :: IO () -paySuccessfully = return () +paySuccessfully :: IO ProcessorResult +paySuccessfully = return . Right $ aChargeId -- Mock a failed payment. -failPayment :: IO () +failPayment :: IO ProcessorResult failPayment = throwIO ArbitraryException +-- Mock a payment that fails at the processor rather than with an IO +-- exception. +aStripeError :: StripeError +aStripeError = StripeError CardError "Card rejected because reasons" Nothing Nothing Nothing +failPaymentProcessing :: IO ProcessorResult +failPaymentProcessing = return $ Left $ PaymentFailed aStripeError + -- | Create a group of tests related to voucher payment and redemption. makeVoucherPaymentTests :: VoucherDatabase d @@ -92,13 +115,13 @@ makeVoucherPaymentTests label makeDatabase = , testCase "paid for" $ do connect <- makeDatabase conn <- connect - () <- payForVoucher conn voucher paySuccessfully + Right _ <- payForVoucher conn voucher paySuccessfully result <- redeemVoucher conn voucher fingerprint assertEqual "redeeming paid voucher" (Right True) result , testCase "allowed double redemption" $ do connect <- makeDatabase conn <- connect - () <- payForVoucher conn voucher paySuccessfully + Right _ <- payForVoucher conn voucher paySuccessfully let redeem = redeemVoucher conn voucher fingerprint first <- redeem second <- redeem @@ -107,7 +130,7 @@ makeVoucherPaymentTests label makeDatabase = , testCase "disallowed double redemption" $ do connect <- makeDatabase conn <- connect - () <- payForVoucher conn voucher paySuccessfully + Right _ <- payForVoucher conn voucher paySuccessfully let redeem = redeemVoucher conn voucher first <- redeem fingerprint second <- redeem (Text.cons 'a' $ Text.tail fingerprint) @@ -116,7 +139,7 @@ makeVoucherPaymentTests label makeDatabase = , testCase "allowed redemption varying by counter" $ do connect <- makeDatabase conn <- connect - () <- payForVoucher conn voucher paySuccessfully + Right _ <- payForVoucher conn voucher paySuccessfully let redeem = redeemVoucherWithCounter conn voucher first <- redeem fingerprint 0 second <- redeem anotherFingerprint 1 @@ -125,12 +148,20 @@ makeVoucherPaymentTests label makeDatabase = , testCase "disallowed redemption varying by counter but not fingerprint" $ do connect <- makeDatabase conn <- connect - () <- payForVoucher conn voucher paySuccessfully + Right _ <- payForVoucher conn voucher paySuccessfully let redeem = redeemVoucherWithCounter conn voucher first <- redeem fingerprint 0 second <- redeem fingerprint 1 assertEqual "redeemed with counter 0" (Right True) first assertEqual "redeemed with counter 1" (Left DuplicateFingerprint) second + , testCase "pay with processor error" $ do + connect <- makeDatabase + conn <- connect + actual <- payForVoucher conn voucher failPaymentProcessing + let expected = Left $ PaymentFailed aStripeError + assertEqual "failing payment processing for a voucher" expected actual + result <- redeemVoucher conn voucher fingerprint + assertEqual "redeeming voucher with failed payment" (Left NotPaid) result , testCase "pay with exception" $ do connect <- makeDatabase conn <- connect @@ -142,7 +173,7 @@ makeVoucherPaymentTests label makeDatabase = connect <- makeDatabase conn <- connect let pay = payForVoucher conn voucher paySuccessfully - () <- pay + Right _ <- pay payResult <- try pay assertEqual "double-paying for a voucher" (Left AlreadyPaid) payResult redeemResult <- redeemVoucher conn voucher fingerprint @@ -159,15 +190,15 @@ makeVoucherPaymentTests label makeDatabase = withAsync anotherPayment $ \p2 -> do waitBoth p1 p2 - assertEqual "Both payments should succeed" ((), ()) result + assertEqual "Both payments should succeed" (Right aChargeId, Right aChargeId) result , testCase "concurrent redemption" $ do connect <- makeDatabase connA <- connect connB <- connect -- It doesn't matter which connection pays for the vouchers. They -- payments are concurrent and the connections are to the same database. - () <- payForVoucher connA voucher paySuccessfully - () <- payForVoucher connA anotherVoucher paySuccessfully + Right _ <- payForVoucher connA voucher paySuccessfully + Right _ <- payForVoucher connA anotherVoucher paySuccessfully -- It does matter which connection is used to redeem the voucher. A -- connection can only do one thing at a time. @@ -214,8 +245,9 @@ sqlite3DatabaseVoucherPaymentTests = normalConn <- connect fastBusyConn <- fastBusyConnection connect Sqlite.withExclusiveTransaction normalConn $ do + let expected = Left DatabaseUnavailable result <- redeemVoucher fastBusyConn voucher fingerprint - assertEqual "Redeeming voucher while database busy" result $ Left DatabaseUnavailable + assertEqual "Redeeming voucher while database busy" expected result ] where fastBusyConnection @@ -226,3 +258,44 @@ sqlite3DatabaseVoucherPaymentTests = -- Tweak the timeout down so the test completes quickly Sqlite.execute_ conn "PRAGMA busy_timeout = 0" return . SQLiteDB . return $ conn + + +sqlite3DatabaseSchemaTests :: TestTree +sqlite3DatabaseSchemaTests = + testGroup "SQLite3 schema" + [ testCase "initialize empty database" $ + -- upgradeSchema can start from nothing and upgrade the database to any + -- defined schema version. We upgrade to the latest version because that + -- implies upgrading all the intermediate versions. It probably wouldn't + -- hurt to target every intermediate version specifically, though. I + -- think that's what SmallCheck is for? + Sqlite.withConnection ":memory:" $ \conn -> do + upgradeSchema latestVersion conn + let expected = Right latestVersion + actual <- readVersion conn + assertEqual "The recorded schema version should be the latest value" expected actual + + , testCase "identify version 0" $ + -- readVersion identifies an empty database schema as version 0 + Sqlite.withConnection ":memory:" $ \conn -> do + let expected = Right 0 + actual <- readVersion conn + assertEqual "An empty database schema is version 0" expected actual + + , testCase "identify version 1" $ + -- readVersion identifies schema version 1 + Sqlite.withConnection ":memory:" $ \conn -> do + upgradeSchema 1 conn + let expected = Right 1 + actual <- readVersion conn + assertEqual "readVersion identifies database schema version 1" expected actual + + , testCase "identify version 2" $ + -- readVersion identifies schema version 1 + Sqlite.withConnection ":memory:" $ \conn -> do + upgradeSchema 2 conn + let expected = Right 2 + actual <- readVersion conn + assertEqual "readVersion identifies database schema version 2" expected actual + + ]