Fix tests after refactor

This commit is contained in:
Jinzhu 2015-02-17 17:40:21 +08:00
parent 803343fbe5
commit d6439f4147
5 changed files with 82 additions and 92 deletions

View File

@ -28,7 +28,7 @@ func SaveBeforeAssociations(scope *Scope) {
for _, f := range newDB.NewScope(field.Field.Addr().Interface()).Fields() { for _, f := range newDB.NewScope(field.Field.Addr().Interface()).Fields() {
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface())) value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface()))
} }
scope.SetColumn(field.Name, value.Interface()) scope.SetColumn(field, value.Interface())
} }
scope.Err(newDB.Save(value.Addr().Interface()).Error) scope.Err(newDB.Save(value.Addr().Interface()).Error)

View File

@ -40,14 +40,17 @@ func (field *Field) Set(value interface{}) (err error) {
// Fields get value's fields // Fields get value's fields
func (scope *Scope) Fields() map[string]*Field { func (scope *Scope) Fields() map[string]*Field {
fields := map[string]*Field{} if scope.fields == nil {
structFields := scope.GetStructFields() fields := map[string]*Field{}
structFields := scope.GetStructFields()
for _, structField := range structFields { for _, structField := range structFields {
fields[structField.DBName] = scope.getField(structField) fields[structField.DBName] = scope.getField(structField)
}
scope.fields = fields
} }
return scope.fields
return fields
} }
func (scope *Scope) getField(structField *StructField) *Field { func (scope *Scope) getField(structField *StructField) *Field {
@ -55,9 +58,10 @@ func (scope *Scope) getField(structField *StructField) *Field {
indirectValue := scope.IndirectValue() indirectValue := scope.IndirectValue()
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct { if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
for _, name := range structField.Names { for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue.FieldByName(name)) indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
} }
field.Field = indirectValue field.Field = indirectValue
} }
field.IsBlank = isBlank(indirectValue)
return &field return &field
} }

View File

@ -96,20 +96,28 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct var modelStruct ModelStruct
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
if !reflectValue.IsValid() {
return &modelStruct
}
if reflectValue.Kind() == reflect.Slice { if reflectValue.Kind() == reflect.Slice {
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
} }
scopeTyp := reflectValue.Type() scopeType := reflectValue.Type()
if scopeType.Kind() != reflect.Struct {
return &modelStruct
}
// Set tablename // Set tablename
if fm := reflect.New(scopeTyp).MethodByName("TableName"); fm.IsValid() { if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
if results := fm.Call([]reflect.Value{}); len(results) > 0 { if results := fm.Call([]reflect.Value{}); len(results) > 0 {
if name, ok := results[0].Interface().(string); ok { if name, ok := results[0].Interface().(string); ok {
modelStruct.TableName = name modelStruct.TableName = name
} }
} }
} else { } else {
modelStruct.TableName = ToSnake(scopeTyp.Name()) modelStruct.TableName = ToSnake(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable { if scope.db == nil || !scope.db.parent.singularTable {
for index, reg := range pluralMapKeys { for index, reg := range pluralMapKeys {
if reg.MatchString(modelStruct.TableName) { if reg.MatchString(modelStruct.TableName) {
@ -120,8 +128,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
// Set fields // Set fields
for i := 0; i < scopeTyp.NumField(); i++ { for i := 0; i < scopeType.NumField(); i++ {
fieldStruct := scopeTyp.Field(i) fieldStruct := scopeType.Field(i)
if !ast.IsExported(fieldStruct.Name) { if !ast.IsExported(fieldStruct.Name) {
continue continue
} }
@ -156,7 +164,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.IsScanner, field.IsNormal = true, true field.IsScanner, field.IsNormal = true, true
} }
if _, isTime := reflect.New(indirectType).Interface().(time.Time); isTime { if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
field.IsTime, field.IsNormal = true, true field.IsTime, field.IsNormal = true, true
} }
@ -181,7 +189,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
kind := "has_many" kind := "has_many"
if foreignKey == "" { if foreignKey == "" {
foreignKey = indirectType.Name() + "Id" foreignKey = scopeType.Name() + "Id"
} }
if associationForeignKey == "" { if associationForeignKey == "" {
@ -199,6 +207,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
ForeignType: foreignType, ForeignType: foreignType,
ForeignFieldName: foreignKey, ForeignFieldName: foreignKey,
AssociationForeignFieldName: associationForeignKey, AssociationForeignFieldName: associationForeignKey,
ForeignDBName: ToSnake(foreignKey),
AssociationForeignDBName: ToSnake(associationForeignKey),
Kind: kind, Kind: kind,
} }
} else { } else {
@ -215,22 +225,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var belongsToForeignKey, hasOneForeignKey, kind string var belongsToForeignKey, hasOneForeignKey, kind string
if foreignKey == "" { if foreignKey == "" {
belongsToForeignKey = indirectType.Name() + "Id" belongsToForeignKey = field.Name + "Id"
hasOneForeignKey = scopeTyp.Name() + "Id" hasOneForeignKey = scopeType.Name() + "Id"
} else { } else {
belongsToForeignKey = foreignKey belongsToForeignKey = foreignKey
hasOneForeignKey = foreignKey hasOneForeignKey = foreignKey
} }
if _, ok := scopeTyp.FieldByName(belongsToForeignKey); ok { if _, ok := scopeType.FieldByName(belongsToForeignKey); ok {
foreignKey = belongsToForeignKey
kind = "belongs_to" kind = "belongs_to"
foreignKey = belongsToForeignKey
} else { } else {
foreignKey = hasOneForeignKey foreignKey = hasOneForeignKey
kind = "has_one" kind = "has_one"
} }
field.Relationship = &Relationship{ForeignFieldName: foreignKey, ForeignType: foreignType, Kind: kind} field.Relationship = &Relationship{
ForeignFieldName: foreignKey,
ForeignDBName: ToSnake(foreignKey),
ForeignType: foreignType,
Kind: kind,
}
} }
default: default:
@ -248,7 +263,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStruct.PrimaryKeyField = field modelStruct.PrimaryKeyField = field
} }
scope.generateSqlTag(field) if scope.db != nil {
scope.generateSqlTag(field)
}
} }
} }

