diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index f4ae39fc91d8ce877fa0b68f9bfd057d528c99bf..1b5ee93eda8fc9c50595d6764645b185ce79597c 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -11,6 +11,9 @@ module PaymentServer.Persistence , VoucherDatabaseState(MemoryDB, SQLiteDB) , memory , sqlite + , upgradeSchema + , latestVersion + , readVersion ) where import Control.Exception @@ -373,9 +376,7 @@ sqlite path = let exec = Sqlite.execute_ dbConn exec "PRAGMA busy_timeout = 60000" exec "PRAGMA foreign_keys = ON" - Sqlite.withExclusiveTransaction dbConn $ do - exec "CREATE TABLE IF NOT EXISTS vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" - exec "CREATE TABLE IF NOT EXISTS redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, counter INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" + Sqlite.withExclusiveTransaction dbConn (upgradeSchema latestVersion dbConn) return dbConn connect :: IO Sqlite.Connection @@ -383,3 +384,100 @@ sqlite path = bracketOnError (Sqlite.open . unpack $ path) Sqlite.close initialize in return . SQLiteDB $ connect + + +-- | updateVersions gives the SQL statements necessary to initialize the +-- database schema at each version that has ever existed. The first element +-- is a list of SQL statements that modify an empty schema to create the first +-- version. The second element is a list of SQL statements that modify the +-- first version to create the second version. etc. +updateVersions :: [[Sqlite.Query]] +updateVersions = + [ [ "CREATE TABLE vouchers (id INTEGER PRIMARY KEY, name TEXT UNIQUE)" + , "CREATE TABLE redeemed (id INTEGER PRIMARY KEY, voucher_id INTEGER, counter INTEGER, fingerprint TEXT, FOREIGN KEY (voucher_id) REFERENCES vouchers(id))" + ] + , [ "CREATE TABLE version AS SELECT 2 AS version" + , "ALTER TABLE vouchers ADD COLUMN charge_id" + ] + ] + +latestVersion :: Int +latestVersion = length updateVersions + +-- | readVersion reads the schema version out of a database using the given +-- query function. Since not all versions of the schema had an explicit +-- version marker, it digs around a little bit to find the answer. +readVersion :: Sqlite.Connection -> IO (Either UpgradeError Int) +readVersion conn = do + versionExists <- doesTableExist "version" + if versionExists + -- If there is a version table then it knows the answer. + then + do + versions <- Sqlite.query_ conn "SELECT version FROM version" :: IO [Sqlite.Only Int] + case versions of + [] -> return $ Left VersionMissing + (Sqlite.Only v):[] -> return $ Right v + vs -> return $ Left $ ExcessVersions (map Sqlite.fromOnly vs) + else + do + vouchersExists <- doesTableExist "vouchers" + if vouchersExists + -- If there is a vouchers table then we have version 1 + then return $ Right 1 + -- Otherwise we have version 0 + else return $ Right 0 + + where + doesTableExist :: Text -> IO Bool + doesTableExist name = do + (Sqlite.Only count):[] <- + Sqlite.query + conn + "SELECT COUNT(*) FROM [sqlite_master] WHERE [type] = 'table' AND [name] = ?" + (Sqlite.Only name) :: IO [Sqlite.Only Int] + return $ count > 0 + + + +-- | upgradeSchema determines what schema changes need to be applied to the +-- database associated with a connection to make the schema match the +-- requested version. +upgradeSchema :: Int -> Sqlite.Connection -> IO (Either UpgradeError ()) +upgradeSchema targetVersion conn = do + errOrCurrentVersion <- readVersion conn + case errOrCurrentVersion of + Left err -> return $ Left err + Right currentVersion -> + case compareVersion targetVersion currentVersion of + Lesser -> return $ Left DatabaseSchemaTooNew + Equal -> return $ Right () + Greater -> runUpgrades currentVersion + + where + runUpgrades :: Int -> IO (Either UpgradeError ()) + runUpgrades currentVersion = + let + upgrades :: [[Sqlite.Query]] + upgrades = drop currentVersion updateVersions + + oneStep :: [Sqlite.Query] -> IO [()] + oneStep = mapM $ Sqlite.execute_ conn + in do + mapM oneStep upgrades + return $ Right () + + +data UpgradeError + = VersionMissing + | ExcessVersions [Int] + | DatabaseSchemaTooNew + deriving (Show, Eq) + +data ComparisonResult = Lesser | Equal | Greater + +compareVersion :: Int -> Int -> ComparisonResult +compareVersion a b + | a < b = Lesser + | a == b = Equal + | otherwise = Greater diff --git a/test/Persistence.hs b/test/Persistence.hs index 4afbb11204b170bc5a3146362e4769f2bddbad3c..37f4c8e7bffa205e8fcd7dcb70fc035ffac781d6 100644 --- a/test/Persistence.hs +++ b/test/Persistence.hs @@ -47,6 +47,9 @@ import PaymentServer.Persistence , VoucherDatabaseState(SQLiteDB) , memory , sqlite + , upgradeSchema + , latestVersion + , readVersion ) data ArbitraryException = ArbitraryException @@ -58,6 +61,7 @@ tests :: TestTree tests = testGroup "Persistence" [ memoryDatabaseVoucherPaymentTests , sqlite3DatabaseVoucherPaymentTests + , sqlite3DatabaseSchemaTests ] -- Some dummy values that should be replaced by the use of QuickCheck. @@ -214,8 +218,9 @@ sqlite3DatabaseVoucherPaymentTests = normalConn <- connect fastBusyConn <- fastBusyConnection connect Sqlite.withExclusiveTransaction normalConn $ do + let expected = Left DatabaseUnavailable result <- redeemVoucher fastBusyConn voucher fingerprint - assertEqual "Redeeming voucher while database busy" result $ Left DatabaseUnavailable + assertEqual "Redeeming voucher while database busy" expected result ] where fastBusyConnection @@ -226,3 +231,20 @@ sqlite3DatabaseVoucherPaymentTests = -- Tweak the timeout down so the test completes quickly Sqlite.execute_ conn "PRAGMA busy_timeout = 0" return . SQLiteDB . return $ conn + + +sqlite3DatabaseSchemaTests :: TestTree +sqlite3DatabaseSchemaTests = + testGroup "SQLite3 schema" + [ testCase "initialize empty database" $ + -- upgradeSchema can start from nothing and upgrade the database to any + -- defined schema version. We upgrade to the latest version because that + -- implies upgrading all the intermediate versions. It probably wouldn't + -- hurt to target every intermediate version specifically, though. I + -- think that's what SmallCheck is for? + Sqlite.withConnection ":memory:" $ \conn -> do + upgradeSchema latestVersion conn + let expected = Right latestVersion + actual <- readVersion conn + assertEqual "The recorded schema version should be the latest value" expected actual + ]