Yay, Add Update, Updates Support

This commit is contained in:
Jinzhu 2013-10-28 21:52:22 +08:00
parent 2a20e551ed
commit 84b280c0ff
5 changed files with 111 additions and 19 deletions

View File

@ -9,6 +9,8 @@ Yet Another ORM library for Go, aims for developer friendly
* Before/After Create/Save/Update/Delete Callbacks * Before/After Create/Save/Update/Delete Callbacks
* Order/Select/Limit/Offset Support * Order/Select/Limit/Offset Support
* Automatically CreatedAt, UpdatedAt * Automatically CreatedAt, UpdatedAt
* Update, Updates Like Rails's update_attribute, update_attributes
* Dynamically set table name when search, update, delete...
* Create table from struct * Create table from struct
* Prevent SQL Injection * Prevent SQL Injection
* Goroutines friendly * Goroutines friendly
@ -83,7 +85,6 @@ db.Order("age desc").Find(&users1).Order("age", true).Find(&users2)
//// users1 -> select * from users order by age desc; //// users1 -> select * from users order by age desc;
//// users2 -> select * from users order by age; //// users2 -> select * from users order by age;
// Limit // Limit
db.Limit(3).Find(&users) db.Limit(3).Find(&users)
//// users -> select * from users limit 3; //// users -> select * from users limit 3;
@ -92,7 +93,6 @@ db.Limit(10).Find(&users1).Limit(20).Find(&users2).Limit(-1).Find(&users3)
//// users2 -> select * from users limit 20; //// users2 -> select * from users limit 20;
//// users3 -> select * from users; //// users3 -> select * from users;
// Offset // Offset
//// select * from users offset 3; //// select * from users offset 3;
db.Offset(3).Find(&users) db.Offset(3).Find(&users)
@ -159,6 +159,18 @@ db.Table("deleted_users").Find(&deleted_users)
db.Table("deleted_users").Find(&deleted_user) db.Table("deleted_users").Find(&deleted_user)
//// deleted_user -> select * from deleted_users limit 1; //// deleted_user -> select * from deleted_users limit 1;
// Update
db.Table("users").Where(10).Update("name", "hello")
//// update users set name='hello' where id = 10;
db.Table("users").Update("name", "hello")
//// update users set name='hello';
// Updates
db.Table("users").Where(10).Updates(map[string]interface{}{"name": "hello", "age": 18})
//// update users set name='hello', age=18 where id = 10;
db.Table("users").Updates(map[string]interface{}{"name": "hello", "age": 18})
//// update users set name='hello', age=18;
// Run Raw SQL // Run Raw SQL
db.Exec("drop table users;") db.Exec("drop table users;")
@ -203,7 +215,6 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us
``` ```
## TODO ## TODO
* Update, Updates like rails's update_attribute, update_attributes
* Soft Delete * Soft Delete
* Query with map or struct * Query with map or struct
* FindOrInitialize / FindOrCreate * FindOrInitialize / FindOrCreate

View File

@ -132,11 +132,15 @@ func (s *Chain) Delete(value interface{}) *Chain {
return s return s
} }
func (s *Chain) Update(column string, value string) *Chain { func (s *Chain) Update(column string, value interface{}) *Chain {
return s return s.Updates(map[string]interface{}{column: value}, true)
} }
func (s *Chain) Updates(values map[string]string) *Chain { func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain {
do := s.do(s.value)
do.updateAttrs = values
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0
do.update()
return s return s
} }

23
do.go
View File

