From 549c7450ea3e76e5265433a2dc7aecc9e7dcac33 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 31 Oct 2013 12:59:04 +0800 Subject: [PATCH] Make it easy when only change one field with Attr, Assign --- README.md | 6 ++++++ chain.go | 12 ++++++------ do.go | 28 ++++++++++++++++------------ gorm_test.go | 8 ++++---- main.go | 8 ++++---- model.go | 16 +++++++++++----- utils.go | 17 +++++++++++++++++ 7 files changed, 64 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 913d307b..8f61f20d 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,8 @@ db.Where(User{Name: "noexisting_user"}).Attrs(User{Age: 20}).FirstOrInit(&user) //// user -> select * from users where name = 'noexisting_user'; //// If no record found, will assign the attrs to user, so user become: //// User{Name: "noexisting_user", Age: 20} +db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user) +// Same as above db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 20}).FirstOrInit(&user) //// user -> select * from users where name = 'jinzhu'; //// If found the user, will ingore the attrs: @@ -108,6 +110,8 @@ db.Where(User{Name: "noexisting_user"}).Assign(User{Age: 20}).FirstOrInit(&user) //// user -> select * from users where name = 'noexisting_user'; //// If no record found, will assign the value to user, so user become: //// User{Name: "noexisting_user", Age: 20} (same as FirstOrInit With Attrs) +db.Where(User{Name: "noexisting_user"}).Assign("age", 20).FirstOrInit(&user) +// Same as above //// user -> User{Name: "noexisting_user", Age: 20} db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 20}).FirstOrInit(&user) //// user -> select * from users where name = 'jinzhu'; @@ -127,6 +131,8 @@ db.Where(User{Name: "noexisting_user"}).Attrs(User{Age: 20}).FirstOrCreate(&user //// user -> select * from users where name = 'noexisting_user'; //// If not record found, will assing the attrs to the user first, then create it //// Same as db.Where(User{Name: "noexisting_user"}).FirstOrCreate(&user).Update("age": 20), but one less sql +db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrCreate(&user) +// Save as above //// user -> User{Id: 112, Name: "noexisting_user", Age: 20} db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 20}).FirstOrCreate(&user) //// user -> select * from users where name = 'jinzhu'; diff --git a/chain.go b/chain.go index 8a357df8..fe99b862 100644 --- a/chain.go +++ b/chain.go @@ -153,8 +153,8 @@ func (s *Chain) Delete(value interface{}) *Chain { return s } -func (s *Chain) Update(column string, value interface{}) *Chain { - return s.Updates(map[string]interface{}{column: value}, true) +func (s *Chain) Update(attrs ...interface{}) *Chain { + return s.Updates(toSearchableMap(attrs...), true) } func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { @@ -174,13 +174,13 @@ func (s *Chain) First(out interface{}, where ...interface{}) *Chain { return s } -func (s *Chain) Attrs(attrs interface{}) *Chain { - s.initAttrs = append(s.initAttrs, attrs) +func (s *Chain) Attrs(attrs ...interface{}) *Chain { + s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) return s } -func (s *Chain) Assign(attrs interface{}) *Chain { - s.assignAttrs = append(s.assignAttrs, attrs) +func (s *Chain) Assign(attrs ...interface{}) *Chain { + s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) return s } diff --git a/do.go b/do.go index 402b26f1..98126026 100644 --- a/do.go +++ b/do.go @@ -133,10 +133,7 @@ func (s *Do) create() { if !s.hasError() { result := reflect.ValueOf(s.value).Elem() - primary_key := result.FieldByName(s.model.primaryKey()) - if primary_key.IsValid() { - primary_key.SetInt(id) - } + setFieldValue(result.FieldByName(s.model.primaryKey()), id) s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterSave")) @@ -323,7 +320,7 @@ func (s *Do) count(value interface{}) { for rows.Next() { var dest int64 if s.err(rows.Scan(&dest)) == nil { - dest_out.SetInt(dest) + setFieldValue(dest_out, dest) } } } @@ -523,17 +520,24 @@ func (s *Do) initializeWithSearchCondition() { for _, clause := range s.whereClause { query := clause["query"] switch query.(type) { - case []interface{}: - for _, obj := range query.([]interface{}) { - m := &Model{data: obj, driver: s.driver} - for _, field := range m.columnsHasValue("") { - m.setValueByColumn(field.DbName, field.Value, s.value) - } - } case map[string]interface{}: for key, value := range query.(map[string]interface{}) { m.setValueByColumn(key, value, s.value) } + case []interface{}: + for _, obj := range query.([]interface{}) { + switch reflect.ValueOf(obj).Kind() { + case reflect.Struct: + m := &Model{data: obj, driver: s.driver} + for _, field := range m.columnsHasValue("") { + m.setValueByColumn(field.DbName, field.Value, s.value) + } + case reflect.Map: + for key, value := range obj.(map[string]interface{}) { + m.setValueByColumn(key, value, s.value) + } + } + } case interface{}: m := &Model{data: query, driver: s.driver} for _, field := range m.columnsHasValue("") { diff --git a/gorm_test.go b/gorm_test.go index c57d028d..e54a5ec5 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -870,14 +870,14 @@ func TestFindOrInitialize(t *testing.T) { t.Errorf("user should be initialized with search value and attrs") } - db.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user4) + db.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { t.Errorf("user should be initialized with search value and assigned attrs") } db.Save(&User{Name: "find or init", Age: 33}) - db.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user5) + db.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { t.Errorf("user should be found and not initialized by Attrs") } @@ -910,7 +910,7 @@ func TestFindOrCreate(t *testing.T) { t.Errorf("user should be created with inline search value") } - db.Where(&User{Name: "find or create 3"}).Attrs(User{Age: 44}).FirstOrCreate(&user4) + db.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { t.Errorf("user should be created with search value and attrs") } @@ -920,7 +920,7 @@ func TestFindOrCreate(t *testing.T) { t.Errorf("user should be created with search value and assigned attrs") } - db.Where(&User{Name: "find or create"}).Attrs(User{Age: 44}).FirstOrInit(&user5) + db.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { t.Errorf("user should be found and not initialized by Attrs") } diff --git a/main.go b/main.go index f7ca92c2..ea56b7bc 100644 --- a/main.go +++ b/main.go @@ -30,12 +30,12 @@ func (s *DB) First(out interface{}, where ...interface{}) *Chain { return s.buildChain().First(out, where...) } -func (s *DB) Attrs(attrs interface{}) *Chain { - return s.buildChain().Attrs(attrs) +func (s *DB) Attrs(attrs ...interface{}) *Chain { + return s.buildChain().Attrs(attrs...) } -func (s *DB) Assign(attrs interface{}) *Chain { - return s.buildChain().Assign(attrs) +func (s *DB) Assign(attrs ...interface{}) *Chain { + return s.buildChain().Assign(attrs...) } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *Chain { diff --git a/model.go b/model.go index 6887bd3d..58a63479 100644 --- a/model.go +++ b/model.go @@ -151,9 +151,8 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[stri } } - field := data.FieldByName("UpdatedAt") - if field.IsValid() && values["updated_at"] != nil && len(results) > 0 { - data.FieldByName("UpdatedAt").Set(reflect.ValueOf(time.Now())) + if values["updated_at"] != nil && len(results) > 0 { + setFieldValue(data.FieldByName("UpdatedAt"), time.Now()) } result := len(results) > 0 return map[string]interface{}{}, result @@ -238,9 +237,16 @@ func (m *Model) missingColumns() (results []string) { func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) { data := reflect.Indirect(reflect.ValueOf(out)) + setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value) +} - field := data.FieldByName(snakeToUpperCamel(name)) +func setFieldValue(field reflect.Value, value interface{}) { if field.IsValid() { - field.Set(reflect.ValueOf(value)) + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + field.SetInt(reflect.ValueOf(value).Int()) + default: + field.Set(reflect.ValueOf(value)) + } } } diff --git a/utils.go b/utils.go index 41791b7a..6277bc40 100644 --- a/utils.go +++ b/utils.go @@ -29,6 +29,23 @@ func snakeToUpperCamel(s string) string { return buf.String() } +func toSearchableMap(attrs ...interface{}) (result interface{}) { + if len(attrs) > 1 { + if str, ok := attrs[0].(string); ok { + result = map[string]interface{}{str: attrs[1]} + } + } else if len(attrs) == 1 { + if attr, ok := attrs[0].(map[string]interface{}); ok { + result = attr + } + + if attr, ok := attrs[0].(interface{}); ok { + result = attr + } + } + return +} + func debug(value interface{}) { fmt.Printf("***************\n") fmt.Printf("%+v\n\n", value)