config add rewrite

This commit is contained in:
siddontang 2014-10-08 16:20:09 +08:00
parent d6f1959da1
commit c09c02c82b
2 changed files with 42 additions and 7 deletions

View File

@ -2,13 +2,16 @@ package config
import ( import (
"bytes" "bytes"
"errors"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/siddontang/go/ioutil2" "github.com/siddontang/go/ioutil2"
"io" "io"
"io/ioutil" "io/ioutil"
) )
type Size int var (
ErrNoConfigFile = errors.New("Running without a config file")
)
const ( const (
DefaultAddr string = "127.0.0.1:6380" DefaultAddr string = "127.0.0.1:6380"
@ -42,6 +45,8 @@ type ReplicationConfig struct {
} }
type Config struct { type Config struct {
FileName string `toml:"-"`
Addr string `toml:"addr"` Addr string `toml:"addr"`
HttpAddr string `toml:"http_addr"` HttpAddr string `toml:"http_addr"`
@ -70,7 +75,12 @@ func NewConfigWithFile(fileName string) (*Config, error) {
return nil, err return nil, err
} }
return NewConfigWithData(data) if cfg, err := NewConfigWithData(data); err != nil {
return nil, err
} else {
cfg.FileName = fileName
return cfg, nil
}
} }
func NewConfigWithData(data []byte) (*Config, error) { func NewConfigWithData(data []byte) (*Config, error) {
@ -142,3 +152,11 @@ func (cfg *Config) DumpFile(fileName string) error {
return ioutil2.WriteFileAtomic(fileName, b.Bytes(), 0644) return ioutil2.WriteFileAtomic(fileName, b.Bytes(), 0644)
} }
func (cfg *Config) Rewrite() error {
if len(cfg.FileName) == 0 {
return ErrNoConfigFile
}
return cfg.DumpFile(cfg.FileName)
}

View File

@ -12,14 +12,31 @@ func TestConfig(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer os.Remove("./config.toml.bak") bakFile := "./config.toml.bak"
if err := cfg.DumpFile("./config.toml.bak"); err != nil {
defer os.Remove(bakFile)
if err := cfg.DumpFile(bakFile); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if c, err := NewConfigWithFile("./config.toml.bak"); err != nil { if c, err := NewConfigWithFile(bakFile); err != nil {
t.Fatal(err) t.Fatal(err)
} else if !reflect.DeepEqual(cfg, c) { } else {
t.Fatal("must equal") c.FileName = cfg.FileName
if !reflect.DeepEqual(cfg, c) {
t.Fatal("must equal")
}
c.FileName = bakFile
c.SlaveOf = "127.0.0.1:6381"
if err := c.Rewrite(); err != nil {
t.Fatal(err)
}
if c1, err := NewConfigWithFile(bakFile); err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(c, c1) {
t.Fatalf("must equal %v != %v", c, c1)
}
} }
} }