@ -33,6 +33,9 @@ type Do struct {
offsetStr string offsetStr string
limitStr string limitStr string
operation string operation string
updateAttrs map[string]interface{}
ignoreProtectedAttrs bool
} }
func (s *Do) tableName() string { func (s *Do) tableName() string {
@ -93,10 +96,10 @@ func (s *Do) save() *Do {
} }
func (s *Do) prepareCreateSql() *Do { func (s *Do) prepareCreateSql() *Do {
columns, values := s.model.columnsAndValues("create") var sqls, columns []string
var sqls []string for key, value := range s.model.columnsAndValues("create") {
for _, value := range values { columns = append(columns, key)
sqls = append(sqls, s.addToVars(value)) sqls = append(sqls, s.addToVars(value))
} }
@ -138,16 +141,20 @@ func (s *Do) create() *Do {
} }
func (s *Do) prepareUpdateSql() *Do { func (s *Do) prepareUpdateSql() *Do {
columns, values := s.model.columnsAndValues("update") update_attrs := s.updateAttrs
var sets []string if len(update_attrs) == 0 {
for index, column := range columns { update_attrs = s.model.columnsAndValues("update")
sets = append(sets, fmt.Sprintf("%v = %v", s.quote(column), s.addToVars(values[index]))) }
var sqls []string
for key, value := range update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
} }
s.Sql = fmt.Sprintf( s.Sql = fmt.Sprintf(
"UPDATE %v SET %v %v", "UPDATE %v SET %v %v",
s.tableName(), s.tableName(),
strings.Join(sets, ", "), strings.Join(sqls, ", "),
s.combinedSql(), s.combinedSql(),
) )
return s return s

View File

@ -116,7 +116,7 @@ func TestSaveAndFind(t *testing.T) {
db.Find(&users) db.Find(&users)
} }
func TestUpdate(t *testing.T) { func TestSaveAndUpdate(t *testing.T) {
name, name2, new_name := "update", "update2", "new_update" name, name2, new_name := "update", "update2", "new_update"
user := User{Name: name, Age: 1} user := User{Name: name, Age: 1}
db.Save(&user) db.Save(&user)
@ -172,6 +172,10 @@ func TestWhere(t *testing.T) {
t.Errorf("Should found out user with name '%v'", name) t.Errorf("Should found out user with name '%v'", name)
} }
if db.Where(user.Id).First(&User{}).Error != nil {
t.Errorf("Should found out users only with id")
}
user = &User{} user = &User{}
orm := db.Where("name LIKE ?", "%nonono%").First(user) orm := db.Where("name LIKE ?", "%nonono%").First(user)
if orm.Error == nil { if orm.Error == nil {
@ -616,3 +620,65 @@ func TestSetTableDirectly(t *testing.T) {
t.Errorf("Set Table Chain Should works well") t.Errorf("Set Table Chain Should works well")
} }
} }
func TestUpdate(t *testing.T) {
product1 := Product{Code: "123"}
product2 := Product{Code: "234"}
db.Save(&product1).Save(&product2).Update("code", "456")
if db.First(&Product{}, "code = '123'").Error != nil {
t.Errorf("Product 123's code should not be changed!")
}
if db.First(&Product{}, "code = '234'").Error == nil {
t.Errorf("Product 234's code should be changed to 456!")
}
if db.First(&Product{}, "code = '456'").Error != nil {
t.Errorf("Product 234's code should be 456 now!")
}
db.Table("products").Where("code in (?)", []string{"123"}).Update("code", "789")
if db.First(&Product{}, "code = '123'").Error == nil {
t.Errorf("Product 123's code should be changed to 789")
}
if db.First(&Product{}, "code = '456'").Error != nil {
t.Errorf("Product 456's code should not be changed to 789")
}
if db.First(&Product{}, "code = '789'").Error != nil {
t.Errorf("We should have Product 789")
}
}
func TestUpdates(t *testing.T) {
product1 := Product{Code: "abc", Price: 10}
product2 := Product{Code: "cde", Price: 20}
db.Save(&product1).Save(&product2).Updates(map[string]interface{}{"code": "edf", "price": 100})
if db.First(&Product{}, "code = 'abc' and price = 10").Error != nil {
t.Errorf("Product abc should not be updated!")
}
if db.First(&Product{}, "code = 'cde'").Error == nil {
t.Errorf("Product cde should be renamed to edf!")
}
if db.First(&Product{}, "code = 'edf' and price = 100").Error != nil {
t.Errorf("We should have product edf!")
}
db.Table("products").Where("code in (?)", []string{"abc"}).Updates(map[string]interface{}{"code": "fgh", "price": 200})
if db.First(&Product{}, "code = 'abc'").Error == nil {
t.Errorf("Product abc's code should be changed to fgh")
}
if db.First(&Product{}, "code = 'edf' and price = ?", 100).Error != nil {
t.Errorf("Product cde's code should not be changed to fgh")
}
if db.First(&Product{}, "code = 'fgh' and price = 200").Error != nil {
t.Errorf("We should have Product fgh")
}
}

View File

@ -96,14 +96,14 @@ func (m *Model) fields(operation string) (fields []Field) {
return return
} }
func (m *Model) columnsAndValues(operation string) (columns []string, values []interface{}) { func (m *Model) columnsAndValues(operation string) map[string]interface{} {
results := map[string]interface{}{}
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey { if !field.IsPrimaryKey {
columns = append(columns, field.DbName) results[field.DbName] = field.Value
values = append(values, field.Value)
} }
} }
return return results
} }
func (m *Model) tableName() (str string, err error) { func (m *Model) tableName() (str string, err error) {
@ -138,6 +138,10 @@ func (m *Model) tableName() (str string, err error) {
} }
func (m *Model) callMethod(method string) error { func (m *Model) callMethod(method string) error {
if m.Data == nil {
return nil
}
fm := reflect.ValueOf(m.Data).MethodByName(method) fm := reflect.ValueOf(m.Data).MethodByName(method)
if fm.IsValid() { if fm.IsValid() {
v := fm.Call([]reflect.Value{}) v := fm.Call([]reflect.Value{})