diff --git a/association.go b/association.go index 2df571f5..cd8fd912 100644 --- a/association.go +++ b/association.go @@ -63,9 +63,11 @@ func (association *Association) Replace(values ...interface{}) *Association { var associationForeignFieldNames []string if relationship.Kind == "many_to_many" { // if many to many relations, get association fields name from association foreign keys - associationFields := scope.New(reflect.New(field.Type()).Interface()).Fields() + associationScope := scope.New(reflect.New(field.Type()).Interface()) for _, dbName := range relationship.AssociationForeignFieldNames { - associationForeignFieldNames = append(associationForeignFieldNames, associationFields[dbName].Name) + if field, ok := associationScope.FieldByName(dbName); ok { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } } } else { // If other relations, use primary keys @@ -84,15 +86,12 @@ func (association *Association) Replace(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { // if many to many relations, delete related relations from join table - - // get source fields name from source foreign keys - var ( - sourceFields = scope.Fields() - sourceForeignFieldNames []string - ) + var sourceForeignFieldNames []string for _, dbName := range relationship.ForeignFieldNames { - sourceForeignFieldNames = append(sourceForeignFieldNames, sourceFields[dbName].Name) + if field, ok := scope.FieldByName(dbName); ok { + sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) + } } if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { @@ -147,10 +146,12 @@ func (association *Association) Delete(values ...interface{}) *Association { } // get association's foreign fields name - var associationFields = scope.New(reflect.New(field.Type()).Interface()).Fields() + var associationScope = scope.New(reflect.New(field.Type()).Interface()) var associationForeignFieldNames []string for _, associationDBName := range relationship.AssociationForeignFieldNames { - associationForeignFieldNames = append(associationForeignFieldNames, associationFields[associationDBName].Name) + if field, ok := associationScope.FieldByName(associationDBName); ok { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } } // association value's foreign keys diff --git a/callback_create.go b/callback_create.go index 314f505a..2a0b9b2a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -45,10 +45,9 @@ func createCallback(scope *Scope) { var ( columns, placeholders []string blankColumnsWithDefaultValue []string - fields = scope.Fields() ) - for _, field := range fields { + for _, field := range scope.Fields() { if scope.changeableField(field) { if field.IsNormal { if !field.IsPrimaryKey || !field.IsBlank { @@ -62,7 +61,7 @@ func createCallback(scope *Scope) { } } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { columns = append(columns, scope.Quote(foreignField.DBName)) placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) } diff --git a/callback_query.go b/callback_query.go index aa643557..acdf149b 100644 --- a/callback_query.go +++ b/callback_query.go @@ -68,7 +68,7 @@ func queryCallback(scope *Scope) { elem = reflect.New(resultType).Elem() } - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) + scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap()) if isSlice { if isPtr { diff --git a/callback_query_preload.go b/callback_query_preload.go index e57caad0..1c9bbc84 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -255,7 +255,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface for rows.Next() { var ( elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() + fields = scope.New(elem.Addr().Interface()).fieldsMap() ) // register foreign keys in join tables @@ -284,7 +284,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface indirectScopeValue = scope.IndirectValue() fieldsSourceMap = map[string]reflect.Value{} foreignFieldNames = []string{} - fields = scope.Fields() + fields = scope.fieldsMap() ) for _, dbName := range relation.ForeignFieldNames { diff --git a/callback_update.go b/callback_update.go index 287b927f..192d8a9e 100644 --- a/callback_update.go +++ b/callback_update.go @@ -60,14 +60,13 @@ func updateCallback(scope *Scope) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { - fields := scope.Fields() - for _, field := range fields { + for _, field := range scope.Fields() { if scope.changeableField(field) { if !field.IsPrimaryKey && field.IsNormal { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { - if foreignField := fields[foreignKey]; !scope.changeableField(foreignField) { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) } diff --git a/field.go b/field.go index 2f0daf77..ff252e1a 100644 --- a/field.go +++ b/field.go @@ -56,29 +56,34 @@ func (field *Field) Set(value interface{}) (err error) { } // Fields get value's fields -func (scope *Scope) Fields() map[string]*Field { - if scope.fields == nil { - var ( - fields = map[string]*Field{} - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) +func (scope *Scope) Fields() []*Field { + var ( + fields []*Field + indirectScopeValue = scope.IndirectValue() + isStruct = indirectScopeValue.Kind() == reflect.Struct + ) - for _, structField := range scope.GetModelStruct().StructFields { - if field, ok := fields[structField.DBName]; !ok || field.IsIgnored { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields[structField.DBName] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)} - } else { - fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} - } + for _, structField := range scope.GetModelStruct().StructFields { + if isStruct { + fieldValue := indirectScopeValue + for _, name := range structField.Names { + fieldValue = reflect.Indirect(fieldValue).FieldByName(name) } + fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) + } else { + fields = append(fields, &Field{StructField: structField, IsBlank: true}) } - - scope.fields = fields } - return scope.fields + + return fields +} + +func (scope *Scope) fieldsMap() map[string]*Field { + var results = map[string]*Field{} + for _, field := range scope.Fields() { + if field.IsNormal { + results[field.DBName] = field + } + } + return results } diff --git a/field_test.go b/field_test.go index 2172b059..30e9a778 100644 --- a/field_test.go +++ b/field_test.go @@ -32,12 +32,16 @@ type CalculateFieldCategory struct { func TestCalculateField(t *testing.T) { var field CalculateField - fields := DB.NewScope(&field).Fields() - if fields["children"].Relationship == nil || fields["category"].Relationship == nil { + var scope = DB.NewScope(&field) + if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { t.Errorf("Should calculate fields correctly for the first time") } - if field, ok := fields["embedded_name"]; !ok { + if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { + t.Errorf("Should calculate fields correctly for the first time") + } + + if field, ok := scope.FieldByName("embedded_name"); !ok { t.Errorf("should find embedded field") } else if _, ok := field.TagSettings["NOT NULL"]; !ok { t.Errorf("should find embedded field's tag settings") diff --git a/join_table_handler.go b/join_table_handler.go index 9e6c027a..6251cd22 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -74,11 +74,15 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin if s.Source.ModelType == modelType { for _, foreignKey := range s.Source.ForeignKeys { - values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + values[foreignKey.DBName] = field.Field.Interface() + } } } else if s.Destination.ModelType == modelType { for _, foreignKey := range s.Destination.ForeignKeys { - values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + values[foreignKey.DBName] = field.Field.Interface() + } } } } @@ -151,7 +155,9 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so for _, foreignKey := range s.Source.ForeignKeys { foreignDBNames = append(foreignDBNames, foreignKey.DBName) - foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + foreignFieldNames = append(foreignFieldNames, field.Name) + } } foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) diff --git a/main.go b/main.go index 46f35d01..09b6df74 100644 --- a/main.go +++ b/main.go @@ -232,7 +232,7 @@ func (s *DB) ScanRows(rows *sql.Rows, value interface{}) error { ) if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) + scope.scan(rows, columns, scope.fieldsMap()) } return clone.Error diff --git a/scope.go b/scope.go index c84a8179..9f4e821d 100644 --- a/scope.go +++ b/scope.go @@ -100,10 +100,11 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } -func (scope *Scope) PrimaryFields() []*Field { - var fields = []*Field{} - for _, field := range scope.GetModelStruct().PrimaryFields { - fields = append(fields, scope.Fields()[field.DBName]) +func (scope *Scope) PrimaryFields() (fields []*Field) { + for _, field := range scope.Fields() { + if field.IsPrimaryKey { + fields = append(fields, field) + } } return fields } @@ -111,11 +112,11 @@ func (scope *Scope) PrimaryFields() []*Field { func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 { - if field, ok := scope.Fields()["id"]; ok { + if field, ok := scope.FieldByName("id"); ok { return field } } - return scope.Fields()[primaryFields[0].DBName] + return scope.PrimaryFields()[0] } return nil } @@ -164,20 +165,23 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { updateAttrs[field.DBName] = value return field.Set(value) } else if name, ok := column.(string); ok { - if field, ok := scope.Fields()[name]; ok { - updateAttrs[field.DBName] = value - return field.Set(value) + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + for _, field := range scope.Fields() { + if field.DBName == value { + updateAttrs[field.DBName] = value + return field.Set(value) + } + if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + mostMatchedField = field + } } - dbName := ToDBName(name) - if field, ok := scope.Fields()[dbName]; ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } - - if field, ok := scope.FieldByName(name); ok { - updateAttrs[field.DBName] = value - return field.Set(value) + if mostMatchedField != nil { + updateAttrs[mostMatchedField.DBName] = value + return mostMatchedField.Set(value) } } return errors.New("could not convert column to field") @@ -286,12 +290,20 @@ func (scope *Scope) CombinedConditionSql() string { // FieldByName find gorm.Field with name and db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + for _, field := range scope.Fields() { if field.Name == name || field.DBName == name { return field, true } + if field.DBName == dbName { + mostMatchedField = field + } } - return nil, false + return mostMatchedField, mostMatchedField != nil } // Raw set sql @@ -390,12 +402,12 @@ func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields map[string]*Field) { +func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) { var values = make([]interface{}, len(columns)) var ignored interface{} for index, column := range columns { - if field, ok := fields[column]; ok { + if field, ok := fieldsMap[column]; ok { if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { @@ -411,7 +423,7 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields map[string]*Fi scope.Err(rows.Scan(values...)) for index, column := range columns { - if field, ok := fields[column]; ok { + if field, ok := fieldsMap[column]; ok { if field.Field.Kind() != reflect.Ptr { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { field.Field.Set(v) diff --git a/scope_private.go b/scope_private.go index b2f504d1..31db4a0b 100644 --- a/scope_private.go +++ b/scope_private.go @@ -437,21 +437,10 @@ func (scope *Scope) shouldSaveAssociations() bool { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) - fromFields := scope.Fields() - toFields := toScope.Fields() for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - var fromField, toField *Field - if field, ok := scope.FieldByName(foreignKey); ok { - fromField = field - } else { - fromField = fromFields[ToDBName(foreignKey)] - } - if field, ok := toScope.FieldByName(foreignKey); ok { - toField = field - } else { - toField = toFields[ToDBName(foreignKey)] - } + fromField, _ := scope.FieldByName(foreignKey) + toField, _ := toScope.FieldByName(foreignKey) if fromField != nil { if relationship := fromField.Relationship; relationship != nil { @@ -515,7 +504,7 @@ func (scope *Scope) createJoinTable(field *StructField) { var sqlTypes, primaryKeys []string for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.Fields()[fieldName]; ok { + if field, ok := scope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" @@ -525,7 +514,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.Fields()[fieldName]; ok { + if field, ok := toScope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"