Refactor GetModelStruct

This commit is contained in:
Jinzhu 2016-01-03 09:46:07 +08:00
parent 9455215e61
commit 5d2b9bfe34
1 changed files with 22 additions and 27 deletions

View File

@ -97,48 +97,43 @@ type Relationship struct {
func (scope *Scope) GetModelStruct() *ModelStruct { func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct var modelStruct ModelStruct
// Scope value can't be nil
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) if scope.Value == nil {
if !reflectValue.IsValid() {
return &modelStruct return &modelStruct
} }
if reflectValue.Kind() == reflect.Slice { reflectType := reflect.ValueOf(scope.Value).Type()
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
} }
scopeType := reflectValue.Type() // Scope value need to be a struct
if reflectType.Kind() != reflect.Struct {
if scopeType.Kind() == reflect.Ptr { return &modelStruct
scopeType = scopeType.Elem()
} }
if value := modelStructsMap.Get(scopeType); value != nil { // Get Cached model struct
if value := modelStructsMap.Get(reflectType); value != nil {
return value return value
} }
modelStruct.ModelType = scopeType modelStruct.ModelType = reflectType
if scopeType.Kind() != reflect.Struct {
return &modelStruct
}
if tabler, ok := reflect.New(scopeType).Interface().(interface { // Set default table name
TableName() string if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok {
}); ok {
modelStruct.defaultTableName = tabler.TableName() modelStruct.defaultTableName = tabler.TableName()
} else { } else {
name := ToDBName(scopeType.Name()) tableName := ToDBName(reflectType.Name())
if scope.db == nil || !scope.db.parent.singularTable { if scope.db == nil || !scope.db.parent.singularTable {
name = inflection.Plural(name) tableName = inflection.Plural(tableName)
} }
modelStruct.defaultTableName = tableName
modelStruct.defaultTableName = name
} }
// Get all fields // Get all fields
fields := []*StructField{} fields := []*StructField{}
for i := 0; i < scopeType.NumField(); i++ { for i := 0; i < reflectType.NumField(); i++ {
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
field := &StructField{ field := &StructField{
Struct: fieldStruct, Struct: fieldStruct,
Name: fieldStruct.Name, Name: fieldStruct.Name,
@ -244,7 +239,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for _, foreignKey := range foreignKeys { for _, foreignKey := range foreignKeys {
if field, ok := scope.FieldByName(foreignKey); ok { if field, ok := scope.FieldByName(foreignKey); ok {
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
} }
} }
@ -268,7 +263,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
joinTableHandler := JoinTableHandler{} joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, scopeType, elemType) joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
relationship.JoinTableHandler = &joinTableHandler relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship field.Relationship = relationship
} else { } else {
@ -276,7 +271,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if len(foreignKeys) == 0 { if len(foreignKeys) == 0 {
for _, field := range scope.PrimaryFields() { for _, field := range scope.PrimaryFields() {
if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { if foreignField := getForeignField(reflectType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
@ -386,7 +381,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
finished <- true finished <- true
}(finished) }(finished)
modelStructsMap.Set(scopeType, &modelStruct) modelStructsMap.Set(reflectType, &modelStruct)
<-finished <-finished
modelStruct.cached = true modelStruct.cached = true