diff --git a/PaymentServer.cabal b/PaymentServer.cabal index 2746523024ff2852dc5308d95cf2c1b690f4e45d..6359ffb5741e451e2760100e14bb5b664da572ed 100644 --- a/PaymentServer.cabal +++ b/PaymentServer.cabal @@ -31,6 +31,7 @@ library , wai-extra , data-default , warp + , warp-tls , stripe-core , text , containers diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs index 93f9ea8b9115aea0ddf786de6277fadf151793b0..eb988b7160e691c4db037cefbad0b0f0709a25b7 100644 --- a/src/PaymentServer/Main.hs +++ b/src/PaymentServer/Main.hs @@ -9,6 +9,9 @@ module PaymentServer.Main import Text.Printf ( printf ) +import Data.Maybe + ( maybeToList + ) import Data.Text ( Text ) @@ -16,7 +19,17 @@ import Data.Default ( def ) import Network.Wai.Handler.Warp - ( run + ( Port + , defaultSettings + , setPort + , run + ) +import Network.Wai.Handler.WarpTLS + ( runTLS + , tlsSettingsChain + ) +import Network.Wai + ( Application ) import Network.Wai.Middleware.RequestLogger ( OutputFormat(Detailed) @@ -38,6 +51,7 @@ import PaymentServer.Server import Options.Applicative ( Parser , ParserInfo + , strOption , option , auto , str @@ -53,6 +67,7 @@ import Options.Applicative , progDesc , header , (<**>) + , (<|>) ) import System.Exit ( exitFailure @@ -70,13 +85,59 @@ data Database = deriving (Show, Eq, Ord, Read) data ServerConfig = ServerConfig - { issuer :: Issuer - , signingKey :: Maybe Text - , database :: Database - , databasePath :: Maybe Text + { issuer :: Issuer + , signingKey :: Maybe Text + , database :: Database + , databasePath :: Maybe Text + , endpoint :: Endpoint + } + deriving (Show, Eq) + +-- | An Endpoint represents the configuration for a socket's IP address. +-- There are some layering violations here. I'm just copying Twisted +-- endpoints at the moment. At some point it would be great to implement a +-- general purpose endpoint library outside of PaymentServer and without the +-- layering violations. +data Endpoint = + -- | A TCPEndpoint represents a bare TCP/IP socket address. + TCPEndpoint + { portNumber :: Port + } + | + -- | A TLSEndpoint represents a TCP/IP socket address which will have TLS + -- used over it. + TLSEndpoint + { portNumber :: Port + , certificatePath :: FilePath + , chainPath :: Maybe FilePath + , keyPath :: FilePath } deriving (Show, Eq) +http :: Parser Endpoint +http = TCPEndpoint + <$> option auto + ( long "http-port" + <> help "Port number on which to accept HTTP connections." + ) + +https :: Parser Endpoint +https = TLSEndpoint + <$> option auto + ( long "https-port" + <> help "Port number on which to accept HTTPS connections." ) + <*> strOption + ( long "https-certificate-path" + <> help "Filesystem path to the TLS certificate to use for HTTPS." ) + <*> optional + ( strOption + ( long "https-certificate-chain-path" + <> help "Filesystem path to the TLS certificate chain to use for HTTPS." ) ) + <*> strOption + ( long "https-key-path" + <> help "Filesystem path to the TLS private key to use for HTTPS." ) + + sample :: Parser ServerConfig sample = ServerConfig <$> option auto @@ -97,6 +158,8 @@ sample = ServerConfig ( long "database-path" <> help "Path to on-disk database (sqlite3 only)" <> showDefault ) ) + <*> (http <|> https) + opts :: ParserInfo ServerConfig opts = info (sample <**> helper) @@ -106,7 +169,35 @@ opts = info (sample <**> helper) ) main :: IO () -main = +main = do + config <- execParser opts + app <- getApp config + let run = getRunner (endpoint config) + logEndpoint (endpoint config) + run app + +getRunner :: Endpoint -> (Application -> IO ()) +getRunner endpoint = + case endpoint of + (TCPEndpoint portNumber) -> + run portNumber + (TLSEndpoint portNumber certificatePath chainPath keyPath) -> + let + tlsSettings = tlsSettingsChain certificatePath (maybeToList chainPath) keyPath + settings = setPort portNumber defaultSettings + in + runTLS tlsSettings settings + +logEndpoint :: Endpoint -> IO () +logEndpoint endpoint = + case endpoint of + TCPEndpoint { portNumber } -> + putStrLn (printf "Accepting HTTP connections on %d" portNumber :: String) + TLSEndpoint { portNumber } -> + putStrLn (printf "Accepting HTTPS connections on %d" portNumber :: String) + +getApp :: ServerConfig -> IO Application +getApp config = let getIssuer ServerConfig{ issuer, signingKey } = case (issuer, signingKey) of @@ -119,20 +210,17 @@ main = (SQLite3, Just path) -> Right (getDBConnection path) _ -> Left "invalid options" in do - config <- execParser opts case getIssuer config of Left err -> do print err exitFailure Right issuer -> case getDatabase config of - Left err ->do + Left err -> do print err exitFailure Right getDB -> do db <- getDB - let port = 8081 let app = paymentServerApp issuer db logger <- mkRequestLogger (def { outputFormat = Detailed True}) - putStrLn (printf "Listening on %d" port :: String) - run port $ logger app + return $ logger app