forked from mirror/ledisdb
Attempt at Auth :)
This commit is contained in:
parent
8f582ee5fb
commit
623794feb2
|
@ -91,6 +91,9 @@ type SnapshotConfig struct {
|
|||
type Config struct {
|
||||
m sync.RWMutex `toml:"-"`
|
||||
|
||||
AuthEnabled bool `toml:"auth_enabled"`
|
||||
AuthPassword string `toml:"auth_password"`
|
||||
|
||||
FileName string `toml:"-"`
|
||||
|
||||
Addr string `toml:"addr"`
|
||||
|
@ -168,6 +171,10 @@ func NewConfigDefault() *Config {
|
|||
cfg.SlaveOf = ""
|
||||
cfg.Readonly = false
|
||||
|
||||
// Disable Auth by default
|
||||
cfg.AuthEnabled = false
|
||||
cfg.AuthPassword = ""
|
||||
|
||||
// default databases number
|
||||
cfg.Databases = 16
|
||||
|
||||
|
|
|
@ -10,9 +10,53 @@ import (
|
|||
)
|
||||
|
||||
var testAppOnce sync.Once
|
||||
var testAppAuthOnce sync.Once
|
||||
var testApp *App
|
||||
|
||||
var testLedisClient *goredis.Client
|
||||
var testLedisClientAuth *goredis.Client
|
||||
|
||||
func getTestConnAuth(password string) *goredis.PoolConn {
|
||||
startTestAppAuth(password)
|
||||
conn, _ := testLedisClientAuth.Get()
|
||||
return conn
|
||||
}
|
||||
|
||||
func newTestLedisClientAuth() {
|
||||
testLedisClientAuth = goredis.NewClient("127.0.0.1:20000", "")
|
||||
testLedisClientAuth.SetMaxIdleConns(4)
|
||||
}
|
||||
|
||||
func startTestAppAuth(password string) {
|
||||
f := func() {
|
||||
newTestLedisClientAuth()
|
||||
|
||||
cfg := config.NewConfigDefault()
|
||||
cfg.DataDir = "/tmp/testdb_auth"
|
||||
os.RemoveAll(cfg.DataDir)
|
||||
|
||||
cfg.Addr = "127.0.0.1:20000"
|
||||
cfg.HttpAddr = "127.0.0.1:20001"
|
||||
|
||||
if password != "" {
|
||||
cfg.AuthPassword = password
|
||||
cfg.AuthEnabled = true
|
||||
}
|
||||
|
||||
os.RemoveAll(cfg.DataDir)
|
||||
|
||||
var err error
|
||||
testApp, err = NewApp(cfg)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go testApp.Run()
|
||||
}
|
||||
|
||||
testAppAuthOnce.Do(f)
|
||||
}
|
||||
|
||||
func newTestLedisClient() {
|
||||
testLedisClient = goredis.NewClient("127.0.0.1:16380", "")
|
||||
|
@ -36,7 +80,7 @@ func startTestApp() {
|
|||
cfg.Addr = "127.0.0.1:16380"
|
||||
cfg.HttpAddr = "127.0.0.1:21181"
|
||||
|
||||
os.RemoveAll("/tmp/testdb")
|
||||
os.RemoveAll(cfg.DataDir)
|
||||
|
||||
var err error
|
||||
testApp, err = NewApp(cfg)
|
||||
|
|
|
@ -64,6 +64,8 @@ type client struct {
|
|||
cmd string
|
||||
args [][]byte
|
||||
|
||||
is_authed bool
|
||||
|
||||
resp responseWriter
|
||||
|
||||
syncBuf bytes.Buffer
|
||||
|
@ -86,6 +88,7 @@ func newClient(app *App) *client {
|
|||
|
||||
c.app = app
|
||||
c.ldb = app.ldb
|
||||
c.is_authed = false || !app.cfg.AuthEnabled
|
||||
c.db, _ = app.ldb.Select(0) //use default db
|
||||
|
||||
return c
|
||||
|
@ -104,6 +107,8 @@ func (c *client) perform() {
|
|||
err = ErrEmptyCommand
|
||||
} else if exeCmd, ok := regCmds[c.cmd]; !ok {
|
||||
err = ErrNotFound
|
||||
} else if c.app.cfg.AuthEnabled && !c.is_authed && c.cmd != "auth" {
|
||||
err = ErrNotAuthenticated
|
||||
} else {
|
||||
// if c.db.IsTransaction() {
|
||||
// if _, ok := txUnsupportedCmds[c.cmd]; ok {
|
||||
|
@ -120,7 +125,6 @@ func (c *client) perform() {
|
|||
// }
|
||||
|
||||
err = exeCmd(c)
|
||||
|
||||
}
|
||||
|
||||
if c.app.access != nil {
|
||||
|
|
|
@ -14,6 +14,21 @@ func pingCommand(c *client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func authCommand(c *client) error {
|
||||
if len(c.args) != 1 {
|
||||
return ErrCmdParams
|
||||
}
|
||||
|
||||
if c.app.cfg.AuthPassword == string(c.args[0]) {
|
||||
c.is_authed = true
|
||||
c.resp.writeStatus(OK)
|
||||
return nil
|
||||
} else {
|
||||
c.is_authed = false
|
||||
return ErrAuthenticationFailure
|
||||
}
|
||||
}
|
||||
|
||||
func echoCommand(c *client) error {
|
||||
if len(c.args) != 1 {
|
||||
return ErrCmdParams
|
||||
|
@ -156,6 +171,7 @@ func configCommand(c *client) error {
|
|||
}
|
||||
|
||||
func init() {
|
||||
register("auth", authCommand)
|
||||
register("ping", pingCommand)
|
||||
register("echo", echoCommand)
|
||||
register("select", selectCommand)
|
||||
|
|
|
@ -6,6 +6,50 @@ import (
|
|||
"github.com/siddontang/goredis"
|
||||
)
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
c1 := getTestConn()
|
||||
defer c1.Close()
|
||||
|
||||
// Should error, no params
|
||||
_, err := c1.Do("AUTH")
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should error, invalid pass
|
||||
_, err = c1.Do("AUTH", "password")
|
||||
if err.Error() != " authentication failure" {
|
||||
t.Fatal("Expected authentication error:", err)
|
||||
}
|
||||
|
||||
c2 := getTestConnAuth("password")
|
||||
defer c2.Close()
|
||||
|
||||
// Login
|
||||
_, err = c2.Do("AUTH", "password")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should be ok doing a command
|
||||
_, err = c2.Do("GET", "tmp_select_key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Log out by sending wrong pass
|
||||
_, err = c2.Do("AUTH", "wrong password")
|
||||
if err.Error() != " authentication failure" {
|
||||
t.Fatal("Expected authentication error:", err)
|
||||
}
|
||||
|
||||
// Should fail doing a command as we're logged out
|
||||
_, err = c2.Do("GET", "tmp_select_key")
|
||||
if err.Error() != " not authenticated" {
|
||||
t.Fatal("Expected authentication error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXSelect(t *testing.T) {
|
||||
c1 := getTestConn()
|
||||
defer c1.Close()
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
var (
|
||||
ErrEmptyCommand = errors.New("empty command")
|
||||
ErrNotFound = errors.New("command not found")
|
||||
ErrNotAuthenticated = errors.New("not authenticated")
|
||||
ErrAuthenticationFailure = errors.New("authentication failure")
|
||||
ErrCmdParams = errors.New("invalid command param")
|
||||
ErrValue = errors.New("value is not an integer or out of range")
|
||||
ErrSyntax = errors.New("syntax error")
|
||||
|
|
Loading…
Reference in New Issue