diff --git a/config/config.go b/config/config.go index 390b96c..5d1c539 100644 --- a/config/config.go +++ b/config/config.go @@ -91,6 +91,9 @@ type SnapshotConfig struct { type Config struct { m sync.RWMutex `toml:"-"` + AuthEnabled bool `toml:"auth_enabled"` + AuthPassword string `toml:"auth_password"` + FileName string `toml:"-"` Addr string `toml:"addr"` @@ -168,6 +171,10 @@ func NewConfigDefault() *Config { cfg.SlaveOf = "" cfg.Readonly = false + // Disable Auth by default + cfg.AuthEnabled = false + cfg.AuthPassword = "" + // default databases number cfg.Databases = 16 diff --git a/server/app_test.go b/server/app_test.go index f182e8e..4dbdaa1 100644 --- a/server/app_test.go +++ b/server/app_test.go @@ -10,9 +10,53 @@ import ( ) var testAppOnce sync.Once +var testAppAuthOnce sync.Once var testApp *App var testLedisClient *goredis.Client +var testLedisClientAuth *goredis.Client + +func getTestConnAuth(password string) *goredis.PoolConn { + startTestAppAuth(password) + conn, _ := testLedisClientAuth.Get() + return conn +} + +func newTestLedisClientAuth() { + testLedisClientAuth = goredis.NewClient("127.0.0.1:20000", "") + testLedisClientAuth.SetMaxIdleConns(4) +} + +func startTestAppAuth(password string) { + f := func() { + newTestLedisClientAuth() + + cfg := config.NewConfigDefault() + cfg.DataDir = "/tmp/testdb_auth" + os.RemoveAll(cfg.DataDir) + + cfg.Addr = "127.0.0.1:20000" + cfg.HttpAddr = "127.0.0.1:20001" + + if password != "" { + cfg.AuthPassword = password + cfg.AuthEnabled = true + } + + os.RemoveAll(cfg.DataDir) + + var err error + testApp, err = NewApp(cfg) + if err != nil { + println(err.Error()) + panic(err) + } + + go testApp.Run() + } + + testAppAuthOnce.Do(f) +} func newTestLedisClient() { testLedisClient = goredis.NewClient("127.0.0.1:16380", "") @@ -36,7 +80,7 @@ func startTestApp() { cfg.Addr = "127.0.0.1:16380" cfg.HttpAddr = "127.0.0.1:21181" - os.RemoveAll("/tmp/testdb") + os.RemoveAll(cfg.DataDir) var err error testApp, err = NewApp(cfg) diff --git a/server/client.go b/server/client.go index 0156d16..23afe08 100644 --- a/server/client.go +++ b/server/client.go @@ -64,6 +64,8 @@ type client struct { cmd string args [][]byte + is_authed bool + resp responseWriter syncBuf bytes.Buffer @@ -86,6 +88,7 @@ func newClient(app *App) *client { c.app = app c.ldb = app.ldb + c.is_authed = false || !app.cfg.AuthEnabled c.db, _ = app.ldb.Select(0) //use default db return c @@ -104,6 +107,8 @@ func (c *client) perform() { err = ErrEmptyCommand } else if exeCmd, ok := regCmds[c.cmd]; !ok { err = ErrNotFound + } else if c.app.cfg.AuthEnabled && !c.is_authed && c.cmd != "auth" { + err = ErrNotAuthenticated } else { // if c.db.IsTransaction() { // if _, ok := txUnsupportedCmds[c.cmd]; ok { @@ -120,7 +125,6 @@ func (c *client) perform() { // } err = exeCmd(c) - } if c.app.access != nil { diff --git a/server/cmd_server.go b/server/cmd_server.go index 0530e33..04d1e04 100644 --- a/server/cmd_server.go +++ b/server/cmd_server.go @@ -14,6 +14,21 @@ func pingCommand(c *client) error { return nil } +func authCommand(c *client) error { + if len(c.args) != 1 { + return ErrCmdParams + } + + if c.app.cfg.AuthPassword == string(c.args[0]) { + c.is_authed = true + c.resp.writeStatus(OK) + return nil + } else { + c.is_authed = false + return ErrAuthenticationFailure + } +} + func echoCommand(c *client) error { if len(c.args) != 1 { return ErrCmdParams @@ -156,6 +171,7 @@ func configCommand(c *client) error { } func init() { + register("auth", authCommand) register("ping", pingCommand) register("echo", echoCommand) register("select", selectCommand) diff --git a/server/cmd_server_test.go b/server/cmd_server_test.go index 325d602..e44d62d 100644 --- a/server/cmd_server_test.go +++ b/server/cmd_server_test.go @@ -6,6 +6,50 @@ import ( "github.com/siddontang/goredis" ) +func TestAuth(t *testing.T) { + c1 := getTestConn() + defer c1.Close() + + // Should error, no params + _, err := c1.Do("AUTH") + if err == nil { + t.Fatal(err) + } + + // Should error, invalid pass + _, err = c1.Do("AUTH", "password") + if err.Error() != " authentication failure" { + t.Fatal("Expected authentication error:", err) + } + + c2 := getTestConnAuth("password") + defer c2.Close() + + // Login + _, err = c2.Do("AUTH", "password") + if err != nil { + t.Fatal(err) + } + + // Should be ok doing a command + _, err = c2.Do("GET", "tmp_select_key") + if err != nil { + t.Fatal(err) + } + + // Log out by sending wrong pass + _, err = c2.Do("AUTH", "wrong password") + if err.Error() != " authentication failure" { + t.Fatal("Expected authentication error:", err) + } + + // Should fail doing a command as we're logged out + _, err = c2.Do("GET", "tmp_select_key") + if err.Error() != " not authenticated" { + t.Fatal("Expected authentication error:", err) + } +} + func TestXSelect(t *testing.T) { c1 := getTestConn() defer c1.Close() diff --git a/server/const.go b/server/const.go index 73fc6a0..df56e12 100644 --- a/server/const.go +++ b/server/const.go @@ -7,13 +7,15 @@ import ( ) var ( - ErrEmptyCommand = errors.New("empty command") - ErrNotFound = errors.New("command not found") - ErrCmdParams = errors.New("invalid command param") - ErrValue = errors.New("value is not an integer or out of range") - ErrSyntax = errors.New("syntax error") - ErrOffset = errors.New("offset bit is not an natural number") - ErrBool = errors.New("value is not 0 or 1") + ErrEmptyCommand = errors.New("empty command") + ErrNotFound = errors.New("command not found") + ErrNotAuthenticated = errors.New("not authenticated") + ErrAuthenticationFailure = errors.New("authentication failure") + ErrCmdParams = errors.New("invalid command param") + ErrValue = errors.New("value is not an integer or out of range") + ErrSyntax = errors.New("syntax error") + ErrOffset = errors.New("offset bit is not an natural number") + ErrBool = errors.New("value is not 0 or 1") ) var (