Make it easy when only change one field with Attr, Assign

This commit is contained in:
Jinzhu 2013-10-31 12:59:04 +08:00
parent 0e692a13b1
commit 549c7450ea
7 changed files with 64 additions and 31 deletions

View File

@ -98,6 +98,8 @@ db.Where(User{Name: "noexisting_user"}).Attrs(User{Age: 20}).FirstOrInit(&user)
//// user -> select * from users where name = 'noexisting_user';
//// If no record found, will assign the attrs to user, so user become:
//// User{Name: "noexisting_user", Age: 20}
db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user)
// Same as above
db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 20}).FirstOrInit(&user)
//// user -> select * from users where name = 'jinzhu';
//// If found the user, will ingore the attrs:
@ -108,6 +110,8 @@ db.Where(User{Name: "noexisting_user"}).Assign(User{Age: 20}).FirstOrInit(&user)
//// user -> select * from users where name = 'noexisting_user';
//// If no record found, will assign the value to user, so user become:
//// User{Name: "noexisting_user", Age: 20} (same as FirstOrInit With Attrs)
db.Where(User{Name: "noexisting_user"}).Assign("age", 20).FirstOrInit(&user)
// Same as above
//// user -> User{Name: "noexisting_user", Age: 20}
db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 20}).FirstOrInit(&user)
//// user -> select * from users where name = 'jinzhu';
@ -127,6 +131,8 @@ db.Where(User{Name: "noexisting_user"}).Attrs(User{Age: 20}).FirstOrCreate(&user
//// user -> select * from users where name = 'noexisting_user';
//// If not record found, will assing the attrs to the user first, then create it
//// Same as db.Where(User{Name: "noexisting_user"}).FirstOrCreate(&user).Update("age": 20), but one less sql
db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrCreate(&user)
// Save as above
//// user -> User{Id: 112, Name: "noexisting_user", Age: 20}
db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 20}).FirstOrCreate(&user)
//// user -> select * from users where name = 'jinzhu';

View File

@ -153,8 +153,8 @@ func (s *Chain) Delete(value interface{}) *Chain {
return s
}
func (s *Chain) Update(column string, value interface{}) *Chain {
return s.Updates(map[string]interface{}{column: value}, true)
func (s *Chain) Update(attrs ...interface{}) *Chain {
return s.Updates(toSearchableMap(attrs...), true)
}
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
@ -174,13 +174,13 @@ func (s *Chain) First(out interface{}, where ...interface{}) *Chain {
return s
}
func (s *Chain) Attrs(attrs interface{}) *Chain {
s.initAttrs = append(s.initAttrs, attrs)
func (s *Chain) Attrs(attrs ...interface{}) *Chain {
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
return s
}
func (s *Chain) Assign(attrs interface{}) *Chain {
s.assignAttrs = append(s.assignAttrs, attrs)
func (s *Chain) Assign(attrs ...interface{}) *Chain {
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
return s
}

28
do.go
View File

@ -133,10 +133,7 @@ func (s *Do) create() {
if !s.hasError() {
result := reflect.ValueOf(s.value).Elem()
primary_key := result.FieldByName(s.model.primaryKey())
if primary_key.IsValid() {
primary_key.SetInt(id)
}
setFieldValue(result.FieldByName(s.model.primaryKey()), id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
@ -323,7 +320,7 @@ func (s *Do) count(value interface{}) {
for rows.Next() {
var dest int64
if s.err(rows.Scan(&dest)) == nil {
dest_out.SetInt(dest)
setFieldValue(dest_out, dest)
}
}
}
@ -523,17 +520,24 @@ func (s *Do) initializeWithSearchCondition() {
for _, clause := range s.whereClause {
query := clause["query"]
switch query.(type) {
case []interface{}:
for _, obj := range query.([]interface{}) {
m := &Model{data: obj, driver: s.driver}
for _, field := range m.columnsHasValue("") {
m.setValueByColumn(field.DbName, field.Value, s.value)
}
}
case map[string]interface{}:
for key, value := range query.(map[string]interface{}) {
m.setValueByColumn(key, value, s.value)
}
case []interface{}:
for _, obj := range query.([]interface{}) {
switch reflect.ValueOf(obj).Kind() {
case reflect.Struct:
m := &Model{data: obj, driver: s.driver}
for _, field := range m.columnsHasValue("") {
m.setValueByColumn(field.DbName, field.Value, s.value)
}
case reflect.Map:
for key, value := range obj.(map[string]interface{}) {
m.setValueByColumn(key, value, s.value)
}
}
}
case interface{}:
m := &Model{data: query, driver: s.driver}
for _, field := range m.columnsHasValue("") {

View File

@ -870,14 +870,14 @@ func TestFindOrInitialize(t *testing.T) {
t.Errorf("user should be initialized with search value and attrs")
}
db.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user4)
db.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
t.Errorf("user should be initialized with search value and assigned attrs")
}
db.Save(&User{Name: "find or init", Age: 33})
db.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user5)
db.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 {
t.Errorf("user should be found and not initialized by Attrs")
}
@ -910,7 +910,7 @@ func TestFindOrCreate(t *testing.T) {
t.Errorf("user should be created with inline search value")
}
db.Where(&User{Name: "find or create 3"}).Attrs(User{Age: 44}).FirstOrCreate(&user4)
db.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4)
if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 {
t.Errorf("user should be created with search value and attrs")
}
@ -920,7 +920,7 @@ func TestFindOrCreate(t *testing.T) {
t.Errorf("user should be created with search value and assigned attrs")
}
db.Where(&User{Name: "find or create"}).Attrs(User{Age: 44}).FirstOrInit(&user5)
db.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5)
if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 {
t.Errorf("user should be found and not initialized by Attrs")
}

View File

@ -30,12 +30,12 @@ func (s *DB) First(out interface{}, where ...interface{}) *Chain {
return s.buildChain().First(out, where...)
}
func (s *DB) Attrs(attrs interface{}) *Chain {
return s.buildChain().Attrs(attrs)
func (s *DB) Attrs(attrs ...interface{}) *Chain {
return s.buildChain().Attrs(attrs...)
}
func (s *DB) Assign(attrs interface{}) *Chain {
return s.buildChain().Assign(attrs)
func (s *DB) Assign(attrs ...interface{}) *Chain {
return s.buildChain().Assign(attrs...)
}
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *Chain {

View File

@ -151,9 +151,8 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[stri
}
}
field := data.FieldByName("UpdatedAt")
if field.IsValid() && values["updated_at"] != nil && len(results) > 0 {
data.FieldByName("UpdatedAt").Set(reflect.ValueOf(time.Now()))
if values["updated_at"] != nil && len(results) > 0 {
setFieldValue(data.FieldByName("UpdatedAt"), time.Now())
}
result := len(results) > 0
return map[string]interface{}{}, result
@ -238,9 +237,16 @@ func (m *Model) missingColumns() (results []string) {
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
data := reflect.Indirect(reflect.ValueOf(out))
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
}
field := data.FieldByName(snakeToUpperCamel(name))
func setFieldValue(field reflect.Value, value interface{}) {
if field.IsValid() {
field.Set(reflect.ValueOf(value))
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(reflect.ValueOf(value).Int())
default:
field.Set(reflect.ValueOf(value))
}
}
}

View File

@ -29,6 +29,23 @@ func snakeToUpperCamel(s string) string {
return buf.String()
}
func toSearchableMap(attrs ...interface{}) (result interface{}) {
if len(attrs) > 1 {
if str, ok := attrs[0].(string); ok {
result = map[string]interface{}{str: attrs[1]}
}
} else if len(attrs) == 1 {
if attr, ok := attrs[0].(map[string]interface{}); ok {
result = attr
}
if attr, ok := attrs[0].(interface{}); ok {
result = attr
}
}
return
}
func debug(value interface{}) {
fmt.Printf("***************\n")
fmt.Printf("%+v\n\n", value)