diff --git a/association.go b/association.go index 539218bf..2408476b 100644 --- a/association.go +++ b/association.go @@ -93,9 +93,9 @@ func (association *Association) Delete(values ...interface{}) *Association { } else { relationship := association.Field.Relationship // many to many - if relationship.kind == "many_to_many" { - whereSql := fmt.Sprintf("%v.%v IN (?)", relationship.joinTable, association.Scope.Quote(ToSnake(relationship.associationForeignKey))) - association.Scope.db.Model("").Table(relationship.joinTable).Where(whereSql, primaryKeys).Delete("") + if relationship.Kind == "many_to_many" { + whereSql := fmt.Sprintf("%v.%v IN (?)", relationship.JoinTable, association.Scope.Quote(ToSnake(relationship.AssociationForeignKey))) + association.Scope.db.Model("").Table(relationship.JoinTable).Where(whereSql, primaryKeys).Delete("") } else { association.err(errors.New("delete only support many to many")) } @@ -106,7 +106,7 @@ func (association *Association) Delete(values ...interface{}) *Association { func (association *Association) Replace(values ...interface{}) *Association { relationship := association.Field.Relationship scope := association.Scope - if relationship.kind == "many_to_many" { + if relationship.Kind == "many_to_many" { field := scope.IndirectValue().FieldByName(association.Column) oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) @@ -130,8 +130,8 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } - whereSql := fmt.Sprintf("%v.%v NOT IN (?)", relationship.joinTable, scope.Quote(ToSnake(relationship.associationForeignKey))) - scope.db.Model("").Table(relationship.joinTable).Where(whereSql, addedPrimaryKeys).Delete("") + whereSql := fmt.Sprintf("%v.%v NOT IN (?)", relationship.JoinTable, scope.Quote(ToSnake(relationship.AssociationForeignKey))) + scope.db.Model("").Table(relationship.JoinTable).Where(whereSql, addedPrimaryKeys).Delete("") } else { association.err(errors.New("replace only support many to many")) } @@ -141,9 +141,9 @@ func (association *Association) Replace(values ...interface{}) *Association { func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope - if relationship.kind == "many_to_many" { - whereSql := fmt.Sprintf("%v.%v = ?", relationship.joinTable, scope.Quote(ToSnake(relationship.foreignKey))) - scope.db.Model("").Table(relationship.joinTable).Where(whereSql, association.PrimaryKey).Delete("") + if relationship.Kind == "many_to_many" { + whereSql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(ToSnake(relationship.ForeignKey))) + scope.db.Model("").Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("") } else { association.err(errors.New("clear only support many to many")) } @@ -158,25 +158,25 @@ func (association *Association) Count() int { fieldValue := field.Interface() newScope := scope.New(fieldValue) - if relationship.kind == "many_to_many" { + if relationship.Kind == "many_to_many" { whereSql := fmt.Sprintf("%v.%v IN (SELECT %v.%v FROM %v WHERE %v.%v = ?)", newScope.QuotedTableName(), scope.Quote(newScope.PrimaryKey()), - relationship.joinTable, - scope.Quote(ToSnake(relationship.associationForeignKey)), - relationship.joinTable, - relationship.joinTable, - scope.Quote(ToSnake(relationship.foreignKey))) + relationship.JoinTable, + scope.Quote(ToSnake(relationship.AssociationForeignKey)), + relationship.JoinTable, + relationship.JoinTable, + scope.Quote(ToSnake(relationship.ForeignKey))) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) - } else if relationship.kind == "has_many" { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.foreignKey))) + } else if relationship.Kind == "has_many" { + whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignKey))) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) - } else if relationship.kind == "has_one" { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.foreignKey) + } else if relationship.Kind == "has_one" { + whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.ForeignKey) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey).Count(&count) - } else if relationship.kind == "belongs_to" { - if v, ok := scope.FieldByName(association.Column); ok { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.foreignKey) + } else if relationship.Kind == "belongs_to" { + if v, ok := scope.FieldValueByName(association.Column); ok { + whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), relationship.ForeignKey) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count) } } diff --git a/callback_shared.go b/callback_shared.go index c0de840a..30ccab87 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -18,7 +18,7 @@ func SaveBeforeAssociations(scope *Scope) { for _, field := range scope.Fields() { if !field.IsBlank && !field.IsIgnored { relationship := field.Relationship - if relationship != nil && relationship.kind == "belongs_to" { + if relationship != nil && relationship.Kind == "belongs_to" { value := reflect.ValueOf(field.Value) newDB := scope.NewDB() @@ -34,8 +34,8 @@ func SaveBeforeAssociations(scope *Scope) { scope.SetColumn(field.Name, value.Interface()) } - if relationship.foreignKey != "" { - scope.SetColumn(relationship.foreignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) + if relationship.ForeignKey != "" { + scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) } } } @@ -47,7 +47,7 @@ func SaveAfterAssociations(scope *Scope) { if !field.IsBlank && !field.IsIgnored { relationship := field.Relationship if relationship != nil && - (relationship.kind == "has_one" || relationship.kind == "has_many" || relationship.kind == "many_to_many") { + (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := reflect.ValueOf(field.Value) switch value.Kind() { @@ -56,18 +56,18 @@ func SaveAfterAssociations(scope *Scope) { newDB := scope.NewDB() elem := value.Index(i).Addr().Interface() - if relationship.joinTable == "" && relationship.foreignKey != "" { - newDB.NewScope(elem).SetColumn(relationship.foreignKey, scope.PrimaryKeyValue()) + if relationship.JoinTable == "" && relationship.ForeignKey != "" { + newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } scope.Err(newDB.Save(elem).Error) - if relationship.joinTable != "" { + if relationship.JoinTable != "" { newScope := scope.New(elem) - joinTable := relationship.joinTable - foreignKey := ToSnake(relationship.foreignKey) + joinTable := relationship.JoinTable + foreignKey := ToSnake(relationship.ForeignKey) foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue()) - associationForeignKey := ToSnake(relationship.associationForeignKey) + associationForeignKey := ToSnake(relationship.AssociationForeignKey) associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue()) newScope.Raw(fmt.Sprintf( @@ -88,8 +88,8 @@ func SaveAfterAssociations(scope *Scope) { default: newDB := scope.NewDB() if value.CanAddr() { - if relationship.foreignKey != "" { - newDB.NewScope(field.Value).SetColumn(relationship.foreignKey, scope.PrimaryKeyValue()) + if relationship.ForeignKey != "" { + newDB.NewScope(field.Value).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } scope.Err(newDB.Save(field.Value).Error) } else { @@ -100,8 +100,8 @@ func SaveAfterAssociations(scope *Scope) { } elem := destValue.Addr().Interface() - if relationship.foreignKey != "" { - newDB.NewScope(elem).SetColumn(relationship.foreignKey, scope.PrimaryKeyValue()) + if relationship.ForeignKey != "" { + newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) } scope.Err(newDB.Save(elem).Error) scope.SetColumn(field.Name, destValue.Interface()) diff --git a/field.go b/field.go index 4134e58f..f49bc7fe 100644 --- a/field.go +++ b/field.go @@ -7,22 +7,22 @@ import ( ) type relationship struct { - joinTable string - foreignKey string - associationForeignKey string - kind string + JoinTable string + ForeignKey string + AssociationForeignKey string + Kind string } type Field struct { Name string DBName string Value interface{} - IsBlank bool - IsIgnored bool Tag reflect.StructTag SqlTag string - isPrimaryKey bool Relationship *relationship + IsBlank bool + IsIgnored bool + IsPrimaryKey bool } func (f *Field) IsScanner() bool { diff --git a/main.go b/main.go index 5f82a361..4df38ff0 100644 --- a/main.go +++ b/main.go @@ -366,7 +366,7 @@ func (s *DB) Association(column string) *Association { scopeType := scope.IndirectValue().Type() if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok { field = scope.fieldFromStruct(f) - if field.Relationship == nil || field.Relationship.foreignKey == "" { + if field.Relationship == nil || field.Relationship.ForeignKey == "" { scope.Err(errors.New(fmt.Sprintf("invalid association %v for %v", column, scopeType))) } } else { diff --git a/scope.go b/scope.go index f1a07b98..ae203d1d 100644 --- a/scope.go +++ b/scope.go @@ -112,13 +112,13 @@ func (scope *Scope) PrimaryKeyValue() interface{} { // HasColumn to check if has column func (scope *Scope) HasColumn(name string) bool { - _, result := scope.FieldByName(name) + _, result := scope.FieldValueByName(name) return result } -// FieldByName to get column's value and existence -func (scope *Scope) FieldByName(name string) (interface{}, bool) { - return FieldByName(name, scope.Value) +// FieldValueByName to get column's value and existence +func (scope *Scope) FieldValueByName(name string) (interface{}, bool) { + return FieldValueByName(name, scope.Value) } // SetColumn to set the column's value @@ -234,6 +234,18 @@ func (scope *Scope) CombinedConditionSql() string { scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() } +func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { + var f reflect.StructField + if scope.Value != nil { + if scope.IndirectValue().Kind() == reflect.Struct { + if f, ok = scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok { + field = scope.fieldFromStruct(f) + } + } + } + return +} + func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { var field Field field.Name = fieldStruct.Name @@ -246,68 +258,62 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { // Search for primary key tag identifier settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) - if _, ok := settings["PRIMARY_KEY"]; scope.PrimaryKey() == field.DBName || ok { - field.isPrimaryKey = true + if scope.PrimaryKey() == field.DBName { + field.IsPrimaryKey = true } - if field.isPrimaryKey { - scope.primaryKey = field.DBName - } + field.Tag = fieldStruct.Tag + field.SqlTag = scope.sqlTagForField(&field) - if scope.db != nil { - field.Tag = fieldStruct.Tag - field.SqlTag = scope.sqlTagForField(&field) + // parse association + typ := indirectValue.Type() + foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) + associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) + many2many := settings["MANY2MANY"] + scopeTyp := scope.IndirectValue().Type() - // parse association - typ := indirectValue.Type() - foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) - associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) - many2many := settings["MANY2MANY"] - scopeTyp := scope.IndirectValue().Type() + switch indirectValue.Kind() { + case reflect.Slice: + typ = typ.Elem() - switch indirectValue.Kind() { - case reflect.Slice: - typ = typ.Elem() + if typ.Kind() == reflect.Struct { + if foreignKey == "" { + foreignKey = scopeTyp.Name() + "Id" + } + if associationForeignKey == "" { + associationForeignKey = typ.Name() + "Id" + } - if typ.Kind() == reflect.Struct { + // if not many to many, foreign key could be null + if many2many == "" { + if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + foreignKey = "" + } + } + + field.Relationship = &relationship{ + JoinTable: many2many, + ForeignKey: foreignKey, + AssociationForeignKey: associationForeignKey, + Kind: "has_many", + } + + if many2many != "" { + field.Relationship.Kind = "many_to_many" + } + } + case reflect.Struct: + if !field.IsTime() && !field.IsScanner() { + if foreignKey == "" && scope.HasColumn(field.Name+"Id") { + field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} + } else if scope.HasColumn(foreignKey) { + field.Relationship = &relationship{ForeignKey: foreignKey, Kind: "belongs_to"} + } else { if foreignKey == "" { foreignKey = scopeTyp.Name() + "Id" } - if associationForeignKey == "" { - associationForeignKey = typ.Name() + "Id" - } - - // if not many to many, foreign key could be null - if many2many == "" { - if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - foreignKey = "" - } - } - - field.Relationship = &relationship{ - joinTable: many2many, - foreignKey: foreignKey, - associationForeignKey: associationForeignKey, - kind: "has_many", - } - - if many2many != "" { - field.Relationship.kind = "many_to_many" - } - } - case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { - if foreignKey == "" && scope.HasColumn(field.Name+"Id") { - field.Relationship = &relationship{foreignKey: field.Name + "Id", kind: "belongs_to"} - } else if scope.HasColumn(foreignKey) { - field.Relationship = &relationship{foreignKey: foreignKey, kind: "belongs_to"} - } else { - if foreignKey == "" { - foreignKey = scopeTyp.Name() + "Id" - } - if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - field.Relationship = &relationship{foreignKey: foreignKey, kind: "has_one"} - } + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + field.Relationship = &relationship{ForeignKey: foreignKey, Kind: "has_one"} } } } diff --git a/scope_private.go b/scope_private.go index b35ca3f8..cfe30458 100644 --- a/scope_private.go +++ b/scope_private.go @@ -300,6 +300,9 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore } func (scope *Scope) sqlTagForField(field *Field) (typ string) { + if scope.db == nil { + return "" + } var size = 255 fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier) @@ -343,7 +346,7 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) { } if len(typ) == 0 { - if field.isPrimaryKey { + if field.IsPrimaryKey { typ = scope.Dialect().PrimaryKeyTag(reflectValue, size) } else { typ = scope.Dialect().SqlTag(reflectValue, size) @@ -416,30 +419,27 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - scopeType := scope.IndirectValue().Type() - if f, ok := scopeType.FieldByName(SnakeToUpperCamel(foreignKey)); ok { - field := scope.fieldFromStruct(f) + if field, ok := scope.FieldByName(foreignKey); ok { relationship := field.Relationship - if relationship != nil && relationship.foreignKey != "" { - foreignKey = relationship.foreignKey + if relationship != nil && relationship.ForeignKey != "" { + foreignKey = relationship.ForeignKey - // many to many relations - if relationship.joinTable != "" { + if relationship.Kind == "many_to_many" { joinSql := fmt.Sprintf( "INNER JOIN %v ON %v.%v = %v.%v", - scope.Quote(relationship.joinTable), - scope.Quote(relationship.joinTable), - scope.Quote(ToSnake(relationship.associationForeignKey)), + scope.Quote(relationship.JoinTable), + scope.Quote(relationship.JoinTable), + scope.Quote(ToSnake(relationship.AssociationForeignKey)), toScope.QuotedTableName(), scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.joinTable), scope.Quote(ToSnake(relationship.foreignKey))) + whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey))) toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) return scope } } // has one - if foreignValue, ok := scope.FieldByName(foreignKey); ok { + if foreignValue, ok := scope.FieldValueByName(foreignKey); ok { toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) return scope } @@ -456,15 +456,15 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } func (scope *Scope) createJoinTable(field *Field) { - if field.Relationship != nil && field.Relationship.joinTable != "" { - if !scope.Dialect().HasTable(scope, field.Relationship.joinTable) { + if field.Relationship != nil && field.Relationship.JoinTable != "" { + if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) { newScope := scope.db.NewScope("") primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255) newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", - field.Relationship.joinTable, + field.Relationship.JoinTable, strings.Join([]string{ - scope.Quote(ToSnake(field.Relationship.foreignKey)) + " " + primaryKeySqlType, - scope.Quote(ToSnake(field.Relationship.associationForeignKey)) + " " + primaryKeySqlType}, ",")), + scope.Quote(ToSnake(field.Relationship.ForeignKey)) + " " + primaryKeySqlType, + scope.Quote(ToSnake(field.Relationship.AssociationForeignKey)) + " " + primaryKeySqlType}, ",")), ).Exec() scope.Err(newScope.db.Error) } diff --git a/utils.go b/utils.go index ed10e854..50512143 100644 --- a/utils.go +++ b/utils.go @@ -25,7 +25,7 @@ func (s *safeMap) Get(key string) string { return s.m[key] } -func FieldByName(name string, value interface{}, withAddr ...bool) (interface{}, bool) { +func FieldValueByName(name string, value interface{}, withAddr ...bool) (interface{}, bool) { data := reflect.Indirect(reflect.ValueOf(value)) name = SnakeToUpperCamel(name)