mirror of https://github.com/ledisdb/ledisdb.git
Attempt at Auth :)
This commit is contained in:
parent
8f582ee5fb
commit
623794feb2
|
@ -91,6 +91,9 @@ type SnapshotConfig struct {
|
||||||
type Config struct {
|
type Config struct {
|
||||||
m sync.RWMutex `toml:"-"`
|
m sync.RWMutex `toml:"-"`
|
||||||
|
|
||||||
|
AuthEnabled bool `toml:"auth_enabled"`
|
||||||
|
AuthPassword string `toml:"auth_password"`
|
||||||
|
|
||||||
FileName string `toml:"-"`
|
FileName string `toml:"-"`
|
||||||
|
|
||||||
Addr string `toml:"addr"`
|
Addr string `toml:"addr"`
|
||||||
|
@ -168,6 +171,10 @@ func NewConfigDefault() *Config {
|
||||||
cfg.SlaveOf = ""
|
cfg.SlaveOf = ""
|
||||||
cfg.Readonly = false
|
cfg.Readonly = false
|
||||||
|
|
||||||
|
// Disable Auth by default
|
||||||
|
cfg.AuthEnabled = false
|
||||||
|
cfg.AuthPassword = ""
|
||||||
|
|
||||||
// default databases number
|
// default databases number
|
||||||
cfg.Databases = 16
|
cfg.Databases = 16
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,53 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var testAppOnce sync.Once
|
var testAppOnce sync.Once
|
||||||
|
var testAppAuthOnce sync.Once
|
||||||
var testApp *App
|
var testApp *App
|
||||||
|
|
||||||
var testLedisClient *goredis.Client
|
var testLedisClient *goredis.Client
|
||||||
|
var testLedisClientAuth *goredis.Client
|
||||||
|
|
||||||
|
func getTestConnAuth(password string) *goredis.PoolConn {
|
||||||
|
startTestAppAuth(password)
|
||||||
|
conn, _ := testLedisClientAuth.Get()
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestLedisClientAuth() {
|
||||||
|
testLedisClientAuth = goredis.NewClient("127.0.0.1:20000", "")
|
||||||
|
testLedisClientAuth.SetMaxIdleConns(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTestAppAuth(password string) {
|
||||||
|
f := func() {
|
||||||
|
newTestLedisClientAuth()
|
||||||
|
|
||||||
|
cfg := config.NewConfigDefault()
|
||||||
|
cfg.DataDir = "/tmp/testdb_auth"
|
||||||
|
os.RemoveAll(cfg.DataDir)
|
||||||
|
|
||||||
|
cfg.Addr = "127.0.0.1:20000"
|
||||||
|
cfg.HttpAddr = "127.0.0.1:20001"
|
||||||
|
|
||||||
|
if password != "" {
|
||||||
|
cfg.AuthPassword = password
|
||||||
|
cfg.AuthEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
os.RemoveAll(cfg.DataDir)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
testApp, err = NewApp(cfg)
|
||||||
|
if err != nil {
|
||||||
|
println(err.Error())
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go testApp.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
testAppAuthOnce.Do(f)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestLedisClient() {
|
func newTestLedisClient() {
|
||||||
testLedisClient = goredis.NewClient("127.0.0.1:16380", "")
|
testLedisClient = goredis.NewClient("127.0.0.1:16380", "")
|
||||||
|
@ -36,7 +80,7 @@ func startTestApp() {
|
||||||
cfg.Addr = "127.0.0.1:16380"
|
cfg.Addr = "127.0.0.1:16380"
|
||||||
cfg.HttpAddr = "127.0.0.1:21181"
|
cfg.HttpAddr = "127.0.0.1:21181"
|
||||||
|
|
||||||
os.RemoveAll("/tmp/testdb")
|
os.RemoveAll(cfg.DataDir)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
testApp, err = NewApp(cfg)
|
testApp, err = NewApp(cfg)
|
||||||
|
|
|
@ -64,6 +64,8 @@ type client struct {
|
||||||
cmd string
|
cmd string
|
||||||
args [][]byte
|
args [][]byte
|
||||||
|
|
||||||
|
is_authed bool
|
||||||
|
|
||||||
resp responseWriter
|
resp responseWriter
|
||||||
|
|
||||||
syncBuf bytes.Buffer
|
syncBuf bytes.Buffer
|
||||||
|
@ -86,6 +88,7 @@ func newClient(app *App) *client {
|
||||||
|
|
||||||
c.app = app
|
c.app = app
|
||||||
c.ldb = app.ldb
|
c.ldb = app.ldb
|
||||||
|
c.is_authed = false || !app.cfg.AuthEnabled
|
||||||
c.db, _ = app.ldb.Select(0) //use default db
|
c.db, _ = app.ldb.Select(0) //use default db
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
@ -104,6 +107,8 @@ func (c *client) perform() {
|
||||||
err = ErrEmptyCommand
|
err = ErrEmptyCommand
|
||||||
} else if exeCmd, ok := regCmds[c.cmd]; !ok {
|
} else if exeCmd, ok := regCmds[c.cmd]; !ok {
|
||||||
err = ErrNotFound
|
err = ErrNotFound
|
||||||
|
} else if c.app.cfg.AuthEnabled && !c.is_authed && c.cmd != "auth" {
|
||||||
|
err = ErrNotAuthenticated
|
||||||
} else {
|
} else {
|
||||||
// if c.db.IsTransaction() {
|
// if c.db.IsTransaction() {
|
||||||
// if _, ok := txUnsupportedCmds[c.cmd]; ok {
|
// if _, ok := txUnsupportedCmds[c.cmd]; ok {
|
||||||
|
@ -120,7 +125,6 @@ func (c *client) perform() {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
err = exeCmd(c)
|
err = exeCmd(c)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.app.access != nil {
|
if c.app.access != nil {
|
||||||
|
|
|
@ -14,6 +14,21 @@ func pingCommand(c *client) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func authCommand(c *client) error {
|
||||||
|
if len(c.args) != 1 {
|
||||||
|
return ErrCmdParams
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.app.cfg.AuthPassword == string(c.args[0]) {
|
||||||
|
c.is_authed = true
|
||||||
|
c.resp.writeStatus(OK)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
c.is_authed = false
|
||||||
|
return ErrAuthenticationFailure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func echoCommand(c *client) error {
|
func echoCommand(c *client) error {
|
||||||
if len(c.args) != 1 {
|
if len(c.args) != 1 {
|
||||||
return ErrCmdParams
|
return ErrCmdParams
|
||||||
|
@ -156,6 +171,7 @@ func configCommand(c *client) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
register("auth", authCommand)
|
||||||
register("ping", pingCommand)
|
register("ping", pingCommand)
|
||||||
register("echo", echoCommand)
|
register("echo", echoCommand)
|
||||||
register("select", selectCommand)
|
register("select", selectCommand)
|
||||||
|
|
|
@ -6,6 +6,50 @@ import (
|
||||||
"github.com/siddontang/goredis"
|
"github.com/siddontang/goredis"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestAuth(t *testing.T) {
|
||||||
|
c1 := getTestConn()
|
||||||
|
defer c1.Close()
|
||||||
|
|
||||||
|
// Should error, no params
|
||||||
|
_, err := c1.Do("AUTH")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should error, invalid pass
|
||||||
|
_, err = c1.Do("AUTH", "password")
|
||||||
|
if err.Error() != " authentication failure" {
|
||||||
|
t.Fatal("Expected authentication error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c2 := getTestConnAuth("password")
|
||||||
|
defer c2.Close()
|
||||||
|
|
||||||
|
// Login
|
||||||
|
_, err = c2.Do("AUTH", "password")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be ok doing a command
|
||||||
|
_, err = c2.Do("GET", "tmp_select_key")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log out by sending wrong pass
|
||||||
|
_, err = c2.Do("AUTH", "wrong password")
|
||||||
|
if err.Error() != " authentication failure" {
|
||||||
|
t.Fatal("Expected authentication error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should fail doing a command as we're logged out
|
||||||
|
_, err = c2.Do("GET", "tmp_select_key")
|
||||||
|
if err.Error() != " not authenticated" {
|
||||||
|
t.Fatal("Expected authentication error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestXSelect(t *testing.T) {
|
func TestXSelect(t *testing.T) {
|
||||||
c1 := getTestConn()
|
c1 := getTestConn()
|
||||||
defer c1.Close()
|
defer c1.Close()
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
var (
|
var (
|
||||||
ErrEmptyCommand = errors.New("empty command")
|
ErrEmptyCommand = errors.New("empty command")
|
||||||
ErrNotFound = errors.New("command not found")
|
ErrNotFound = errors.New("command not found")
|
||||||
|
ErrNotAuthenticated = errors.New("not authenticated")
|
||||||
|
ErrAuthenticationFailure = errors.New("authentication failure")
|
||||||
ErrCmdParams = errors.New("invalid command param")
|
ErrCmdParams = errors.New("invalid command param")
|
||||||
ErrValue = errors.New("value is not an integer or out of range")
|
ErrValue = errors.New("value is not an integer or out of range")
|
||||||
ErrSyntax = errors.New("syntax error")
|
ErrSyntax = errors.New("syntax error")
|
||||||
|
|
Loading…
Reference in New Issue