From 1e99d1bad764f0bd4b9ea1035a8939efdb9dd1b3 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Tue, 29 Oct 2019 19:07:25 -0400
Subject: [PATCH] Support either HTTP or HTTPS

---
 src/PaymentServer/Main.hs | 102 +++++++++++++++++++++++++++-----------
 1 file changed, 74 insertions(+), 28 deletions(-)

diff --git a/src/PaymentServer/Main.hs b/src/PaymentServer/Main.hs
index eb6db2c..eb988b7 100644
--- a/src/PaymentServer/Main.hs
+++ b/src/PaymentServer/Main.hs
@@ -19,12 +19,13 @@ import Data.Default
   ( def
   )
 import Network.Wai.Handler.Warp
-  ( defaultSettings
+  ( Port
+  , defaultSettings
   , setPort
+  , run
   )
 import Network.Wai.Handler.WarpTLS
-  ( TLSSettings
-  , runTLS
+  ( runTLS
   , tlsSettingsChain
   )
 import Network.Wai
@@ -66,6 +67,7 @@ import Options.Applicative
   , progDesc
   , header
   , (<**>)
+  , (<|>)
   )
 import System.Exit
   ( exitFailure
@@ -87,13 +89,55 @@ data ServerConfig = ServerConfig
   , signingKey      :: Maybe Text
   , database        :: Database
   , databasePath    :: Maybe Text
-  , httpPortNumber  :: Int
-  , certificatePath :: String
-  , chainPath       :: Maybe String
-  , keyPath         :: String
+  , endpoint        :: Endpoint
   }
   deriving (Show, Eq)
 
+-- | An Endpoint represents the configuration for a socket's IP address.
+-- There are some layering violations here.  I'm just copying Twisted
+-- endpoints at the moment.  At some point it would be great to implement a
+-- general purpose endpoint library outside of PaymentServer and without the
+-- layering violations.
+data Endpoint =
+  -- | A TCPEndpoint represents a bare TCP/IP socket address.
+  TCPEndpoint
+  { portNumber :: Port
+  }
+  |
+  -- | A TLSEndpoint represents a TCP/IP socket address which will have TLS
+  -- used over it.
+  TLSEndpoint
+  { portNumber      :: Port
+  , certificatePath :: FilePath
+  , chainPath       :: Maybe FilePath
+  , keyPath         :: FilePath
+  }
+  deriving (Show, Eq)
+
+http :: Parser Endpoint
+http = TCPEndpoint
+  <$> option auto
+  ( long "http-port"
+    <> help "Port number on which to accept HTTP connections."
+  )
+
+https :: Parser Endpoint
+https = TLSEndpoint
+  <$> option auto
+  ( long "https-port"
+    <> help "Port number on which to accept HTTPS connections." )
+  <*> strOption
+  ( long "https-certificate-path"
+    <> help "Filesystem path to the TLS certificate to use for HTTPS." )
+  <*> optional
+  ( strOption
+    ( long "https-certificate-chain-path"
+      <> help "Filesystem path to the TLS certificate chain to use for HTTPS." ) )
+  <*> strOption
+  ( long "https-key-path"
+    <> help "Filesystem path to the TLS private key to use for HTTPS." )
+
+
 sample :: Parser ServerConfig
 sample = ServerConfig
   <$> option auto
@@ -114,20 +158,7 @@ sample = ServerConfig
   ( long "database-path"
     <> help "Path to on-disk database (sqlite3 only)"
     <> showDefault ) )
-  <*> option auto
-  ( long "https-port"
-    <> help "Port number on which to accept HTTPS connections."
-    <> showDefault
-    <> value 443 )
-  <*> strOption
-  ( long "https-certificate-path"
-    <> help "Filesystem path to the TLS certificate to use for HTTPS." )
-  <*> optional ( strOption
-  ( long "https-certificate-chain-path"
-    <> help "Filesystem path to the TLS certificate chain to use for HTTPS." ) )
-  <*> strOption
-  ( long "https-key-path"
-    <> help "Filesystem path to the TLS private key to use for HTTPS." )
+  <*> (http <|> https)
 
 
 opts :: ParserInfo ServerConfig
@@ -140,15 +171,30 @@ opts = info (sample <**> helper)
 main :: IO ()
 main = do
     config <- execParser opts
-    let port = httpPortNumber config
     app <- getApp config
-    tlsSettings <- getTlsSettings config
-    putStrLn (printf "Accepting HTTPS connections on %d" port :: String)
-    runTLS tlsSettings (setPort port defaultSettings) app
+    let run = getRunner (endpoint config)
+    logEndpoint (endpoint config)
+    run app
+
+getRunner :: Endpoint -> (Application -> IO ())
+getRunner endpoint =
+  case endpoint of
+    (TCPEndpoint portNumber) ->
+      run portNumber
+    (TLSEndpoint portNumber certificatePath chainPath keyPath) ->
+      let
+        tlsSettings = tlsSettingsChain certificatePath (maybeToList chainPath) keyPath
+        settings = setPort portNumber defaultSettings
+      in
+        runTLS tlsSettings settings
 
-getTlsSettings :: ServerConfig -> IO TLSSettings
-getTlsSettings ServerConfig{ certificatePath, chainPath, keyPath } =
-  return $ tlsSettingsChain certificatePath (maybeToList chainPath) keyPath
+logEndpoint :: Endpoint -> IO ()
+logEndpoint endpoint =
+  case endpoint of
+    TCPEndpoint { portNumber } ->
+      putStrLn (printf "Accepting HTTP connections on %d" portNumber :: String)
+    TLSEndpoint { portNumber } ->
+      putStrLn (printf "Accepting HTTPS connections on %d" portNumber :: String)
 
 getApp :: ServerConfig -> IO Application
 getApp config =
-- 
GitLab