TLS support (#302)

This commit is contained in:
Muhamad Azmy 2017-06-12 02:05:42 +02:00 committed by siddontang
parent 74d1db9b35
commit bc27c99e3c
3 changed files with 46 additions and 2 deletions

View File

@ -87,6 +87,12 @@ type SnapshotConfig struct {
MaxNum int `toml:"max_num"` MaxNum int `toml:"max_num"`
} }
type TLS struct {
Enabled bool `toml:"enabled"`
Certificate string `toml:"certificate"`
Key string `toml:"key"`
}
type AuthMethod func(c *Config, password string) bool type AuthMethod func(c *Config, password string) bool
type Config struct { type Config struct {
@ -135,6 +141,9 @@ type Config struct {
ConnKeepaliveInterval int `toml:"conn_keepalive_interval"` ConnKeepaliveInterval int `toml:"conn_keepalive_interval"`
TTLCheckInterval int `toml:"ttl_check_interval"` TTLCheckInterval int `toml:"ttl_check_interval"`
//tls config
TLS TLS `toml:"tls"`
} }
func NewConfigWithFile(fileName string) (*Config, error) { func NewConfigWithFile(fileName string) (*Config, error) {

View File

@ -161,3 +161,8 @@ path = ""
# Reserve newest max_num snapshot dump files # Reserve newest max_num snapshot dump files
max_num = 1 max_num = 1
[tls]
enabled = true
certificate = "test.crt"
key = "test.key"

View File

@ -10,6 +10,7 @@ import (
"strings" "strings"
"sync" "sync"
"crypto/tls"
"github.com/siddontang/goredis" "github.com/siddontang/goredis"
"github.com/siddontang/ledisdb/config" "github.com/siddontang/ledisdb/config"
"github.com/siddontang/ledisdb/ledis" "github.com/siddontang/ledisdb/ledis"
@ -61,6 +62,27 @@ func netType(s string) string {
return "tcp" return "tcp"
} }
func tlsConfig(c *config.TLS) (*tls.Config, error) {
crt, err := tls.LoadX509KeyPair(c.Certificate, c.Key)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{
crt,
},
}, nil
}
func listen(netType, laddr string, tlsCfg *tls.Config) (net.Listener, error) {
if tlsCfg != nil {
return tls.Listen(netType, laddr, tlsCfg)
}
return net.Listen(netType, laddr)
}
func NewApp(cfg *config.Config) (*App, error) { func NewApp(cfg *config.Config) (*App, error) {
if len(cfg.DataDir) == 0 { if len(cfg.DataDir) == 0 {
println("use default datadir %s", config.DefaultDataDir) println("use default datadir %s", config.DefaultDataDir)
@ -89,10 +111,18 @@ func NewApp(cfg *config.Config) (*App, error) {
return nil, err return nil, err
} }
var tlsCfg *tls.Config
if cfg.TLS.Enabled {
tlsCfg, err = tlsConfig(&cfg.TLS)
if err != nil {
return nil, err
}
}
if cfg.Addr != "" { if cfg.Addr != "" {
addrNetType := netType(cfg.Addr) addrNetType := netType(cfg.Addr)
if app.listener, err = net.Listen(addrNetType, cfg.Addr); err != nil { if app.listener, err = listen(addrNetType, cfg.Addr, tlsCfg); err != nil {
return nil, err return nil, err
} }
@ -115,7 +145,7 @@ func NewApp(cfg *config.Config) (*App, error) {
} }
if len(cfg.HttpAddr) > 0 { if len(cfg.HttpAddr) > 0 {
if app.httpListener, err = net.Listen(netType(cfg.HttpAddr), cfg.HttpAddr); err != nil { if app.httpListener, err = listen(netType(cfg.HttpAddr), cfg.HttpAddr, tlsCfg); err != nil {
return nil, err return nil, err
} }
} }