forked from mirror/gorm
Fix tests after refactor
This commit is contained in:
parent
803343fbe5
commit
d6439f4147
|
@ -28,7 +28,7 @@ func SaveBeforeAssociations(scope *Scope) {
|
|||
for _, f := range newDB.NewScope(field.Field.Addr().Interface()).Fields() {
|
||||
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface()))
|
||||
}
|
||||
scope.SetColumn(field.Name, value.Interface())
|
||||
scope.SetColumn(field, value.Interface())
|
||||
}
|
||||
scope.Err(newDB.Save(value.Addr().Interface()).Error)
|
||||
|
||||
|
|
8
field.go
8
field.go
|
@ -40,6 +40,7 @@ func (field *Field) Set(value interface{}) (err error) {
|
|||
|
||||
// Fields get value's fields
|
||||
func (scope *Scope) Fields() map[string]*Field {
|
||||
if scope.fields == nil {
|
||||
fields := map[string]*Field{}
|
||||
structFields := scope.GetStructFields()
|
||||
|
||||
|
@ -47,7 +48,9 @@ func (scope *Scope) Fields() map[string]*Field {
|
|||
fields[structField.DBName] = scope.getField(structField)
|
||||
}
|
||||
|
||||
return fields
|
||||
scope.fields = fields
|
||||
}
|
||||
return scope.fields
|
||||
}
|
||||
|
||||
func (scope *Scope) getField(structField *StructField) *Field {
|
||||
|
@ -55,9 +58,10 @@ func (scope *Scope) getField(structField *StructField) *Field {
|
|||
indirectValue := scope.IndirectValue()
|
||||
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
|
||||
for _, name := range structField.Names {
|
||||
indirectValue = reflect.Indirect(indirectValue.FieldByName(name))
|
||||
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
|
||||
}
|
||||
field.Field = indirectValue
|
||||
}
|
||||
field.IsBlank = isBlank(indirectValue)
|
||||
return &field
|
||||
}
|
||||
|
|
|
@ -96,20 +96,28 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
var modelStruct ModelStruct
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
if !reflectValue.IsValid() {
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
|
||||
}
|
||||
scopeTyp := reflectValue.Type()
|
||||
scopeType := reflectValue.Type()
|
||||
|
||||
if scopeType.Kind() != reflect.Struct {
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
// Set tablename
|
||||
if fm := reflect.New(scopeTyp).MethodByName("TableName"); fm.IsValid() {
|
||||
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
|
||||
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
|
||||
if name, ok := results[0].Interface().(string); ok {
|
||||
modelStruct.TableName = name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
modelStruct.TableName = ToSnake(scopeTyp.Name())
|
||||
modelStruct.TableName = ToSnake(scopeType.Name())
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
for index, reg := range pluralMapKeys {
|
||||
if reg.MatchString(modelStruct.TableName) {
|
||||
|
@ -120,8 +128,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
}
|
||||
|
||||
// Set fields
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
fieldStruct := scopeTyp.Field(i)
|
||||
for i := 0; i < scopeType.NumField(); i++ {
|
||||
fieldStruct := scopeType.Field(i)
|
||||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
|
@ -156,7 +164,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
field.IsScanner, field.IsNormal = true, true
|
||||
}
|
||||
|
||||
if _, isTime := reflect.New(indirectType).Interface().(time.Time); isTime {
|
||||
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
|
||||
field.IsTime, field.IsNormal = true, true
|
||||
}
|
||||
|
||||
|
@ -181,7 +189,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
kind := "has_many"
|
||||
|
||||
if foreignKey == "" {
|
||||
foreignKey = indirectType.Name() + "Id"
|
||||
foreignKey = scopeType.Name() + "Id"
|
||||
}
|
||||
|
||||
if associationForeignKey == "" {
|
||||
|
@ -199,6 +207,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
ForeignType: foreignType,
|
||||
ForeignFieldName: foreignKey,
|
||||
AssociationForeignFieldName: associationForeignKey,
|
||||
ForeignDBName: ToSnake(foreignKey),
|
||||
AssociationForeignDBName: ToSnake(associationForeignKey),
|
||||
Kind: kind,
|
||||
}
|
||||
} else {
|
||||
|
@ -215,22 +225,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
var belongsToForeignKey, hasOneForeignKey, kind string
|
||||
|
||||
if foreignKey == "" {
|
||||
belongsToForeignKey = indirectType.Name() + "Id"
|
||||
hasOneForeignKey = scopeTyp.Name() + "Id"
|
||||
belongsToForeignKey = field.Name + "Id"
|
||||
hasOneForeignKey = scopeType.Name() + "Id"
|
||||
} else {
|
||||
belongsToForeignKey = foreignKey
|
||||
hasOneForeignKey = foreignKey
|
||||
}
|
||||
|
||||
if _, ok := scopeTyp.FieldByName(belongsToForeignKey); ok {
|
||||
foreignKey = belongsToForeignKey
|
||||
if _, ok := scopeType.FieldByName(belongsToForeignKey); ok {
|
||||
kind = "belongs_to"
|
||||
foreignKey = belongsToForeignKey
|
||||
} else {
|
||||
foreignKey = hasOneForeignKey
|
||||
kind = "has_one"
|
||||
}
|
||||
|
||||
field.Relationship = &Relationship{ForeignFieldName: foreignKey, ForeignType: foreignType, Kind: kind}
|
||||
field.Relationship = &Relationship{
|
||||
ForeignFieldName: foreignKey,
|
||||
ForeignDBName: ToSnake(foreignKey),
|
||||
ForeignType: foreignType,
|
||||
Kind: kind,
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -248,9 +263,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
modelStruct.PrimaryKeyField = field
|
||||
}
|
||||
|
||||
if scope.db != nil {
|
||||
scope.generateSqlTag(field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &modelStruct
|
||||
}
|
||||
|
|
16
scope.go
16
scope.go
|
@ -124,16 +124,14 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
|
|||
|
||||
// HasColumn to check if has column
|
||||
func (scope *Scope) HasColumn(column string) bool {
|
||||
clone := scope
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := reflect.New(scope.IndirectValue().Type().Elem()).Interface()
|
||||
clone = scope.New(value)
|
||||
for _, field := range scope.GetStructFields() {
|
||||
if !field.IsIgnored {
|
||||
if field.Name == column || field.DBName == column {
|
||||
return true
|
||||
}
|
||||
|
||||
dbName := ToSnake(column)
|
||||
|
||||
field, hasColumn := clone.Fields()[dbName]
|
||||
return hasColumn && !field.IsIgnored
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FieldValueByName to get column's value and existence
|
||||
|
|
|
@ -415,76 +415,47 @@ func (scope *Scope) typeName() string {
|
|||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
toScope := scope.db.NewScope(value)
|
||||
fromScopeType := scope.typeName()
|
||||
toScopeType := toScope.typeName()
|
||||
scopeType := ""
|
||||
|
||||
fromFields := scope.Fields()
|
||||
toFields := toScope.Fields()
|
||||
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
|
||||
if keys := strings.Split(foreignKey, "."); len(keys) > 1 {
|
||||
scopeType = keys[0]
|
||||
foreignKey = keys[1]
|
||||
}
|
||||
fromField := fromFields[ToSnake(foreignKey)]
|
||||
toField := toFields[ToSnake(foreignKey)]
|
||||
|
||||
var relationship *Relationship
|
||||
var field *Field
|
||||
var scopeHasField bool
|
||||
if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField {
|
||||
relationship = field.Relationship
|
||||
}
|
||||
|
||||
if scopeType == "" || scopeType == fromScopeType {
|
||||
if scopeHasField {
|
||||
if relationship != nil && relationship.ForeignFieldName != "" {
|
||||
foreignKey = relationship.ForeignFieldName
|
||||
}
|
||||
|
||||
if relationship != nil && relationship.Kind == "many_to_many" {
|
||||
if relationship.ForeignType != "" {
|
||||
scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
|
||||
}
|
||||
if fromField != nil {
|
||||
if relationship := fromField.Relationship; relationship != nil {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
joinSql := fmt.Sprintf(
|
||||
"INNER JOIN %v ON %v.%v = %v.%v",
|
||||
scope.Quote(relationship.JoinTable),
|
||||
scope.Quote(relationship.JoinTable),
|
||||
scope.Quote(ToSnake(relationship.AssociationForeignFieldName)),
|
||||
scope.Quote(relationship.AssociationForeignDBName),
|
||||
toScope.QuotedTableName(),
|
||||
scope.Quote(toScope.PrimaryKey()))
|
||||
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignFieldName)))
|
||||
toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
|
||||
return scope
|
||||
}
|
||||
|
||||
// has many or has one
|
||||
if toScope.HasColumn(foreignKey) {
|
||||
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue())
|
||||
if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
|
||||
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
|
||||
}
|
||||
toScope.callCallbacks(scope.db.parent.callback.queries)
|
||||
return scope
|
||||
}
|
||||
|
||||
// belongs to
|
||||
if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil {
|
||||
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName))
|
||||
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) {
|
||||
scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
|
||||
return scope
|
||||
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||
query := toScope.db.Where(sql, scope.PrimaryKeyValue())
|
||||
if relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
|
||||
}
|
||||
toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries)
|
||||
return scope
|
||||
scope.Err(query.Find(value).Error)
|
||||
}
|
||||
} else {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
|
||||
}
|
||||
return scope
|
||||
} else if toField != nil {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
|
||||
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||
return scope
|
||||
}
|
||||
}
|
||||
|
||||
if scopeType == "" || scopeType == toScopeType {
|
||||
// has many or has one in foreign scope
|
||||
if toScope.HasColumn(foreignKey) {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
|
||||
return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)
|
||||
}
|
||||
}
|
||||
}
|
||||
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
|
||||
return scope
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue