From 5d2b9bfe3420c95932d1ee0f3ff274c3efd71637 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 09:46:07 +0800 Subject: [PATCH] Refactor GetModelStruct --- model_struct.go | 49 ++++++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/model_struct.go b/model_struct.go index 89e7a169..7a47540e 100644 --- a/model_struct.go +++ b/model_struct.go @@ -97,48 +97,43 @@ type Relationship struct { func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct - - reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) - if !reflectValue.IsValid() { + // Scope value can't be nil + if scope.Value == nil { return &modelStruct } - if reflectValue.Kind() == reflect.Slice { - reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) + reflectType := reflect.ValueOf(scope.Value).Type() + for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() } - scopeType := reflectValue.Type() - - if scopeType.Kind() == reflect.Ptr { - scopeType = scopeType.Elem() + // Scope value need to be a struct + if reflectType.Kind() != reflect.Struct { + return &modelStruct } - if value := modelStructsMap.Get(scopeType); value != nil { + // Get Cached model struct + if value := modelStructsMap.Get(reflectType); value != nil { return value } - modelStruct.ModelType = scopeType - if scopeType.Kind() != reflect.Struct { - return &modelStruct - } + modelStruct.ModelType = reflectType - if tabler, ok := reflect.New(scopeType).Interface().(interface { - TableName() string - }); ok { + // Set default table name + if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { modelStruct.defaultTableName = tabler.TableName() } else { - name := ToDBName(scopeType.Name()) + tableName := ToDBName(reflectType.Name()) if scope.db == nil || !scope.db.parent.singularTable { - name = inflection.Plural(name) + tableName = inflection.Plural(tableName) } - - modelStruct.defaultTableName = name + modelStruct.defaultTableName = tableName } // Get all fields fields := []*StructField{} - for i := 0; i < scopeType.NumField(); i++ { - if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { + for i := 0; i < reflectType.NumField(); i++ { + if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { field := &StructField{ Struct: fieldStruct, Name: fieldStruct.Name, @@ -244,7 +239,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if field, ok := scope.FieldByName(foreignKey); ok { relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) - joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName + joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } } @@ -268,7 +263,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, scopeType, elemType) + joinTableHandler.Setup(relationship, many2many, reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -276,7 +271,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(foreignKeys) == 0 { for _, field := range scope.PrimaryFields() { - if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { + if foreignField := getForeignField(reflectType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -386,7 +381,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { finished <- true }(finished) - modelStructsMap.Set(scopeType, &modelStruct) + modelStructsMap.Set(reflectType, &modelStruct) <-finished modelStruct.cached = true