mirror of https://github.com/go-gorm/gorm.git
Refactor GetModelStruct
This commit is contained in:
parent
9455215e61
commit
5d2b9bfe34
|
@ -97,48 +97,43 @@ type Relationship struct {
|
|||
|
||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
var modelStruct ModelStruct
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
if !reflectValue.IsValid() {
|
||||
// Scope value can't be nil
|
||||
if scope.Value == nil {
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
|
||||
reflectType := reflect.ValueOf(scope.Value).Type()
|
||||
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
|
||||
scopeType := reflectValue.Type()
|
||||
|
||||
if scopeType.Kind() == reflect.Ptr {
|
||||
scopeType = scopeType.Elem()
|
||||
// Scope value need to be a struct
|
||||
if reflectType.Kind() != reflect.Struct {
|
||||
return &modelStruct
|
||||
}
|
||||
|
||||
if value := modelStructsMap.Get(scopeType); value != nil {
|
||||
// Get Cached model struct
|
||||
if value := modelStructsMap.Get(reflectType); value != nil {
|
||||
return value
|
||||
}
|
||||
|
||||
modelStruct.ModelType = scopeType
|
||||
if scopeType.Kind() != reflect.Struct {
|
||||
return &modelStruct
|
||||
}
|
||||
modelStruct.ModelType = reflectType
|
||||
|
||||
if tabler, ok := reflect.New(scopeType).Interface().(interface {
|
||||
TableName() string
|
||||
}); ok {
|
||||
// Set default table name
|
||||
if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok {
|
||||
modelStruct.defaultTableName = tabler.TableName()
|
||||
} else {
|
||||
name := ToDBName(scopeType.Name())
|
||||
tableName := ToDBName(reflectType.Name())
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
name = inflection.Plural(name)
|
||||
tableName = inflection.Plural(tableName)
|
||||
}
|
||||
|
||||
modelStruct.defaultTableName = name
|
||||
modelStruct.defaultTableName = tableName
|
||||
}
|
||||
|
||||
// Get all fields
|
||||
fields := []*StructField{}
|
||||
for i := 0; i < scopeType.NumField(); i++ {
|
||||
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
for i := 0; i < reflectType.NumField(); i++ {
|
||||
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
field := &StructField{
|
||||
Struct: fieldStruct,
|
||||
Name: fieldStruct.Name,
|
||||
|
@ -244,7 +239,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
for _, foreignKey := range foreignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
||||
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
||||
joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||
}
|
||||
}
|
||||
|
@ -268,7 +263,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
}
|
||||
|
||||
joinTableHandler := JoinTableHandler{}
|
||||
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
|
||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||
relationship.JoinTableHandler = &joinTableHandler
|
||||
field.Relationship = relationship
|
||||
} else {
|
||||
|
@ -276,7 +271,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
|
||||
if foreignField := getForeignField(reflectType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
|
@ -386,7 +381,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
finished <- true
|
||||
}(finished)
|
||||
|
||||
modelStructsMap.Set(scopeType, &modelStruct)
|
||||
modelStructsMap.Set(reflectType, &modelStruct)
|
||||
|
||||
<-finished
|
||||
modelStruct.cached = true
|
||||
|
|
Loading…
Reference in New Issue