mirror of https://github.com/go-gorm/gorm.git
Refactor GetModelStruct
This commit is contained in:
parent
9455215e61
commit
5d2b9bfe34
|
@ -97,48 +97,43 @@ type Relationship struct {
|
||||||
|
|
||||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
var modelStruct ModelStruct
|
var modelStruct ModelStruct
|
||||||
|
// Scope value can't be nil
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
|
if scope.Value == nil {
|
||||||
if !reflectValue.IsValid() {
|
|
||||||
return &modelStruct
|
return &modelStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
if reflectValue.Kind() == reflect.Slice {
|
reflectType := reflect.ValueOf(scope.Value).Type()
|
||||||
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
|
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
scopeType := reflectValue.Type()
|
// Scope value need to be a struct
|
||||||
|
if reflectType.Kind() != reflect.Struct {
|
||||||
if scopeType.Kind() == reflect.Ptr {
|
return &modelStruct
|
||||||
scopeType = scopeType.Elem()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if value := modelStructsMap.Get(scopeType); value != nil {
|
// Get Cached model struct
|
||||||
|
if value := modelStructsMap.Get(reflectType); value != nil {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStruct.ModelType = scopeType
|
modelStruct.ModelType = reflectType
|
||||||
if scopeType.Kind() != reflect.Struct {
|
|
||||||
return &modelStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
if tabler, ok := reflect.New(scopeType).Interface().(interface {
|
// Set default table name
|
||||||
TableName() string
|
if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok {
|
||||||
}); ok {
|
|
||||||
modelStruct.defaultTableName = tabler.TableName()
|
modelStruct.defaultTableName = tabler.TableName()
|
||||||
} else {
|
} else {
|
||||||
name := ToDBName(scopeType.Name())
|
tableName := ToDBName(reflectType.Name())
|
||||||
if scope.db == nil || !scope.db.parent.singularTable {
|
if scope.db == nil || !scope.db.parent.singularTable {
|
||||||
name = inflection.Plural(name)
|
tableName = inflection.Plural(tableName)
|
||||||
}
|
}
|
||||||
|
modelStruct.defaultTableName = tableName
|
||||||
modelStruct.defaultTableName = name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all fields
|
// Get all fields
|
||||||
fields := []*StructField{}
|
fields := []*StructField{}
|
||||||
for i := 0; i < scopeType.NumField(); i++ {
|
for i := 0; i < reflectType.NumField(); i++ {
|
||||||
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
|
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||||
field := &StructField{
|
field := &StructField{
|
||||||
Struct: fieldStruct,
|
Struct: fieldStruct,
|
||||||
Name: fieldStruct.Name,
|
Name: fieldStruct.Name,
|
||||||
|
@ -244,7 +239,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
for _, foreignKey := range foreignKeys {
|
for _, foreignKey := range foreignKeys {
|
||||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
||||||
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -268,7 +263,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
|
|
||||||
joinTableHandler := JoinTableHandler{}
|
joinTableHandler := JoinTableHandler{}
|
||||||
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
|
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||||
relationship.JoinTableHandler = &joinTableHandler
|
relationship.JoinTableHandler = &joinTableHandler
|
||||||
field.Relationship = relationship
|
field.Relationship = relationship
|
||||||
} else {
|
} else {
|
||||||
|
@ -276,7 +271,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
|
|
||||||
if len(foreignKeys) == 0 {
|
if len(foreignKeys) == 0 {
|
||||||
for _, field := range scope.PrimaryFields() {
|
for _, field := range scope.PrimaryFields() {
|
||||||
if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
|
if foreignField := getForeignField(reflectType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
|
||||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
|
||||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||||
|
@ -386,7 +381,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
finished <- true
|
finished <- true
|
||||||
}(finished)
|
}(finished)
|
||||||
|
|
||||||
modelStructsMap.Set(scopeType, &modelStruct)
|
modelStructsMap.Set(reflectType, &modelStruct)
|
||||||
|
|
||||||
<-finished
|
<-finished
|
||||||
modelStruct.cached = true
|
modelStruct.cached = true
|
||||||
|
|
Loading…
Reference in New Issue