View File

@ -124,16 +124,14 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
// HasColumn to check if has column // HasColumn to check if has column
func (scope *Scope) HasColumn(column string) bool { func (scope *Scope) HasColumn(column string) bool {
clone := scope for _, field := range scope.GetStructFields() {
if scope.IndirectValue().Kind() == reflect.Slice { if !field.IsIgnored {
value := reflect.New(scope.IndirectValue().Type().Elem()).Interface() if field.Name == column || field.DBName == column {
clone = scope.New(value) return true
}
}
} }
return false
dbName := ToSnake(column)
field, hasColumn := clone.Fields()[dbName]
return hasColumn && !field.IsIgnored
} }
// FieldValueByName to get column's value and existence // FieldValueByName to get column's value and existence

View File

@ -415,76 +415,47 @@ func (scope *Scope) typeName() string {
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value) toScope := scope.db.NewScope(value)
fromScopeType := scope.typeName() fromFields := scope.Fields()
toScopeType := toScope.typeName() toFields := toScope.Fields()
scopeType := ""
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
if keys := strings.Split(foreignKey, "."); len(keys) > 1 { fromField := fromFields[ToSnake(foreignKey)]
scopeType = keys[0] toField := toFields[ToSnake(foreignKey)]
foreignKey = keys[1]
}
var relationship *Relationship if fromField != nil {
var field *Field if relationship := fromField.Relationship; relationship != nil {
var scopeHasField bool if relationship.Kind == "many_to_many" {
if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField {
relationship = field.Relationship
}
if scopeType == "" || scopeType == fromScopeType {
if scopeHasField {
if relationship != nil && relationship.ForeignFieldName != "" {
foreignKey = relationship.ForeignFieldName
}
if relationship != nil && relationship.Kind == "many_to_many" {
if relationship.ForeignType != "" {
scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
}
joinSql := fmt.Sprintf( joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v", "INNER JOIN %v ON %v.%v = %v.%v",
scope.Quote(relationship.JoinTable), scope.Quote(relationship.JoinTable),
scope.Quote(relationship.JoinTable), scope.Quote(relationship.JoinTable),
scope.Quote(ToSnake(relationship.AssociationForeignFieldName)), scope.Quote(relationship.AssociationForeignDBName),
toScope.QuotedTableName(), toScope.QuotedTableName(),
scope.Quote(toScope.PrimaryKey())) scope.Quote(toScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignFieldName))) whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName))
toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
return scope } else if relationship.Kind == "belongs_to" {
}
// has many or has one
if toScope.HasColumn(foreignKey) {
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue())
if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
}
toScope.callCallbacks(scope.db.parent.callback.queries)
return scope
}
// belongs to
if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) { scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations")) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
return scope sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := toScope.db.Where(sql, scope.PrimaryKeyValue())
if relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
} }
toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries) scope.Err(query.Find(value).Error)
return scope
} }
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
} }
} return scope
} else if toField != nil {
if scopeType == "" || scopeType == toScopeType { sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
// has many or has one in foreign scope scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
if toScope.HasColumn(foreignKey) { return scope
sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)
}
} }
} }
scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
return scope return scope
} }
@ -553,12 +524,12 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
var table = scope.TableName() var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = ` var query = `
ALTER TABLE %s ALTER TABLE %s
ADD CONSTRAINT %s ADD CONSTRAINT %s
FOREIGN KEY (%s) FOREIGN KEY (%s)
REFERENCES %s REFERENCES %s
ON DELETE %s ON DELETE %s
ON UPDATE %s; ON UPDATE %s;
` `
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec() scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
} }