From 953c347ba70bc1e591815ed0342f161e14c24445 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Sep 2014 19:03:01 +0800 Subject: [PATCH] Refactor Scope --- association.go | 32 +++++++++++++++++--------------- association_test.go | 4 +++- callback_create.go | 2 +- callback_shared.go | 26 ++++++++++++-------------- callback_update.go | 2 +- field.go | 26 +++++++++++++++++++++----- scope.go | 33 +++++++++++++++++---------------- scope_private.go | 22 +++++++++++----------- utils.go | 4 ++-- utils_private.go | 20 ++------------------ 10 files changed, 87 insertions(+), 84 deletions(-) diff --git a/association.go b/association.go index 2408476b..6a4dc600 100644 --- a/association.go +++ b/association.go @@ -28,27 +28,29 @@ func (association *Association) Find(value interface{}) *Association { func (association *Association) Append(values ...interface{}) *Association { scope := association.Scope - field := scope.IndirectValue().FieldByName(association.Column) + field := association.Field + fieldType := field.Field.Type() + for _, value := range values { reflectvalue := reflect.ValueOf(value) if reflectvalue.Kind() == reflect.Ptr { if reflectvalue.Elem().Kind() == reflect.Struct { - if field.Type().Elem().Kind() == reflect.Ptr { - field.Set(reflect.Append(field, reflectvalue)) - } else if field.Type().Elem().Kind() == reflect.Struct { - field.Set(reflect.Append(field, reflectvalue.Elem())) + if fieldType.Elem().Kind() == reflect.Ptr { + field.Set(reflect.Append(field.Field, reflectvalue)) + } else if fieldType.Elem().Kind() == reflect.Struct { + field.Set(reflect.Append(field.Field, reflectvalue.Elem())) } } else if reflectvalue.Elem().Kind() == reflect.Slice { - if field.Type().Elem().Kind() == reflect.Ptr { - field.Set(reflect.AppendSlice(field, reflectvalue)) - } else if field.Type().Elem().Kind() == reflect.Struct { - field.Set(reflect.AppendSlice(field, reflectvalue.Elem())) + if fieldType.Elem().Kind() == reflect.Ptr { + field.Set(reflect.AppendSlice(field.Field, reflectvalue)) + } else if fieldType.Elem().Kind() == reflect.Struct { + field.Set(reflect.AppendSlice(field.Field, reflectvalue.Elem())) } } - } else if reflectvalue.Kind() == reflect.Struct && field.Type().Elem().Kind() == reflect.Struct { - field.Set(reflect.Append(field, reflectvalue)) - } else if reflectvalue.Kind() == reflect.Slice && field.Type().Elem() == reflectvalue.Type().Elem() { - field.Set(reflect.AppendSlice(field, reflectvalue)) + } else if reflectvalue.Kind() == reflect.Struct && fieldType.Elem().Kind() == reflect.Struct { + field.Set(reflect.Append(field.Field, reflectvalue)) + } else if reflectvalue.Kind() == reflect.Slice && fieldType.Elem() == reflectvalue.Type().Elem() { + field.Set(reflect.AppendSlice(field.Field, reflectvalue)) } else { association.err(errors.New("invalid association type")) } @@ -107,7 +109,7 @@ func (association *Association) Replace(values ...interface{}) *Association { relationship := association.Field.Relationship scope := association.Scope if relationship.Kind == "many_to_many" { - field := scope.IndirectValue().FieldByName(association.Column) + field := association.Field.Field oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) association.Append(values...) @@ -154,7 +156,7 @@ func (association *Association) Count() int { count := -1 relationship := association.Field.Relationship scope := association.Scope - field := scope.IndirectValue().FieldByName(association.Column) + field := association.Field.Field fieldValue := field.Interface() newScope := scope.New(fieldValue) diff --git a/association_test.go b/association_test.go index 37a904bb..0a2e3fca 100644 --- a/association_test.go +++ b/association_test.go @@ -158,10 +158,12 @@ func TestManyToMany(t *testing.T) { languageA := Language{Name: "AA"} DB.Save(&languageA) DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA) + languageC := Language{Name: "CC"} DB.Save(&languageC) DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) - DB.Model(&User{Id: user.Id}).Association("Languages").Append([]Language{{Name: "DD"}, {Name: "EE"}}) + + DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} diff --git a/callback_create.go b/callback_create.go index 5e51e6d9..1e228319 100644 --- a/callback_create.go +++ b/callback_create.go @@ -28,7 +28,7 @@ func Create(scope *Scope) { for _, field := range scope.Fields() { if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) { columns = append(columns, scope.Quote(field.DBName)) - sqls = append(sqls, scope.AddToVars(field.Value)) + sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } } diff --git a/callback_shared.go b/callback_shared.go index 30ccab87..f13cec9d 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -19,20 +19,18 @@ func SaveBeforeAssociations(scope *Scope) { if !field.IsBlank && !field.IsIgnored { relationship := field.Relationship if relationship != nil && relationship.Kind == "belongs_to" { - value := reflect.ValueOf(field.Value) + value := field.Field newDB := scope.NewDB() - if value.CanAddr() { - scope.Err(newDB.Save(value.Addr().Interface()).Error) - } else { + if !value.CanAddr() { // If can't take address, then clone the value and set it back - value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem() - for _, f := range newDB.NewScope(field.Value).Fields() { - value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + value = reflect.New(value.Type()).Elem() + for _, f := range newDB.NewScope(field.Field.Interface()).Fields() { + value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface())) } - scope.Err(newDB.Save(value.Addr().Interface()).Error) scope.SetColumn(field.Name, value.Interface()) } + scope.Err(newDB.Save(value.Addr().Interface()).Error) if relationship.ForeignKey != "" { scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) @@ -48,7 +46,7 @@ func SaveAfterAssociations(scope *Scope) { relationship := field.Relationship if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := reflect.ValueOf(field.Value) + value := field.Field switch value.Kind() { case reflect.Slice: @@ -89,14 +87,14 @@ func SaveAfterAssociations(scope *Scope) { newDB := scope.NewDB() if value.CanAddr() { if relationship.ForeignKey != "" { - newDB.NewScope(field.Value).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) + newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } - scope.Err(newDB.Save(field.Value).Error) + scope.Err(newDB.Save(value.Addr().Interface()).Error) } else { - destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() + destValue := reflect.New(field.Field.Type()).Elem() - for _, f := range newDB.NewScope(field.Value).Fields() { - destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) + for _, f := range newDB.NewScope(field.Field.Interface()).Fields() { + destValue.FieldByName(f.Name).Set(f.Field) } elem := destValue.Addr().Interface() diff --git a/callback_update.go b/callback_update.go index 36654281..c59bcf1a 100644 --- a/callback_update.go +++ b/callback_update.go @@ -49,7 +49,7 @@ func Update(scope *Scope) { } else { for _, field := range scope.Fields() { if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } } diff --git a/field.go b/field.go index 6ff0cfb7..80eca2f6 100644 --- a/field.go +++ b/field.go @@ -17,7 +17,6 @@ type Field struct { Name string DBName string Field reflect.Value - Value interface{} Tag reflect.StructTag Relationship *relationship IsNormal bool @@ -26,12 +25,29 @@ type Field struct { IsPrimaryKey bool } -func (f *Field) IsScanner() bool { - _, isScanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) +func (field *Field) IsScanner() bool { + _, isScanner := reflect.New(field.Field.Type()).Interface().(sql.Scanner) return isScanner } -func (f *Field) IsTime() bool { - _, isTime := f.Value.(time.Time) +func (field *Field) IsTime() bool { + _, isTime := field.Field.Interface().(time.Time) return isTime } + +func (field *Field) Set(value interface{}) (result bool) { + if field.Field.IsValid() && field.Field.CanAddr() { + result = true + if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { + scanner.Scan(value) + } else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) { + field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type())) + } else { + result = false + } + } + if result { + field.IsBlank = isBlank(field.Field) + } + return +} diff --git a/scope.go b/scope.go index 5bf66d0b..9db9f259 100644 --- a/scope.go +++ b/scope.go @@ -150,13 +150,17 @@ func (scope *Scope) FieldValueByName(name string) (interface{}, bool) { } // SetColumn to set the column's value -func (scope *Scope) SetColumn(column string, value interface{}) bool { - if scope.Value == nil { - return false - } - for _, field := range scope.Fields() { - if field.Name == column || field.DBName == column { - return setFieldValue(field.Field, value) +func (scope *Scope) SetColumn(column interface{}, value interface{}) bool { + if field, ok := column.(*Field); ok { + return field.Set(value) + } else if str, ok := column.(string); ok { + if scope.Value == nil { + return false + } + for _, field := range scope.Fields() { + if field.Name == str || field.DBName == str { + return field.Set(value) + } } } return false @@ -267,11 +271,9 @@ func (scope *Scope) CombinedConditionSql() string { } func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - if scope.Value != nil { - if scope.IndirectValue().Kind() == reflect.Struct { - if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok { - return scope.fieldFromStruct(f, true)[0], true - } + for _, field := range scope.Fields() { + if field.Name == name { + return field, true } } return nil, false @@ -285,7 +287,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio value := scope.IndirectValue().FieldByName(fieldStruct.Name) indirectValue := reflect.Indirect(value) field.Field = value - field.Value = value.Interface() field.IsBlank = isBlank(value) // Search for primary key tag identifier @@ -416,9 +417,9 @@ func (scope *Scope) Fields(noRelations ...bool) map[string]*Field { } } - // if withRelation { - // scope.fields = fields - // } + if withRelation { + scope.fields = fields + } return fields } diff --git a/scope_private.go b/scope_private.go index 3cfb47ab..b6949db3 100644 --- a/scope_private.go +++ b/scope_private.go @@ -42,7 +42,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri var sqls []string for _, field := range scope.New(value).Fields() { if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -103,7 +103,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string var sqls []string for _, field := range scope.New(value).Fields() { if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -264,17 +264,17 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore } for key, value := range values { - if field := data.FieldByName(SnakeToUpperCamel(key)); field.IsValid() { + if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() { func() { defer func() { if err := recover(); err != nil { hasUpdate = true - setFieldValue(field, value) + field.Set(value) } }() - if field.Interface() != value { - switch field.Kind() { + if field.Field.Interface() != value { + switch field.Field.Kind() { case reflect.Int, reflect.Int32, reflect.Int64: if s, ok := value.(string); ok { i, err := strconv.Atoi(s) @@ -283,13 +283,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore } } - if field.Int() != reflect.ValueOf(value).Int() { + if field.Field.Int() != reflect.ValueOf(value).Int() { hasUpdate = true - setFieldValue(field, value) + field.Set(value) } default: hasUpdate = true - setFieldValue(field, value) + field.Set(value) } } }() @@ -324,8 +324,8 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) { additionalType = additionalType + "DEFAULT " + value } - value := field.Value - reflectValue := reflect.ValueOf(value) + value := field.Field.Interface() + reflectValue := field.Field switch reflectValue.Kind() { case reflect.Slice: diff --git a/utils.go b/utils.go index b75aa1b6..f4fb6625 100644 --- a/utils.go +++ b/utils.go @@ -68,7 +68,7 @@ func ToSnake(u string) string { } s := strings.ToLower(buf.String()) - go smap.Set(u, s) + smap.Set(u, s) return s } @@ -86,7 +86,7 @@ func SnakeToUpperCamel(s string) string { } u := buf.String() - go umap.Set(s, u) + umap.Set(s, u) return u } diff --git a/utils_private.go b/utils_private.go index 75c38fcd..cf941f61 100644 --- a/utils_private.go +++ b/utils_private.go @@ -1,7 +1,6 @@ package gorm import ( - "database/sql" "fmt" "os" "reflect" @@ -11,7 +10,7 @@ import ( ) func fileWithLineNum() string { - for i := 1; i < 15; i++ { + for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) @@ -20,21 +19,6 @@ func fileWithLineNum() string { return "" } -func setFieldValue(field reflect.Value, value interface{}) (result bool) { - result = false - if field.IsValid() && field.CanAddr() { - result = true - if scanner, ok := field.Addr().Interface().(sql.Scanner); ok { - scanner.Scan(value) - } else if reflect.TypeOf(value).ConvertibleTo(field.Type()) { - field.Set(reflect.ValueOf(value).Convert(field.Type())) - } else { - result = false - } - } - return -} - func isBlank(value reflect.Value) bool { return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) } @@ -82,7 +66,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { scope := Scope{Value: values} for _, field := range scope.Fields() { if !field.IsBlank { - attrs[field.DBName] = field.Value + attrs[field.DBName] = field.Field.Interface() } } }