mirror of
https://github.com/chylex/SMTP-Relay.git
synced 2024-11-21 12:42:45 +01:00
179 lines
5.8 KiB
Go
179 lines
5.8 KiB
Go
package config
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Config struct {
|
|
LogFile *string
|
|
LogFormat *string
|
|
LogLevel *string
|
|
Accounts map[string]Account
|
|
AllowedNets []*net.IPNet
|
|
ListenAddrs []ProtoAddr
|
|
Hostname *string
|
|
WelcomeMsg *string
|
|
LocalCert *string
|
|
LocalKey *string
|
|
LocalForceTLS bool
|
|
ReadTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
DataTimeout time.Duration
|
|
MaxConnections *int
|
|
MaxMessageSize *int
|
|
MaxRecipients *int
|
|
Command *string
|
|
}
|
|
|
|
var (
|
|
flagset = flag.NewFlagSet("smtprelay", flag.ContinueOnError)
|
|
|
|
// config flags
|
|
logFile = flagset.String("logfile", "", "Path to logfile")
|
|
logFormat = flagset.String("log_format", "default", "Log output format")
|
|
logLevel = flagset.String("log_level", "info", "Minimum log level to output")
|
|
accountFile = flagset.String("account_file", "", "Path to file with user accounts")
|
|
allowedNetsStr = flagset.String("allowed_nets", "127.0.0.0/8 ::1/128", "Networks allowed to send mails")
|
|
listenStr = flagset.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP")
|
|
hostName = flagset.String("hostname", "localhost.localdomain", "Server hostname")
|
|
welcomeMsg = flagset.String("welcome_msg", "", "Welcome message for SMTP session")
|
|
localCert = flagset.String("local_cert", "", "SSL certificate for STARTTLS/TLS")
|
|
localKey = flagset.String("local_key", "", "SSL private key for STARTTLS/TLS")
|
|
localForceTLS = flagset.Bool("local_forcetls", false, "Force STARTTLS (needs local_cert and local_key)")
|
|
readTimeoutStr = flagset.String("read_timeout", "60s", "Socket timeout for read operations")
|
|
writeTimeoutStr = flagset.String("write_timeout", "60s", "Socket timeout for write operations")
|
|
dataTimeoutStr = flagset.String("data_timeout", "5m", "Socket timeout for DATA command")
|
|
maxConnections = flagset.Int("max_connections", 100, "Max concurrent connections, use -1 to disable")
|
|
maxMessageSize = flagset.Int("max_message_size", 10240000, "Max message size in bytes")
|
|
maxRecipients = flagset.Int("max_recipients", 100, "Max RCPT TO calls for each envelope")
|
|
command = flagset.String("command", "", "Path to pipe command")
|
|
|
|
// additional flags
|
|
_ = flagset.String("config", "", "Path to config file (ini format)")
|
|
)
|
|
|
|
func Load(configFile string) *Config {
|
|
reader, err := os.Open(configFile)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Could not open configuration file: %s\n", configFile)
|
|
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
err = IniParser(reader, flagset.Set)
|
|
_ = reader.Close()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Could not parse configuration file: %s\n", configFile)
|
|
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
allowedNets, err := parseAllowedNetworks(*allowedNetsStr)
|
|
handleInvalidConfiguration(err)
|
|
|
|
listenAddrs, err := parseListeners(*listenStr)
|
|
handleInvalidConfiguration(err)
|
|
|
|
readTimeout, err := parseDuration("read_timeout", *readTimeoutStr)
|
|
handleInvalidConfiguration(err)
|
|
|
|
writeTimeout, err := parseDuration("write_timeout", *writeTimeoutStr)
|
|
handleInvalidConfiguration(err)
|
|
|
|
dataTimeout, err := parseDuration("data_timeout", *dataTimeoutStr)
|
|
handleInvalidConfiguration(err)
|
|
|
|
accounts, err := ReadAccountsFromFile(*accountFile)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Could not load account file: %s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
return &Config{
|
|
LogFile: logFile,
|
|
LogFormat: logFormat,
|
|
LogLevel: logLevel,
|
|
Accounts: accounts,
|
|
AllowedNets: allowedNets,
|
|
ListenAddrs: listenAddrs,
|
|
Hostname: hostName,
|
|
WelcomeMsg: welcomeMsg,
|
|
LocalCert: localCert,
|
|
LocalKey: localKey,
|
|
LocalForceTLS: *localForceTLS,
|
|
ReadTimeout: readTimeout,
|
|
WriteTimeout: writeTimeout,
|
|
DataTimeout: dataTimeout,
|
|
MaxConnections: maxConnections,
|
|
MaxMessageSize: maxMessageSize,
|
|
MaxRecipients: maxRecipients,
|
|
Command: command,
|
|
}
|
|
}
|
|
|
|
func handleInvalidConfiguration(err error) {
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Invalid configuration: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// Split a string and ignore empty results
|
|
// https://stackoverflow.com/a/46798310/119527
|
|
func splitstr(s string, sep rune) []string {
|
|
return strings.FieldsFunc(s, func(c rune) bool { return c == sep })
|
|
}
|
|
|
|
func parseAllowedNetworks(allowedNetsStr string) ([]*net.IPNet, error) {
|
|
var allowedNets []*net.IPNet
|
|
|
|
for _, netstr := range splitstr(allowedNetsStr, ' ') {
|
|
baseIP, allowedNet, err := net.ParseCIDR(netstr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid CIDR notation in allowed_nets: %s\nError: %s", netstr, err)
|
|
}
|
|
|
|
// Reject any network specification where any host bits are set,
|
|
// meaning the address refers to a host and not a network.
|
|
if !allowedNet.IP.Equal(baseIP) {
|
|
return nil, fmt.Errorf("invalid network in allowed_nets (host bits set): %s\nAllowed net: %s", netstr, allowedNet)
|
|
}
|
|
|
|
allowedNets = append(allowedNets, allowedNet)
|
|
}
|
|
|
|
return allowedNets, nil
|
|
}
|
|
|
|
func parseListeners(listenStr string) ([]ProtoAddr, error) {
|
|
var listenAddrs []ProtoAddr
|
|
|
|
for _, listenAddr := range strings.Split(listenStr, " ") {
|
|
pa := splitProto(listenAddr)
|
|
|
|
if pa.Protocol == "" {
|
|
return nil, fmt.Errorf("local authentication not allowed with non-TLS listener")
|
|
}
|
|
|
|
listenAddrs = append(listenAddrs, pa)
|
|
}
|
|
|
|
return listenAddrs, nil
|
|
}
|
|
|
|
func parseDuration(configurationKey string, timeoutStr string) (time.Duration, error) {
|
|
timeout, err := time.ParseDuration(timeoutStr)
|
|
if err != nil {
|
|
return timeout, fmt.Errorf("invalid duration string in %s: %s\nError: %w", configurationKey, timeoutStr, err)
|
|
} else if timeout.Seconds() < 1 {
|
|
return timeout, fmt.Errorf("duration of %s must be at least one second", configurationKey)
|
|
} else {
|
|
return timeout, nil
|
|
}
|
|
}
|