diff --git a/callback_shared.go b/callback_shared.go index 7fab779e..5e501c4c 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -28,7 +28,7 @@ func SaveBeforeAssociations(scope *Scope) { for _, f := range newDB.NewScope(field.Field.Addr().Interface()).Fields() { value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface())) } - scope.SetColumn(field.Name, value.Interface()) + scope.SetColumn(field, value.Interface()) } scope.Err(newDB.Save(value.Addr().Interface()).Error) diff --git a/field.go b/field.go index b69dfceb..66be9c31 100644 --- a/field.go +++ b/field.go @@ -40,14 +40,17 @@ func (field *Field) Set(value interface{}) (err error) { // Fields get value's fields func (scope *Scope) Fields() map[string]*Field { - fields := map[string]*Field{} - structFields := scope.GetStructFields() + if scope.fields == nil { + fields := map[string]*Field{} + structFields := scope.GetStructFields() - for _, structField := range structFields { - fields[structField.DBName] = scope.getField(structField) + for _, structField := range structFields { + fields[structField.DBName] = scope.getField(structField) + } + + scope.fields = fields } - - return fields + return scope.fields } func (scope *Scope) getField(structField *StructField) *Field { @@ -55,9 +58,10 @@ func (scope *Scope) getField(structField *StructField) *Field { indirectValue := scope.IndirectValue() if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct { for _, name := range structField.Names { - indirectValue = reflect.Indirect(indirectValue.FieldByName(name)) + indirectValue = reflect.Indirect(indirectValue).FieldByName(name) } field.Field = indirectValue } + field.IsBlank = isBlank(indirectValue) return &field } diff --git a/model_struct.go b/model_struct.go index 01e8f95d..ed8e579f 100644 --- a/model_struct.go +++ b/model_struct.go @@ -96,20 +96,28 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) + if !reflectValue.IsValid() { + return &modelStruct + } + if reflectValue.Kind() == reflect.Slice { reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) } - scopeTyp := reflectValue.Type() + scopeType := reflectValue.Type() + + if scopeType.Kind() != reflect.Struct { + return &modelStruct + } // Set tablename - if fm := reflect.New(scopeTyp).MethodByName("TableName"); fm.IsValid() { + if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { if results := fm.Call([]reflect.Value{}); len(results) > 0 { if name, ok := results[0].Interface().(string); ok { modelStruct.TableName = name } } } else { - modelStruct.TableName = ToSnake(scopeTyp.Name()) + modelStruct.TableName = ToSnake(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { for index, reg := range pluralMapKeys { if reg.MatchString(modelStruct.TableName) { @@ -120,8 +128,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Set fields - for i := 0; i < scopeTyp.NumField(); i++ { - fieldStruct := scopeTyp.Field(i) + for i := 0; i < scopeType.NumField(); i++ { + fieldStruct := scopeType.Field(i) if !ast.IsExported(fieldStruct.Name) { continue } @@ -156,7 +164,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsScanner, field.IsNormal = true, true } - if _, isTime := reflect.New(indirectType).Interface().(time.Time); isTime { + if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { field.IsTime, field.IsNormal = true, true } @@ -181,7 +189,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { kind := "has_many" if foreignKey == "" { - foreignKey = indirectType.Name() + "Id" + foreignKey = scopeType.Name() + "Id" } if associationForeignKey == "" { @@ -199,6 +207,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ForeignType: foreignType, ForeignFieldName: foreignKey, AssociationForeignFieldName: associationForeignKey, + ForeignDBName: ToSnake(foreignKey), + AssociationForeignDBName: ToSnake(associationForeignKey), Kind: kind, } } else { @@ -215,22 +225,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var belongsToForeignKey, hasOneForeignKey, kind string if foreignKey == "" { - belongsToForeignKey = indirectType.Name() + "Id" - hasOneForeignKey = scopeTyp.Name() + "Id" + belongsToForeignKey = field.Name + "Id" + hasOneForeignKey = scopeType.Name() + "Id" } else { belongsToForeignKey = foreignKey hasOneForeignKey = foreignKey } - if _, ok := scopeTyp.FieldByName(belongsToForeignKey); ok { - foreignKey = belongsToForeignKey + if _, ok := scopeType.FieldByName(belongsToForeignKey); ok { kind = "belongs_to" + foreignKey = belongsToForeignKey } else { foreignKey = hasOneForeignKey kind = "has_one" } - field.Relationship = &Relationship{ForeignFieldName: foreignKey, ForeignType: foreignType, Kind: kind} + field.Relationship = &Relationship{ + ForeignFieldName: foreignKey, + ForeignDBName: ToSnake(foreignKey), + ForeignType: foreignType, + Kind: kind, + } } default: @@ -248,7 +263,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.PrimaryKeyField = field } - scope.generateSqlTag(field) + if scope.db != nil { + scope.generateSqlTag(field) + } } } diff --git a/scope.go b/scope.go index 7e91eaf4..1185be9d 100644 --- a/scope.go +++ b/scope.go @@ -124,16 +124,14 @@ func (scope *Scope) PrimaryKeyValue() interface{} { // HasColumn to check if has column func (scope *Scope) HasColumn(column string) bool { - clone := scope - if scope.IndirectValue().Kind() == reflect.Slice { - value := reflect.New(scope.IndirectValue().Type().Elem()).Interface() - clone = scope.New(value) + for _, field := range scope.GetStructFields() { + if !field.IsIgnored { + if field.Name == column || field.DBName == column { + return true + } + } } - - dbName := ToSnake(column) - - field, hasColumn := clone.Fields()[dbName] - return hasColumn && !field.IsIgnored + return false } // FieldValueByName to get column's value and existence diff --git a/scope_private.go b/scope_private.go index 56e1b64f..217df03f 100644 --- a/scope_private.go +++ b/scope_private.go @@ -415,76 +415,47 @@ func (scope *Scope) typeName() string { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) - fromScopeType := scope.typeName() - toScopeType := toScope.typeName() - scopeType := "" - + fromFields := scope.Fields() + toFields := toScope.Fields() for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - if keys := strings.Split(foreignKey, "."); len(keys) > 1 { - scopeType = keys[0] - foreignKey = keys[1] - } + fromField := fromFields[ToSnake(foreignKey)] + toField := toFields[ToSnake(foreignKey)] - var relationship *Relationship - var field *Field - var scopeHasField bool - if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField { - relationship = field.Relationship - } - - if scopeType == "" || scopeType == fromScopeType { - if scopeHasField { - if relationship != nil && relationship.ForeignFieldName != "" { - foreignKey = relationship.ForeignFieldName - } - - if relationship != nil && relationship.Kind == "many_to_many" { - if relationship.ForeignType != "" { - scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations")) - } + if fromField != nil { + if relationship := fromField.Relationship; relationship != nil { + if relationship.Kind == "many_to_many" { joinSql := fmt.Sprintf( "INNER JOIN %v ON %v.%v = %v.%v", scope.Quote(relationship.JoinTable), scope.Quote(relationship.JoinTable), - scope.Quote(ToSnake(relationship.AssociationForeignFieldName)), + scope.Quote(relationship.AssociationForeignDBName), toScope.QuotedTableName(), scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignFieldName))) - toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) - return scope - } - - // has many or has one - if toScope.HasColumn(foreignKey) { - toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue()) - if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) { - toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName()) - } - toScope.callCallbacks(scope.db.parent.callback.queries) - return scope - } - - // belongs to - if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil { + whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName)) + scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error) + } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) { - scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations")) - return scope + scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) + } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) + query := toScope.db.Where(sql, scope.PrimaryKeyValue()) + if relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName()) } - toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries) - return scope + scope.Err(query.Find(value).Error) } + } else { + sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) + scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) } - } - - if scopeType == "" || scopeType == toScopeType { - // has many or has one in foreign scope - if toScope.HasColumn(foreignKey) { - sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))) - return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) - } + return scope + } else if toField != nil { + sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) + scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + return scope } } + scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) return scope } @@ -553,12 +524,12 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on var table = scope.TableName() var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var query = ` - ALTER TABLE %s - ADD CONSTRAINT %s - FOREIGN KEY (%s) - REFERENCES %s - ON DELETE %s - ON UPDATE %s; + ALTER TABLE %s + ADD CONSTRAINT %s + FOREIGN KEY (%s) + REFERENCES %s + ON DELETE %s + ON UPDATE %s; ` scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec() }