From 5e98b666e63ae3d8f2f050ace4d2b0ece70143cc Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Wed, 18 Mar 2020 09:51:13 -0400
Subject: [PATCH] SQLite3 Connection per redemption attempt

Connections could probably be in a per-thread cache or something but
redemption costs are going to dwarf database opening costs.

Also, filesystem caches and no benchmarks and all that.
---
 src/PaymentServer/Main.hs              |  4 +--
 src/PaymentServer/Persistence.hs       | 48 ++++++++++++++++----------
 src/PaymentServer/Processors/Stripe.hs |  3 +-
 src/PaymentServer/Ristretto.hs         |  2 --
 test/Persistence.hs                    |  4 +--
 5 files changed, 34 insertions(+), 27 deletions(-)

diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs
index 6655c8a..f32512a 100644
--- a/src/PaymentServer/Main.hs
+++ b/src/PaymentServer/Main.hs
@@ -64,7 +64,7 @@ import qualified Web.Stripe.Client as Stripe
 
 import PaymentServer.Persistence
   ( memory
-  , getDBConnection
+  , sqlite
   )
 import PaymentServer.Issuer
   ( trivialIssue
@@ -281,7 +281,7 @@ getApp config =
     getDatabase ServerConfig{ database, databasePath } =
       case (database, databasePath) of
         (Memory, Nothing) -> Right memory
-        (SQLite3, Just path) -> Right (getDBConnection path)
+        (SQLite3, Just path) -> Right (sqlite path)
         _ -> Left "invalid options"
 
     stripeConfig ServerConfig
diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs
index 324a888..fbb9720 100644
--- a/src/PaymentServer/Persistence.hs
+++ b/src/PaymentServer/Persistence.hs
@@ -1,5 +1,6 @@
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE TypeSynonymInstances #-}
+
 module PaymentServer.Persistence
   ( Voucher
   , Fingerprint
@@ -8,14 +9,14 @@ module PaymentServer.Persistence
   , VoucherDatabase(payForVoucher, redeemVoucher)
   , VoucherDatabaseState(MemoryDB, SQLiteDB)
   , memory
-  , getDBConnection
+  , sqlite
   ) where
 
 import Control.Exception
   ( Exception
   , throwIO
   , catch
-  , try
+  , bracket
   )
 
 import Data.Text
@@ -28,7 +29,6 @@ import Data.IORef
   ( IORef
   , newIORef
   , modifyIORef
-  , atomicModifyIORef'
   , readIORef
   )
 import qualified Database.SQLite.Simple as Sqlite
@@ -111,7 +111,7 @@ data VoucherDatabaseState =
     -- redemption.
   , redeemed :: IORef (Map.Map Voucher Fingerprint)
   }
-  | SQLiteDB { conn :: Sqlite.Connection }
+  | SQLiteDB { connect :: IO Sqlite.Connection }
 
 instance VoucherDatabase VoucherDatabaseState where
   payForVoucher MemoryDB{ paid = paidRef, redeemed = redeemed } voucher pay = do
@@ -127,7 +127,8 @@ instance VoucherDatabase VoucherDatabaseState where
         modifyIORef paidRef (Set.insert voucher)
         return result
 
-  payForVoucher SQLiteDB{ conn = conn } voucher pay =
+  payForVoucher SQLiteDB{ connect = connect } voucher pay =
+    bracket connect Sqlite.close $ \conn ->
     insertVoucher conn voucher pay
 
   redeemVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher fingerprint = do
@@ -136,11 +137,14 @@ instance VoucherDatabase VoucherDatabaseState where
     let insertFn = (modifyIORef redeemed .) . Map.insert
     redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn
 
