diff --git a/README.md b/README.md index 975263ca..1f9ab8f4 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ Yet Another ORM library for Go, aims for developer friendly * Before/After Create/Save/Update/Delete Callbacks * Order/Select/Limit/Offset Support * Automatically CreatedAt, UpdatedAt +* Update, Updates Like Rails's update_attribute, update_attributes +* Dynamically set table name when search, update, delete... * Create table from struct * Prevent SQL Injection * Goroutines friendly @@ -83,7 +85,6 @@ db.Order("age desc").Find(&users1).Order("age", true).Find(&users2) //// users1 -> select * from users order by age desc; //// users2 -> select * from users order by age; - // Limit db.Limit(3).Find(&users) //// users -> select * from users limit 3; @@ -92,7 +93,6 @@ db.Limit(10).Find(&users1).Limit(20).Find(&users2).Limit(-1).Find(&users3) //// users2 -> select * from users limit 20; //// users3 -> select * from users; - // Offset //// select * from users offset 3; db.Offset(3).Find(&users) @@ -159,6 +159,18 @@ db.Table("deleted_users").Find(&deleted_users) db.Table("deleted_users").Find(&deleted_user) //// deleted_user -> select * from deleted_users limit 1; +// Update +db.Table("users").Where(10).Update("name", "hello") +//// update users set name='hello' where id = 10; +db.Table("users").Update("name", "hello") +//// update users set name='hello'; + +// Updates +db.Table("users").Where(10).Updates(map[string]interface{}{"name": "hello", "age": 18}) +//// update users set name='hello', age=18 where id = 10; +db.Table("users").Updates(map[string]interface{}{"name": "hello", "age": 18}) +//// update users set name='hello', age=18; + // Run Raw SQL db.Exec("drop table users;") @@ -203,7 +215,6 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us ``` ## TODO -* Update, Updates like rails's update_attribute, update_attributes * Soft Delete * Query with map or struct * FindOrInitialize / FindOrCreate diff --git a/chain.go b/chain.go index dba71b71..59bb33e3 100644 --- a/chain.go +++ b/chain.go @@ -132,11 +132,15 @@ func (s *Chain) Delete(value interface{}) *Chain { return s } -func (s *Chain) Update(column string, value string) *Chain { - return s +func (s *Chain) Update(column string, value interface{}) *Chain { + return s.Updates(map[string]interface{}{column: value}, true) } -func (s *Chain) Updates(values map[string]string) *Chain { +func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain { + do := s.do(s.value) + do.updateAttrs = values + do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 + do.update() return s } diff --git a/do.go b/do.go index 6f89564b..21b37bed 100644 --- a/do.go +++ b/do.go @@ -33,6 +33,9 @@ type Do struct { offsetStr string limitStr string operation string + + updateAttrs map[string]interface{} + ignoreProtectedAttrs bool } func (s *Do) tableName() string { @@ -93,10 +96,10 @@ func (s *Do) save() *Do { } func (s *Do) prepareCreateSql() *Do { - columns, values := s.model.columnsAndValues("create") + var sqls, columns []string - var sqls []string - for _, value := range values { + for key, value := range s.model.columnsAndValues("create") { + columns = append(columns, key) sqls = append(sqls, s.addToVars(value)) } @@ -138,16 +141,20 @@ func (s *Do) create() *Do { } func (s *Do) prepareUpdateSql() *Do { - columns, values := s.model.columnsAndValues("update") - var sets []string - for index, column := range columns { - sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index]))) + update_attrs := s.updateAttrs + if len(update_attrs) == 0 { + update_attrs = s.model.columnsAndValues("update") + } + + var sqls []string + for key, value := range update_attrs { + sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value))) } s.Sql = fmt.Sprintf( "UPDATE %v SET %v %v", s.tableName(), - strings.Join(sets, ", "), + strings.Join(sqls, ", "), s.combinedSql(), ) return s diff --git a/gorm_test.go b/gorm_test.go index 20353e60..b9b655b4 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -116,7 +116,7 @@ func TestSaveAndFind(t *testing.T) { db.Find(&users) } -func TestUpdate(t *testing.T) { +func TestSaveAndUpdate(t *testing.T) { name, name2, new_name := "update", "update2", "new_update" user := User{Name: name, Age: 1} db.Save(&user) @@ -172,6 +172,10 @@ func TestWhere(t *testing.T) { t.Errorf("Should found out user with name '%v'", name) } + if db.Where(user.Id).First(&User{}).Error != nil { + t.Errorf("Should found out users only with id") + } + user = &User{} orm := db.Where("name LIKE ?", "%nonono%").First(user) if orm.Error == nil { @@ -616,3 +620,65 @@ func TestSetTableDirectly(t *testing.T) { t.Errorf("Set Table Chain Should works well") } } + +func TestUpdate(t *testing.T) { + product1 := Product{Code: "123"} + product2 := Product{Code: "234"} + db.Save(&product1).Save(&product2).Update("code", "456") + + if db.First(&Product{}, "code = '123'").Error != nil { + t.Errorf("Product 123's code should not be changed!") + } + + if db.First(&Product{}, "code = '234'").Error == nil { + t.Errorf("Product 234's code should be changed to 456!") + } + + if db.First(&Product{}, "code = '456'").Error != nil { + t.Errorf("Product 234's code should be 456 now!") + } + + db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789") + if db.First(&Product{}, "code = '123'").Error == nil { + t.Errorf("Product 123's code should be changed to 789") + } + + if db.First(&Product{}, "code = '456'").Error != nil { + t.Errorf("Product 456's code should not be changed to 789") + } + + if db.First(&Product{}, "code = '789'").Error != nil { + t.Errorf("We should have Product 789") + } +} + +func TestUpdates(t *testing.T) { + product1 := Product{Code: "abc", Price: 10} + product2 := Product{Code: "cde", Price: 20} + db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100}) + + if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil { + t.Errorf("Product abc should not be updated!") + } + + if db.First(&Product{}, "code = 'cde'").Error == nil { + t.Errorf("Product cde should be renamed to edf!") + } + + if db.First(&Product{}, "code = 'edf' and price = 100").Error != nil { + t.Errorf("We should have product edf!") + } + + db.Table("products").Where("code in (?)", []string{"abc"}).Updates(map[string]interface{}{"code": "fgh", "price": 200}) + if db.First(&Product{}, "code = 'abc'").Error == nil { + t.Errorf("Product abc's code should be changed to fgh") + } + + if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil { + t.Errorf("Product cde's code should not be changed to fgh") + } + + if db.First(&Product{}, "code = 'fgh' and price = 200").Error != nil { + t.Errorf("We should have Product fgh") + } +} diff --git a/model.go b/model.go index 52cd681f..244a001d 100644 --- a/model.go +++ b/model.go @@ -96,14 +96,14 @@ func (m *Model) fields(operation string) (fields []Field) { return } -func (m *Model) columnsAndValues(operation string) (columns []string, values []interface{}) { +func (m *Model) columnsAndValues(operation string) map[string]interface{} { + results := map[string]interface{}{} for _, field := range m.fields(operation) { if !field.IsPrimaryKey { - columns = append(columns, field.DbName) - values = append(values, field.Value) + results[field.DbName] = field.Value } } - return + return results } func (m *Model) tableName() (str string, err error) { @@ -138,6 +138,10 @@ func (m *Model) tableName() (str string, err error) { } func (m *Model) callMethod(method string) error { + if m.Data == nil { + return nil + } + fm := reflect.ValueOf(m.Data).MethodByName(method) if fm.IsValid() { v := fm.Call([]reflect.Value{})