From cc450228463e7f8648aa703ddae4845009aa173f Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Fri, 11 Sep 2020 14:58:30 -0400
Subject: [PATCH] Render the database timeout error more nicely

---
 PaymentServer.cabal              |  1 +
 src/PaymentServer/Persistence.hs | 29 +++++++++++++++---------
 src/PaymentServer/Redemption.hs  | 17 ++++++++------
 test/Persistence.hs              | 38 +++++++++++++++++++++++++++-----
 4 files changed, 61 insertions(+), 24 deletions(-)

diff --git a/PaymentServer.cabal b/PaymentServer.cabal
index 2db7528..ad233b8 100644
--- a/PaymentServer.cabal
+++ b/PaymentServer.cabal
@@ -75,6 +75,7 @@ test-suite PaymentServer-tests
                      , tasty-hunit
                      , directory
                      , async
+                     , sqlite-simple
                      , PaymentServer
   default-language: Haskell2010
 
diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs
index 8a9fe6c..0c5ff96 100644
--- a/src/PaymentServer/Persistence.hs
+++ b/src/PaymentServer/Persistence.hs
@@ -5,7 +5,7 @@
 module PaymentServer.Persistence
   ( Voucher
   , Fingerprint
-  , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint)
+  , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint, DatabaseUnavailable)
   , PaymentError(AlreadyPaid, PaymentFailed)
   , VoucherDatabase(payForVoucher, redeemVoucher, redeemVoucherWithCounter)
   , VoucherDatabaseState(MemoryDB, SQLiteDB)
@@ -72,6 +72,8 @@ data RedeemError =
   -- fingerprint.  We check for this case to prevent a misbehaving client from
   -- accidentally creating worthless tokens.
   | DuplicateFingerprint
+  -- | The database is too busy right now.  Try again later.
+  | DatabaseUnavailable
   deriving (Show, Eq)
 
 -- | A fingerprint cryptographically identifies a redemption of a voucher.
@@ -191,16 +193,21 @@ instance VoucherDatabase VoucherDatabaseState where
       fingerprint
 
   redeemVoucherWithCounter SQLiteDB { connect = connect } voucher fingerprint counter =
-    bracket connect Sqlite.close $ \conn ->
-    Sqlite.withExclusiveTransaction conn $
-    redeemVoucherHelper
-    (isVoucherPaid conn)
-    (getVoucherFingerprint conn)
-    (getVoucherCounterForFingerprint conn)
-    (insertVoucherAndFingerprint conn)
-    voucher
-    counter
-    fingerprint
+    bracket connect Sqlite.close redeemIt `catch` transformBusy
+    where
+      redeemIt conn =
+        Sqlite.withExclusiveTransaction conn $
+        redeemVoucherHelper
+        (isVoucherPaid conn)
+        (getVoucherFingerprint conn)
+        (getVoucherCounterForFingerprint conn)
+        (insertVoucherAndFingerprint conn)
+        voucher
+        counter
+        fingerprint
+
+      transformBusy (Sqlite.SQLError Sqlite.ErrorBusy _ _) =
+        return . Left $ DatabaseUnavailable
 
 
 -- | Look up the voucher, counter tuple which previously performed a
diff --git a/src/PaymentServer/Redemption.hs b/src/PaymentServer/Redemption.hs
index 6c0cf7b..d1735fc 100644
--- a/src/PaymentServer/Redemption.hs
+++ b/src/PaymentServer/Redemption.hs
@@ -54,6 +54,7 @@ import Servant
   , Handler
   , ServerError(errBody, errHeaders)
   , err400
+  , err500
   , throwError
   )
 import Servant.API
@@ -68,7 +69,7 @@ import Crypto.Hash
   )
 import PaymentServer.Persistence
   ( VoucherDatabase(redeemVoucherWithCounter)
-  , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint)
+  , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint, DatabaseUnavailable)
   , Fingerprint
   , Voucher
   )
@@ -159,7 +160,7 @@ maxCounter = 16
 
 type RedemptionAPI = ReqBody '[JSON] Redeem :> Post '[JSON] Result
 
