diff --git a/.circleci/config.yml b/.circleci/config.yml index dd2666b3265e1386ac519609de76357d4747d81c..c628d7449f0d1eff3227518db9f28757f8fbc16d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -146,6 +146,7 @@ jobs: --no-terminal \ --haddock \ --haddock-internal \ + --test \ --no-haddock-deps" nix-shell shell.nix --run "$BUILD" diff --git a/PaymentServer.cabal b/PaymentServer.cabal index b6cd94e9b7bb2c86a72be8778c7f1c385d2ae295..0c5286c1282dc7f33ba50e06852314b5363a130f 100644 --- a/PaymentServer.cabal +++ b/PaymentServer.cabal @@ -26,10 +26,13 @@ library , optparse-applicative , aeson , bytestring + , utf8-string , servant , servant-server + , http-types , wai , wai-extra + , wai-cors , data-default , warp , warp-tls @@ -60,6 +63,19 @@ executable PaymentServer-generate-key , PaymentServer default-language: Haskell2010 +test-suite PaymentServer-tests + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Spec.hs + other-modules: Persistence + build-depends: base + , text + , tasty + , tasty-hunit + , directory + , PaymentServer + default-language: Haskell2010 + source-repository head type: git location: https://github.com/privatestorageio/PaymentServer diff --git a/nix/PaymentServer.nix b/nix/PaymentServer.nix index b70939de79fd499051ef4b007e3704067d81e9c1..0816f96d9959f118ef572b0673ad83abbfd210e1 100644 --- a/nix/PaymentServer.nix +++ b/nix/PaymentServer.nix @@ -58,12 +58,15 @@ in { system, compiler, flags, pkgs, hsPkgs, pkgconfPkgs, ... }: depends = [ (hsPkgs."base" or (buildDepError "base")) (hsPkgs."bytestring" or (buildDepError "bytestring")) + (hsPkgs."utf8-string" or (buildDepError "utf8-string")) (hsPkgs."optparse-applicative" or (buildDepError "optparse-applicative")) (hsPkgs."aeson" or (buildDepError "aeson")) (hsPkgs."servant" or (buildDepError "servant")) (hsPkgs."servant-server" or (buildDepError "servant-server")) + (hsPkgs."http-types" or (buildDepError "http-types")) (hsPkgs."wai" or (buildDepError "wai")) (hsPkgs."wai-extra" or (buildDepError "wai-extra")) + (hsPkgs."wai-cors" or (buildDepError "wai-cors")) (hsPkgs."data-default" or (buildDepError "data-default")) (hsPkgs."warp" or (buildDepError "warp")) (hsPkgs."warp-tls" or (buildDepError "warp-tls")) diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs index 941d4c245cc13a5b6f9a76aaaccf75cf2a405122..07c352136d1a8241e89a79fd235ee231dbcbc96d 100644 --- a/src/PaymentServer/Main.hs +++ b/src/PaymentServer/Main.hs @@ -6,6 +6,9 @@ module PaymentServer.Main ( main ) where +import Control.Exception.Base + ( SomeException + ) import Text.Printf ( printf ) @@ -18,11 +21,16 @@ import Data.Text import Data.Default ( def ) +import Network.HTTP.Types.Status + ( status500 + ) import Network.Wai.Handler.Warp ( Port , defaultSettings , setPort - , run + , setOnException + , setOnExceptionResponse + , runSettings ) import Network.Wai.Handler.WarpTLS ( runTLS @@ -30,6 +38,12 @@ import Network.Wai.Handler.WarpTLS ) import Network.Wai ( Application + , Request + , Response + , responseLBS + ) +import Network.Wai.Middleware.Cors + ( Origin ) import Network.Wai.Middleware.RequestLogger ( OutputFormat(Detailed) @@ -55,6 +69,7 @@ import Options.Applicative , option , auto , str + , many , optional , long , help @@ -75,6 +90,7 @@ import System.Exit import Data.Semigroup ((<>)) import qualified Data.Text.IO as TIO import qualified Data.ByteString as B +import qualified Data.ByteString.Lazy.UTF8 as LBS data Issuer = Trivial @@ -93,6 +109,7 @@ data ServerConfig = ServerConfig , databasePath :: Maybe Text , endpoint :: Endpoint , stripeKeyPath :: FilePath + , corsOrigins :: [Origin] } deriving (Show, Eq) @@ -165,6 +182,9 @@ sample = ServerConfig <*> option str ( long "stripe-key-path" <> help "Path to Stripe Secret key" ) + <*> many ( option str + ( long "cors-origin" + <> help "An allowed `Origin` for the purposes of CORS (zero or more)." ) ) opts :: ParserInfo ServerConfig opts = info (sample <**> helper) @@ -181,17 +201,33 @@ main = do logEndpoint (endpoint config) run app +getPortNumber (TCPEndpoint portNumber) = portNumber +getPortNumber (TLSEndpoint portNumber _ _ _) = portNumber + getRunner :: Endpoint -> (Application -> IO ()) getRunner endpoint = - case endpoint of - (TCPEndpoint portNumber) -> - run portNumber - (TLSEndpoint portNumber certificatePath chainPath keyPath) -> - let - tlsSettings = tlsSettingsChain certificatePath (maybeToList chainPath) keyPath - settings = setPort portNumber defaultSettings - in - runTLS tlsSettings settings + let + onException :: Maybe Request -> SomeException -> IO () + onException _ exc = do + print "onException" + print exc + return () + onExceptionResponse :: SomeException -> Response + onExceptionResponse = (responseLBS status500 []) . LBS.fromString . ("exception: " ++) . show + settings = + setPort (getPortNumber endpoint) . + setOnException onException . + setOnExceptionResponse onExceptionResponse $ + defaultSettings + in + case endpoint of + (TCPEndpoint _) -> + runSettings settings + (TLSEndpoint _ certificatePath chainPath keyPath) -> + let + tlsSettings = tlsSettingsChain certificatePath (maybeToList chainPath) keyPath + in + runTLS tlsSettings settings logEndpoint :: Endpoint -> IO () logEndpoint endpoint = @@ -230,6 +266,8 @@ getApp config = Right getDB -> do db <- getDB key <- B.readFile (stripeKeyPath config) - let app = paymentServerApp key issuer db + let + origins = corsOrigins config + app = paymentServerApp origins key issuer db logger <- mkRequestLogger (def { outputFormat = Detailed True}) return $ logger app diff --git a/src/PaymentServer/Persistence.hs b/src/PaymentServer/Persistence.hs index edefd3ea9d8ba24470d689bb27d59abee665159e..324a8880e5ca8f6f2af12ee5e8435742f2c23acc 100644 --- a/src/PaymentServer/Persistence.hs +++ b/src/PaymentServer/Persistence.hs @@ -4,12 +4,20 @@ module PaymentServer.Persistence ( Voucher , Fingerprint , RedeemError(NotPaid, AlreadyRedeemed) + , PaymentError(AlreadyPaid, PaymentFailed) , VoucherDatabase(payForVoucher, redeemVoucher) , VoucherDatabaseState(MemoryDB, SQLiteDB) , memory , getDBConnection ) where +import Control.Exception + ( Exception + , throwIO + , catch + , try + ) + import Data.Text ( Text , unpack @@ -20,6 +28,7 @@ import Data.IORef ( IORef , newIORef , modifyIORef + , atomicModifyIORef' , readIORef ) import qualified Database.SQLite.Simple as Sqlite @@ -36,6 +45,16 @@ import Data.Maybe -- voucher itself. type Voucher = Text +-- | Reasons that a voucher cannot be paid for. +data PaymentError = + -- | The voucher has already been paid for. + AlreadyPaid + -- | The payment transaction has failed. + | PaymentFailed + deriving (Show, Eq) + +instance Exception PaymentError + -- | Reasons that a voucher cannot be redeemed. data RedeemError = -- | The voucher has not been paid for. @@ -61,7 +80,13 @@ class VoucherDatabase d where payForVoucher :: d -- ^ The database in which to record the change -> Voucher -- ^ A voucher which should be considered paid - -> IO () + -> IO a -- ^ An operation which completes the payment. This is + -- evaluated in the context of a database transaction so + -- that if it fails the voucher is not marked as paid in + -- the database but if it succeeds the database state is + -- not confused by a competing transaction run around the + -- same time. + -> IO a -- | Attempt to redeem a voucher. If it has not been redeemed before or it -- has been redeemed with the same fingerprint, the redemption succeeds. @@ -89,11 +114,21 @@ data VoucherDatabaseState = | SQLiteDB { conn :: Sqlite.Connection } instance VoucherDatabase VoucherDatabaseState where - payForVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher = - modifyIORef paid (Set.insert voucher) + payForVoucher MemoryDB{ paid = paidRef, redeemed = redeemed } voucher pay = do + -- Surely far from ideal... + paid <- readIORef paidRef + if Set.member voucher paid + -- Avoid processing the payment if the voucher is already paid. + then throwIO AlreadyPaid + else + do + result <- pay + -- Only modify the paid set if the payment succeeds. + modifyIORef paidRef (Set.insert voucher) + return result - payForVoucher SQLiteDB{ conn = conn } voucher = - insertVoucher conn voucher + payForVoucher SQLiteDB{ conn = conn } voucher pay = + insertVoucher conn voucher pay redeemVoucher MemoryDB{ paid = paid, redeemed = redeemed } voucher fingerprint = do unpaid <- Set.notMember voucher <$> readIORef paid @@ -157,9 +192,32 @@ getVoucherFingerprint dbConn voucher = listToMaybe <$> Sqlite.query dbConn "SELECT redeemed.fingerprint FROM vouchers INNER JOIN redeemed ON vouchers.id = redeemed.voucher_id AND vouchers.name = ?" (Sqlite.Only voucher) -- | Mark the given voucher as paid in the database. -insertVoucher :: Sqlite.Connection -> Voucher -> IO () -insertVoucher dbConn voucher = - Sqlite.execute dbConn "INSERT INTO vouchers (name) VALUES (?)" (Sqlite.Only voucher) +insertVoucher :: Sqlite.Connection -> Voucher -> IO a -> IO a +insertVoucher dbConn voucher pay = + -- Begin an immediate transaction so that it includes the IO. The first + -- thing we do is execute our one and only statement so the transaction is + -- immediate anyway but it doesn't hurt to be explicit. + Sqlite.withImmediateTransaction dbConn $ + do + -- Vouchers must be unique in this table. This might fail if someone is + -- trying to double-pay for a voucher. In this case, we won't ever + -- finalize the payment. + Sqlite.execute dbConn "INSERT INTO vouchers (name) VALUES (?)" (Sqlite.Only voucher) + `catch` handleConstraintError + -- If we managed to insert the voucher, try to finalize the payment. If + -- this succeeds, the transaction is committed and we expect the payment + -- system to actually be moving money around. If it fails, we expect the + -- payment system *not* to move money around and the voucher should not be + -- marked as paid. The transaction will be rolled back so, indeed, it + -- won't be marked thus. + pay + + where + handleConstraintError Sqlite.SQLError { Sqlite.sqlError = Sqlite.ErrorConstraint } = + throwIO AlreadyPaid + handleConstraintError e = + throwIO e + -- | Mark the given voucher as having been redeemed (with the given -- fingerprint) in the database. diff --git a/src/PaymentServer/Processors/Stripe.hs b/src/PaymentServer/Processors/Stripe.hs index 57137eda62fd1e9437607a15f47546f93b1cb603..25a63fe31ded132044a8e051177a9db2a979b278 100644 --- a/src/PaymentServer/Processors/Stripe.hs +++ b/src/PaymentServer/Processors/Stripe.hs @@ -15,6 +15,10 @@ import Control.Monad.IO.Class import Control.Monad ( mzero ) +import Control.Exception + ( try + , throwIO + ) import Data.ByteString ( ByteString ) @@ -22,6 +26,7 @@ import Data.Text ( Text , unpack ) +import qualified Data.Map as Map import Text.Read ( readMaybe ) @@ -30,14 +35,16 @@ import Data.Aeson , FromJSON(parseJSON) , Value(Object) , object + , encode , (.:) + , (.=) ) import Servant ( Server , Handler , err400 , err500 - , ServerError(errBody) + , ServerError(ServerError, errHTTPCode, errBody, errHeaders, errReasonPhrase) , throwError ) import Servant.API @@ -77,6 +84,7 @@ import Web.Stripe import PaymentServer.Persistence ( Voucher , VoucherDatabase(payForVoucher) + , PaymentError(AlreadyPaid, PaymentFailed) ) type StripeSecretKey = ByteString @@ -84,12 +92,11 @@ type StripeSecretKey = ByteString data Acknowledgement = Ok instance ToJSON Acknowledgement where - toJSON Ok = object [] - -type StripeAPI = WebhookAPI - :<|> ChargesAPI + toJSON Ok = object + [ "success" .= True + ] -type WebhookAPI = "webhook" :> ReqBody '[JSON] Event :> Post '[JSON] Acknowledgement +type StripeAPI = ChargesAPI -- | getVoucher finds the metadata item with the key `"Voucher"` and returns -- the corresponding value, or Nothing. @@ -99,31 +106,7 @@ getVoucher (MetaData (("Voucher", value):xs)) = Just value getVoucher (MetaData (x:xs)) = getVoucher (MetaData xs) stripeServer :: VoucherDatabase d => StripeSecretKey -> d -> Server StripeAPI -stripeServer key d = webhook d - :<|> charge d key - --- | Process charge succeeded events -webhook :: VoucherDatabase d => d -> Event -> Handler Acknowledgement -webhook d Event{eventId=Just (EventId eventId), eventType=ChargeSucceededEvent, eventData=(ChargeEvent charge)} = - case getVoucher $ chargeMetaData charge of - Nothing -> - -- TODO: Record the eventId somewhere. In all cases where we don't - -- associate the value of the charge with something in our system, we - -- probably need enough information to issue a refund. We're early - -- enough in the system here that refunds are possible and not even - -- particularly difficult. - return Ok - Just v -> do - -- TODO: What if it is a duplicate payment? payForVoucher should be - -- able to indicate error I guess. - () <- liftIO $ payForVoucher d v - return Ok - --- Disregard anything else - but return success so that Stripe doesn't retry. -webhook d _ = - -- TODO: Record the eventId somewhere. - return Ok - +stripeServer key d = charge d key -- | Browser facing API that takes token, voucher and a few other information -- and calls stripe charges API. If payment succeeds, then the voucher is stored @@ -151,30 +134,63 @@ instance FromJSON Charges where -- and if the Charge is okay, then set the voucher as "paid" in the database. charge :: VoucherDatabase d => d -> StripeSecretKey -> Charges -> Handler Acknowledgement charge d key (Charges token voucher amount currency) = do - let config = StripeConfig (StripeKey key) Nothing - tokenId = TokenId token currency' <- getCurrency currency - result <- liftIO $ stripe config $ - createCharge (Amount amount) currency' - -&- tokenId - -&- MetaData [("Voucher", voucher)] + result <- liftIO (try (payForVoucher d voucher (completeStripeCharge currency'))) case result of + Left AlreadyPaid -> + throwError voucherAlreadyPaid + Left PaymentFailed -> + throwError stripeChargeFailed Right Charge { chargeMetaData = metadata } -> - -- verify that we are getting the same metadata that we sent. - case metadata of - MetaData [("Voucher", v)] -> - if v == voucher - then - do - liftIO $ payForVoucher d voucher - return Ok - else - throwError err500 { errBody = "Voucher code mismatch" } - _ -> throwError err400 { errBody = "Voucher code not found" } - Left StripeError {} -> throwError err400 { errBody = "Stripe charge didn't succeed" } + checkVoucherMetadata metadata where getCurrency :: Text -> Handler Currency getCurrency maybeCurrency = case readMaybe (unpack currency) of Just currency' -> return currency' - Nothing -> throwError err400 { errBody = "Invalid currency specified" } + Nothing -> throwError unsupportedCurrency + + config = StripeConfig (StripeKey key) Nothing + tokenId = TokenId token + completeStripeCharge currency' = do + result <- stripe config $ + createCharge (Amount amount) currency' + -&- tokenId + -&- MetaData [("Voucher", voucher)] + case result of + Left StripeError {} -> throwIO PaymentFailed + Right result -> return result + + checkVoucherMetadata :: MetaData -> Handler Acknowledgement + checkVoucherMetadata metadata = + -- verify that we are getting the same metadata that we sent. + case metadata of + MetaData [("Voucher", v)] -> + if v == voucher + then return Ok + else throwError voucherCodeMismatch + _ -> throwError voucherCodeNotFound + + voucherCodeMismatch = jsonErr 500 "Voucher code mismatch" + unsupportedCurrency = jsonErr 400 "Invalid currency specified" + voucherCodeNotFound = jsonErr 400 "Voucher code not found" + stripeChargeFailed = jsonErr 400 "Stripe charge didn't succeed" + voucherAlreadyPaid = jsonErr 400 "Payment for voucher already supplied" + + jsonErr httpCode reason = ServerError + { errHTTPCode = httpCode + , errReasonPhrase = "" + , errBody = encode $ Failure reason + , errHeaders = [("content-type", "application/json")] + } + + +data Failure = Failure Text + deriving (Show, Eq) + + +instance ToJSON Failure where + toJSON (Failure reason) = object + [ "success" .= False + , "reason" .= reason + ] diff --git a/src/PaymentServer/Server.hs b/src/PaymentServer/Server.hs index feef0b051330a58568fa88ded64667b74719049a..ee07b2cbc38e2a75e6345cf033194b2b81d10405 100644 --- a/src/PaymentServer/Server.hs +++ b/src/PaymentServer/Server.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE OverloadedStrings #-} -- | This module exposes a Servant-based Network.Wai server for payment -- interactions. @@ -7,6 +8,12 @@ module PaymentServer.Server ( paymentServerApp ) where +import Network.Wai.Middleware.Cors + ( Origin + , CorsResourcePolicy(corsOrigins, corsMethods, corsRequestHeaders) + , simpleCorsResourcePolicy + , cors + ) import Servant ( Proxy(Proxy) , Server @@ -47,5 +54,22 @@ paymentServerAPI = Proxy -- | Create a Servant Application which serves the payment server API using -- the given database. -paymentServerApp :: VoucherDatabase d => StripeSecretKey -> Issuer -> d -> Application -paymentServerApp key issuer = serve paymentServerAPI . paymentServer key issuer +paymentServerApp + :: VoucherDatabase d + => [Origin] -- ^ A list of CORS Origins to accept. + -> StripeSecretKey + -> Issuer + -> d + -> Application +paymentServerApp corsOrigins key issuer = + let + app = serve paymentServerAPI . paymentServer key issuer + withCredentials = False + corsResourcePolicy = simpleCorsResourcePolicy + { corsOrigins = Just (corsOrigins, withCredentials) + , corsMethods = [ "POST" ] + , corsRequestHeaders = [ "Content-Type" ] + } + cors' = cors (const $ Just corsResourcePolicy) + in + cors' . app diff --git a/test/Persistence.hs b/test/Persistence.hs new file mode 100644 index 0000000000000000000000000000000000000000..f425e2fd06738ca71739e91fbf8870038f8bd6b9 --- /dev/null +++ b/test/Persistence.hs @@ -0,0 +1,128 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | Tests related to PaymentServer.Persistence and the persistence system in +-- general. + +module Persistence + ( tests + ) where + +import qualified Data.Text as Text + +import Control.Exception + ( Exception + , throwIO + , try + ) + +import Test.Tasty + ( TestTree + , testGroup + ) +import Test.Tasty.HUnit + ( testCase + , assertEqual + ) + +import System.IO + ( openTempFile + ) +import System.Directory + ( getTemporaryDirectory + ) + +import PaymentServer.Persistence + ( Voucher + , Fingerprint + , RedeemError(NotPaid, AlreadyRedeemed) + , PaymentError(AlreadyPaid) + , VoucherDatabase(payForVoucher, redeemVoucher) + , memory + , getDBConnection + ) + +data ArbitraryException = ArbitraryException + deriving (Show, Eq) + +instance Exception ArbitraryException + +tests :: TestTree +tests = testGroup "Persistence" + [ memoryDatabaseVoucherPaymentTests + , sqlite3DatabaseVoucherPaymentTests + ] + +-- Some dummy values that should be replaced by the use of QuickCheck. +voucher = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" +fingerprint = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + +-- Mock a successful payment. +paySuccessfully :: IO () +paySuccessfully = return () + +-- Mock a failed payment. +failPayment :: IO () +failPayment = throwIO ArbitraryException + +-- | Create a group of tests related to voucher payment and redemption. +makeVoucherPaymentTests + :: VoucherDatabase d + => Text.Text -- ^ A distinctive identifier for this group's label. + -> IO d -- ^ An operation that creates a new, empty voucher + -- database. + -> TestTree +makeVoucherPaymentTests label makeDatabase = + testGroup ("voucher payments (" ++ Text.unpack label ++ ")") + [ testCase "not paid for" $ do + db <- makeDatabase + result <- redeemVoucher db voucher fingerprint + assertEqual "redeeming unpaid voucher" (Left NotPaid) result + , testCase "paid for" $ do + db <- makeDatabase + () <- payForVoucher db voucher paySuccessfully + result <- redeemVoucher db voucher fingerprint + assertEqual "redeeming paid voucher" (Right ()) result + , testCase "allowed double redemption" $ do + db <- makeDatabase + () <- payForVoucher db voucher paySuccessfully + let redeem = redeemVoucher db voucher fingerprint + first <- redeem + second <- redeem + assertEqual "redeeming paid voucher" (Right ()) first + assertEqual "re-redeeming paid voucher" (Right ()) second + , testCase "disallowed double redemption" $ do + db <- makeDatabase + () <- payForVoucher db voucher paySuccessfully + let redeem = redeemVoucher db voucher + first <- redeem fingerprint + second <- redeem (Text.cons 'a' $ Text.tail fingerprint) + assertEqual "redeeming paid voucher" (Right ()) first + assertEqual "re-redeeming paid voucher" (Left AlreadyRedeemed) second + , testCase "pay with exception" $ do + db <- makeDatabase + payResult <- try $ payForVoucher db voucher failPayment + assertEqual "failing a payment for a voucher" (Left ArbitraryException) payResult + result <- redeemVoucher db voucher fingerprint + assertEqual "redeeming voucher with failed payment" (Left NotPaid) result + , testCase "disallowed double payment" $ do + db <- makeDatabase + let pay = payForVoucher db voucher paySuccessfully + () <- pay + payResult <- try pay + assertEqual "double-paying for a voucher" (Left AlreadyPaid) payResult + redeemResult <- redeemVoucher db voucher fingerprint + assertEqual "redeeming double-paid voucher" (Right ()) redeemResult + ] + +-- | Instantiate the persistence tests for the memory backend. +memoryDatabaseVoucherPaymentTests :: TestTree +memoryDatabaseVoucherPaymentTests = makeVoucherPaymentTests "memory" memory + +-- | Instantiate the persistence tests for the sqlite3 backend. +sqlite3DatabaseVoucherPaymentTests :: TestTree +sqlite3DatabaseVoucherPaymentTests = + makeVoucherPaymentTests "sqlite3" $ + do + tempdir <- getTemporaryDirectory + (path, handle) <- openTempFile tempdir "voucher-.db" + getDBConnection $ Text.pack path diff --git a/test/Spec.hs b/test/Spec.hs new file mode 100644 index 0000000000000000000000000000000000000000..5821f6551290e304aabbdcad998dcc98a44c23ff --- /dev/null +++ b/test/Spec.hs @@ -0,0 +1,20 @@ +-- | Collect all of the various test groups into a single tree. + +module Main + ( main + ) where + +import Test.Tasty + ( TestTree + , testGroup + , defaultMain + ) + +import qualified Persistence + +tests :: TestTree +tests = testGroup "Tests" + [ Persistence.tests + ] + +main = defaultMain tests