diff --git a/config/config.go b/config/config.go index b90182e..a90e40c 100644 --- a/config/config.go +++ b/config/config.go @@ -2,13 +2,16 @@ package config import ( "bytes" + "errors" "github.com/BurntSushi/toml" "github.com/siddontang/go/ioutil2" "io" "io/ioutil" ) -type Size int +var ( + ErrNoConfigFile = errors.New("Running without a config file") +) const ( DefaultAddr string = "127.0.0.1:6380" @@ -42,6 +45,8 @@ type ReplicationConfig struct { } type Config struct { + FileName string `toml:"-"` + Addr string `toml:"addr"` HttpAddr string `toml:"http_addr"` @@ -70,7 +75,12 @@ func NewConfigWithFile(fileName string) (*Config, error) { return nil, err } - return NewConfigWithData(data) + if cfg, err := NewConfigWithData(data); err != nil { + return nil, err + } else { + cfg.FileName = fileName + return cfg, nil + } } func NewConfigWithData(data []byte) (*Config, error) { @@ -142,3 +152,11 @@ func (cfg *Config) DumpFile(fileName string) error { return ioutil2.WriteFileAtomic(fileName, b.Bytes(), 0644) } + +func (cfg *Config) Rewrite() error { + if len(cfg.FileName) == 0 { + return ErrNoConfigFile + } + + return cfg.DumpFile(cfg.FileName) +} diff --git a/config/config_test.go b/config/config_test.go index 077451c..6bde8bf 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -12,14 +12,31 @@ func TestConfig(t *testing.T) { t.Fatal(err) } - defer os.Remove("./config.toml.bak") - if err := cfg.DumpFile("./config.toml.bak"); err != nil { + bakFile := "./config.toml.bak" + + defer os.Remove(bakFile) + if err := cfg.DumpFile(bakFile); err != nil { t.Fatal(err) } - if c, err := NewConfigWithFile("./config.toml.bak"); err != nil { + if c, err := NewConfigWithFile(bakFile); err != nil { t.Fatal(err) - } else if !reflect.DeepEqual(cfg, c) { - t.Fatal("must equal") + } else { + c.FileName = cfg.FileName + if !reflect.DeepEqual(cfg, c) { + t.Fatal("must equal") + } + + c.FileName = bakFile + c.SlaveOf = "127.0.0.1:6381" + if err := c.Rewrite(); err != nil { + t.Fatal(err) + } + + if c1, err := NewConfigWithFile(bakFile); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(c, c1) { + t.Fatalf("must equal %v != %v", c, c1) + } } }