-jsonErr400 reason = err400
+jsonErr err reason = err
   { errBody = encode reason
   , errHeaders = [ ("Content-Type", "application/json;charset=utf-8") ]
   }
@@ -192,23 +193,25 @@ retry op =
 redeem :: VoucherDatabase d => Issuer -> d -> Redeem -> Handler Result
 redeem issue database (Redeem voucher tokens counter) =
   if counter < 0 || counter >= maxCounter then
-    throwError $ jsonErr400 (CounterOutOfBounds 0 maxCounter counter)
+    throwError $ jsonErr err400 (CounterOutOfBounds 0 maxCounter counter)
   else do
 
     let fingerprint = fingerprintFromTokens tokens
     result <- liftIO . retry $ redeemVoucherWithCounter database voucher fingerprint counter
     case result of
       Left NotPaid -> do
-        throwError $ jsonErr400 Unpaid
+        throwError $ jsonErr err400 Unpaid
       Left AlreadyRedeemed -> do
-        throwError $ jsonErr400 DoubleSpend
+        throwError $ jsonErr err400 DoubleSpend
       Left DuplicateFingerprint -> do
-        throwError $ jsonErr400 $ OtherFailure "fingerprint already used"
+        throwError $ jsonErr err400 $ OtherFailure "fingerprint already used"
+      Left DatabaseUnavailable -> do
+        throwError $ jsonErr err500 $ OtherFailure "database temporarily unavailable"
       Right () -> do
         let result = issue tokens
         case result of
           Left reason -> do
-            throwError $ jsonErr400 $ OtherFailure reason
+            throwError $ jsonErr err400 $ OtherFailure reason
           Right (ChallengeBypass key signatures proof) ->
             return $ Succeeded key signatures proof
 
diff --git a/test/Persistence.hs b/test/Persistence.hs
index 7c0f975..bd28094 100644
--- a/test/Persistence.hs
+++ b/test/Persistence.hs
@@ -36,12 +36,15 @@ import System.Directory
   ( getTemporaryDirectory
   )
 
+import qualified Database.SQLite.Simple as Sqlite
+
 import PaymentServer.Persistence
   ( Voucher
   , Fingerprint
-  , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint)
+  , RedeemError(NotPaid, AlreadyRedeemed, DuplicateFingerprint, DatabaseUnavailable)
   , PaymentError(AlreadyPaid)
   , VoucherDatabase(payForVoucher, redeemVoucher, redeemVoucherWithCounter)
+  , VoucherDatabaseState(SQLiteDB)
   , memory
   , sqlite
   )
@@ -187,8 +190,31 @@ memoryDatabaseVoucherPaymentTests = makeVoucherPaymentTests "memory" $ do
 -- | Instantiate the persistence tests for the sqlite3 backend.
 sqlite3DatabaseVoucherPaymentTests :: TestTree
 sqlite3DatabaseVoucherPaymentTests =
-  makeVoucherPaymentTests "sqlite3" $
-  do
-    tempdir <- getTemporaryDirectory
-    (path, handle) <- openTempFile tempdir "voucher-.db"
-    return . sqlite . Text.pack $ path
+  testGroup ""
+  [ genericTests
+  , sqlite3Tests
+  ]
+  where
+    makeDatabase = do
+      tempdir <- getTemporaryDirectory
+      (path, handle) <- openTempFile tempdir "voucher-.db"
+      return . sqlite . Text.pack $ path
+
+    genericTests = makeVoucherPaymentTests "sqlite3" makeDatabase
+
+    sqlite3Tests =
+      testGroup "SQLite3-specific voucher"
+      [ testCase "database is busy" $ do
+          getDB <- makeDatabase
+          db <- getDB
+          case db of
+            (SQLiteDB connect) -> do
+              conn <- connect
+              -- Tweak the timeout down so the test completes quickly
+              Sqlite.execute_ conn "PRAGMA busy_timeout = 0"
+              -- Acquire a write lock before letting the application code run so that
+              -- the application code is denied the write lock.
+              Sqlite.withExclusiveTransaction conn $ do
+                result <- redeemVoucher db voucher fingerprint
+                assertEqual "Redeeming voucher while database busy" result $ Left DatabaseUnavailable
+      ]
-- 
GitLab