-  redeemVoucher SQLiteDB { conn = conn } voucher fingerprint = Sqlite.withExclusiveTransaction conn $ do
-    unpaid <- isVoucherUnpaid conn voucher
-    existingFingerprint <- getVoucherFingerprint conn voucher
-    let insertFn = insertVoucherAndFingerprint conn
-    redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn
+  redeemVoucher SQLiteDB { connect = connect } voucher fingerprint =
+    bracket connect Sqlite.close $ \conn ->
+    Sqlite.withExclusiveTransaction conn $
+    do
+      unpaid <- isVoucherUnpaid conn voucher
+      existingFingerprint <- getVoucherFingerprint conn voucher
+      let insertFn = insertVoucherAndFingerprint conn
+      redeemVoucherHelper unpaid existingFingerprint voucher fingerprint insertFn
 
 -- | Allow a voucher to be redeemed if it has been paid for and not redeemed
 -- before or redeemed with the same fingerprint.
@@ -225,12 +229,18 @@ insertVoucherAndFingerprint :: Sqlite.Connection -> Voucher -> Fingerprint -> IO
 insertVoucherAndFingerprint dbConn voucher fingerprint =
   Sqlite.execute dbConn "INSERT INTO redeemed (voucher_id, fingerprint) VALUES ((SELECT id FROM vouchers WHERE name = ?), ?)" (voucher, fingerprint)
 
--- | Create and open a database with a given `name` and create the `voucher`
--- table and `redeemed` table with the provided schema.
-getDBConnection :: Text -> IO VoucherDatabaseState
-getDBConnection path = do
-  dbConn <- Sqlite.open (unpack path)
-  Sqlite.execute_ dbConn "PRAGMA foreign_keys = ON"
-  Sqlite.execute_ dbConn "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)"
-  Sqlite.execute_ dbConn "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))"
-  return $ SQLiteDB dbConn
+-- | Open and create (if necessary) a SQLite3 database which can persistently
+-- store all of the relevant information about voucher state.
+sqlite :: Text -> IO VoucherDatabaseState
+sqlite path =
+  let
+    connect :: IO Sqlite.Connection
+    connect = do
+      dbConn <- Sqlite.open (unpack path)
+      let exec = Sqlite.execute_ dbConn
+      exec "PRAGMA foreign_keys = ON"
+      exec "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)"
+      exec "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))"
+      return dbConn
+  in
+    return . SQLiteDB $ connect
diff --git a/src/PaymentServer/Processors/Stripe.hs b/src/PaymentServer/Processors/Stripe.hs
index 96d01b2..8506c6b 100644
--- a/src/PaymentServer/Processors/Stripe.hs
+++ b/src/PaymentServer/Processors/Stripe.hs
@@ -73,8 +73,7 @@ import Web.Stripe.Charge
   , TokenId(TokenId)
   )
 import Web.Stripe.Client
-  ( StripeConfig(StripeConfig)
-  , StripeKey(StripeKey)
+  ( StripeConfig
   )
 import Web.Stripe
   ( stripe
diff --git a/src/PaymentServer/Ristretto.hs b/src/PaymentServer/Ristretto.hs
index a296833..7c06520 100644
--- a/src/PaymentServer/Ristretto.hs
+++ b/src/PaymentServer/Ristretto.hs
@@ -9,7 +9,6 @@ module PaymentServer.Ristretto
 
 import Control.Exception
   ( bracket
-  , assert
   )
 import System.IO.Unsafe
   ( unsafePerformIO
@@ -26,7 +25,6 @@ import Foreign.Ptr
   )
 import Foreign.C.String
   ( CString
-  , withCString
   , newCString
   , peekCString
   )
diff --git a/test/Persistence.hs b/test/Persistence.hs
index 4d435e5..caf789f 100644
--- a/test/Persistence.hs
+++ b/test/Persistence.hs
@@ -43,7 +43,7 @@ import PaymentServer.Persistence
   , PaymentError(AlreadyPaid)
   , VoucherDatabase(payForVoucher, redeemVoucher)
   , memory
-  , getDBConnection
+  , sqlite
   )
 
 data ArbitraryException = ArbitraryException
@@ -159,4 +159,4 @@ sqlite3DatabaseVoucherPaymentTests =
   do
     tempdir <- getTemporaryDirectory
     (path, handle) <- openTempFile tempdir "voucher-.db"
-    return . getDBConnection . Text.pack $ path
+    return . sqlite . Text.pack $ path
-- 
GitLab