diff --git a/config/config.go b/config/config.go index 0c888b3..612c33c 100644 --- a/config/config.go +++ b/config/config.go @@ -87,6 +87,12 @@ type SnapshotConfig struct { MaxNum int `toml:"max_num"` } +type TLS struct { + Enabled bool `toml:"enabled"` + Certificate string `toml:"certificate"` + Key string `toml:"key"` +} + type AuthMethod func(c *Config, password string) bool type Config struct { @@ -135,6 +141,9 @@ type Config struct { ConnKeepaliveInterval int `toml:"conn_keepalive_interval"` TTLCheckInterval int `toml:"ttl_check_interval"` + + //tls config + TLS TLS `toml:"tls"` } func NewConfigWithFile(fileName string) (*Config, error) { diff --git a/config/config.toml b/config/config.toml index d4ea545..8100ab1 100644 --- a/config/config.toml +++ b/config/config.toml @@ -161,3 +161,8 @@ path = "" # Reserve newest max_num snapshot dump files max_num = 1 + +[tls] +enabled = true +certificate = "test.crt" +key = "test.key" \ No newline at end of file diff --git a/server/app.go b/server/app.go index 7047f5c..0b7d517 100644 --- a/server/app.go +++ b/server/app.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "crypto/tls" "github.com/siddontang/goredis" "github.com/siddontang/ledisdb/config" "github.com/siddontang/ledisdb/ledis" @@ -61,6 +62,27 @@ func netType(s string) string { return "tcp" } +func tlsConfig(c *config.TLS) (*tls.Config, error) { + crt, err := tls.LoadX509KeyPair(c.Certificate, c.Key) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{ + crt, + }, + }, nil +} + +func listen(netType, laddr string, tlsCfg *tls.Config) (net.Listener, error) { + if tlsCfg != nil { + return tls.Listen(netType, laddr, tlsCfg) + } + + return net.Listen(netType, laddr) +} + func NewApp(cfg *config.Config) (*App, error) { if len(cfg.DataDir) == 0 { println("use default datadir %s", config.DefaultDataDir) @@ -89,10 +111,18 @@ func NewApp(cfg *config.Config) (*App, error) { return nil, err } + var tlsCfg *tls.Config + if cfg.TLS.Enabled { + tlsCfg, err = tlsConfig(&cfg.TLS) + if err != nil { + return nil, err + } + } + if cfg.Addr != "" { addrNetType := netType(cfg.Addr) - if app.listener, err = net.Listen(addrNetType, cfg.Addr); err != nil { + if app.listener, err = listen(addrNetType, cfg.Addr, tlsCfg); err != nil { return nil, err } @@ -115,7 +145,7 @@ func NewApp(cfg *config.Config) (*App, error) { } if len(cfg.HttpAddr) > 0 { - if app.httpListener, err = net.Listen(netType(cfg.HttpAddr), cfg.HttpAddr); err != nil { + if app.httpListener, err = listen(netType(cfg.HttpAddr), cfg.HttpAddr, tlsCfg); err != nil { return nil, err } }