From fe3b43e7732741f01052491b298058c4beece451 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Fri, 22 Nov 2019 09:29:12 -0500
Subject: [PATCH] Use wai-cors to apply a CORS policy across the whole API

---
 src/PaymentServer/Main.hs   | 12 +++++++++++-
 src/PaymentServer/Server.hs | 26 +++++++++++++++++++++-----
 2 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs
index 941d4c2..0c0dd73 100644
--- a/src/PaymentServer/Main.hs
+++ b/src/PaymentServer/Main.hs
@@ -31,6 +31,9 @@ import Network.Wai.Handler.WarpTLS
 import Network.Wai
   ( Application
   )
+import Network.Wai.Middleware.Cors
+  ( Origin
+  )
 import Network.Wai.Middleware.RequestLogger
   ( OutputFormat(Detailed)
   , outputFormat
@@ -55,6 +58,7 @@ import Options.Applicative
   , option
   , auto
   , str
+  , many
   , optional
   , long
   , help
@@ -93,6 +97,7 @@ data ServerConfig = ServerConfig
   , databasePath    :: Maybe Text
   , endpoint        :: Endpoint
   , stripeKeyPath   :: FilePath
+  , corsOrigins     :: [Origin]
   }
   deriving (Show, Eq)
 
@@ -165,6 +170,9 @@ sample = ServerConfig
   <*> option str
   ( long "stripe-key-path"
     <> help "Path to Stripe Secret key" )
+  <*> many ( option str
+             ( long "cors-origin"
+             <> help "An allowed `Origin` for the purposes of CORS (zero or more)." ) )
 
 opts :: ParserInfo ServerConfig
 opts = info (sample <**> helper)
@@ -230,6 +238,8 @@ getApp config =
           Right getDB -> do
             db <- getDB
             key <- B.readFile (stripeKeyPath config)
-            let app = paymentServerApp key issuer db
+            let
+              origins = corsOrigins config
+              app = paymentServerApp origins key issuer db
             logger <- mkRequestLogger (def { outputFormat = Detailed True})
             return $ logger app
diff --git a/src/PaymentServer/Server.hs b/src/PaymentServer/Server.hs
index 4478a35..ee07b2c 100644
--- a/src/PaymentServer/Server.hs
+++ b/src/PaymentServer/Server.hs
@@ -1,5 +1,6 @@
 {-# LANGUAGE DataKinds #-}
 {-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE OverloadedStrings #-}
 
 -- | This module exposes a Servant-based Network.Wai server for payment
 -- interactions.
@@ -8,7 +9,10 @@ module PaymentServer.Server
   ) where
 
 import Network.Wai.Middleware.Cors
-  ( simpleCors
+  ( Origin
+  , CorsResourcePolicy(corsOrigins, corsMethods, corsRequestHeaders)
+  , simpleCorsResourcePolicy
+  , cors
   )
 import Servant
   ( Proxy(Proxy)
@@ -50,10 +54,22 @@ paymentServerAPI = Proxy
 
 -- | Create a Servant Application which serves the payment server API using
 -- the given database.
-paymentServerApp :: VoucherDatabase d => StripeSecretKey -> Issuer -> d -> Application
-paymentServerApp key issuer =
+paymentServerApp
+  :: VoucherDatabase d
+  => [Origin]              -- ^ A list of CORS Origins to accept.
+  -> StripeSecretKey
+  -> Issuer
+  -> d
+  -> Application
+paymentServerApp corsOrigins key issuer =
   let
     app = serve paymentServerAPI . paymentServer key issuer
-    cors = simpleCors
+    withCredentials = False
+    corsResourcePolicy = simpleCorsResourcePolicy
+                         { corsOrigins = Just (corsOrigins, withCredentials)
+                         , corsMethods = [ "POST" ]
+                         , corsRequestHeaders = [ "Content-Type" ]
+                         }
+    cors' = cors (const $ Just corsResourcePolicy)
   in
-    cors . app
+    cors' . app
-- 
GitLab