Fix tests after refactor

This commit is contained in:
Jinzhu 2015-02-17 17:40:21 +08:00
parent 803343fbe5
commit d6439f4147
5 changed files with 82 additions and 92 deletions

View File

@ -28,7 +28,7 @@ func SaveBeforeAssociations(scope *Scope) {
for _, f := range newDB.NewScope(field.Field.Addr().Interface()).Fields() {
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface()))
}
scope.SetColumn(field.Name, value.Interface())
scope.SetColumn(field, value.Interface())
}
scope.Err(newDB.Save(value.Addr().Interface()).Error)

View File

@ -40,14 +40,17 @@ func (field *Field) Set(value interface{}) (err error) {
// Fields get value's fields
func (scope *Scope) Fields() map[string]*Field {
fields := map[string]*Field{}
structFields := scope.GetStructFields()
if scope.fields == nil {
fields := map[string]*Field{}
structFields := scope.GetStructFields()
for _, structField := range structFields {
fields[structField.DBName] = scope.getField(structField)
for _, structField := range structFields {
fields[structField.DBName] = scope.getField(structField)
}
scope.fields = fields
}
return fields
return scope.fields
}
func (scope *Scope) getField(structField *StructField) *Field {
@ -55,9 +58,10 @@ func (scope *Scope) getField(structField *StructField) *Field {
indirectValue := scope.IndirectValue()
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue.FieldByName(name))
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
}
field.IsBlank = isBlank(indirectValue)
return &field
}

View File

@ -96,20 +96,28 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
if !reflectValue.IsValid() {
return &modelStruct
}
if reflectValue.Kind() == reflect.Slice {
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
}
scopeTyp := reflectValue.Type()
scopeType := reflectValue.Type()
if scopeType.Kind() != reflect.Struct {
return &modelStruct
}
// Set tablename
if fm := reflect.New(scopeTyp).MethodByName("TableName"); fm.IsValid() {
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
if name, ok := results[0].Interface().(string); ok {
modelStruct.TableName = name
}
}
} else {
modelStruct.TableName = ToSnake(scopeTyp.Name())
modelStruct.TableName = ToSnake(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable {
for index, reg := range pluralMapKeys {
if reg.MatchString(modelStruct.TableName) {
@ -120,8 +128,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
// Set fields
for i := 0; i < scopeTyp.NumField(); i++ {
fieldStruct := scopeTyp.Field(i)
for i := 0; i < scopeType.NumField(); i++ {
fieldStruct := scopeType.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
@ -156,7 +164,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.IsScanner, field.IsNormal = true, true
}
if _, isTime := reflect.New(indirectType).Interface().(time.Time); isTime {
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
field.IsTime, field.IsNormal = true, true
}
@ -181,7 +189,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
kind := "has_many"
if foreignKey == "" {
foreignKey = indirectType.Name() + "Id"
foreignKey = scopeType.Name() + "Id"
}
if associationForeignKey == "" {
@ -199,6 +207,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
ForeignType: foreignType,
ForeignFieldName: foreignKey,
AssociationForeignFieldName: associationForeignKey,
ForeignDBName: ToSnake(foreignKey),
AssociationForeignDBName: ToSnake(associationForeignKey),
Kind: kind,
}
} else {
@ -215,22 +225,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var belongsToForeignKey, hasOneForeignKey, kind string
if foreignKey == "" {
belongsToForeignKey = indirectType.Name() + "Id"
hasOneForeignKey = scopeTyp.Name() + "Id"
belongsToForeignKey = field.Name + "Id"
hasOneForeignKey = scopeType.Name() + "Id"
} else {
belongsToForeignKey = foreignKey
hasOneForeignKey = foreignKey
}
if _, ok := scopeTyp.FieldByName(belongsToForeignKey); ok {
foreignKey = belongsToForeignKey
if _, ok := scopeType.FieldByName(belongsToForeignKey); ok {
kind = "belongs_to"
foreignKey = belongsToForeignKey
} else {
foreignKey = hasOneForeignKey
kind = "has_one"
}
field.Relationship = &Relationship{ForeignFieldName: foreignKey, ForeignType: foreignType, Kind: kind}
field.Relationship = &Relationship{
ForeignFieldName: foreignKey,
ForeignDBName: ToSnake(foreignKey),
ForeignType: foreignType,
Kind: kind,
}
}
default:
@ -248,7 +263,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStruct.PrimaryKeyField = field
}
scope.generateSqlTag(field)
if scope.db != nil {
scope.generateSqlTag(field)
}
}
}

View File

@ -124,16 +124,14 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
// HasColumn to check if has column
func (scope *Scope) HasColumn(column string) bool {
clone := scope
if scope.IndirectValue().Kind() == reflect.Slice {
value := reflect.New(scope.IndirectValue().Type().Elem()).Interface()
clone = scope.New(value)
for _, field := range scope.GetStructFields() {
if !field.IsIgnored {
if field.Name == column || field.DBName == column {
return true
}
}
}
dbName := ToSnake(column)
field, hasColumn := clone.Fields()[dbName]
return hasColumn && !field.IsIgnored
return false
}
// FieldValueByName to get column's value and existence

View File

@ -415,76 +415,47 @@ func (scope *Scope) typeName() string {
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value)
fromScopeType := scope.typeName()
toScopeType := toScope.typeName()
scopeType := ""
fromFields := scope.Fields()
toFields := toScope.Fields()
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
if keys := strings.Split(foreignKey, "."); len(keys) > 1 {
scopeType = keys[0]
foreignKey = keys[1]
}
fromField := fromFields[ToSnake(foreignKey)]
toField := toFields[ToSnake(foreignKey)]
var relationship *Relationship
var field *Field
var scopeHasField bool
if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField {
relationship = field.Relationship
}
if scopeType == "" || scopeType == fromScopeType {
if scopeHasField {
if relationship != nil && relationship.ForeignFieldName != "" {
foreignKey = relationship.ForeignFieldName
}
if relationship != nil && relationship.Kind == "many_to_many" {
if relationship.ForeignType != "" {
scope.Err(fmt.Errorf("gorm does not support polymorphic many-to-many associations"))
}
if fromField != nil {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v",
scope.Quote(relationship.JoinTable),
scope.Quote(relationship.JoinTable),
scope.Quote(ToSnake(relationship.AssociationForeignFieldName)),
scope.Quote(relationship.AssociationForeignDBName),
toScope.QuotedTableName(),
scope.Quote(toScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignFieldName)))
toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value)
return scope
}
// has many or has one
if toScope.HasColumn(foreignKey) {
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey))), scope.PrimaryKeyValue())
if relationship != nil && relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
toScope.inlineCondition(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
}
toScope.callCallbacks(scope.db.parent.callback.queries)
return scope
}
// belongs to
if foreignValue, err := scope.FieldValueByName(foreignKey); err == nil {
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName))
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
if relationship != nil && relationship.ForeignType != "" && scope.HasColumn(relationship.ForeignType) {
scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations"))
return scope
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := toScope.db.Where(sql, scope.PrimaryKeyValue())
if relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
}
toScope.inlineCondition(sql, foreignValue).callCallbacks(scope.db.parent.callback.queries)
return scope
scope.Err(query.Find(value).Error)
}
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
}
}
if scopeType == "" || scopeType == toScopeType {
// has many or has one in foreign scope
if toScope.HasColumn(foreignKey) {
sql := fmt.Sprintf("%v = ?", scope.Quote(ToSnake(foreignKey)))
return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries)
}
return scope
} else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
return scope
}
}
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
return scope
}
@ -553,12 +524,12 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = `
ALTER TABLE %s
ADD CONSTRAINT %s
FOREIGN KEY (%s)
REFERENCES %s
ON DELETE %s
ON UPDATE %s;
ALTER TABLE %s
ADD CONSTRAINT %s
FOREIGN KEY (%s)
REFERENCES %s
ON DELETE %s
ON UPDATE %s;
`
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
}