diff --git a/callback_shared.go b/callback_shared.go index 3d8a245c..7fab779e 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -32,8 +32,8 @@ func SaveBeforeAssociations(scope *Scope) { } scope.Err(newDB.Save(value.Addr().Interface()).Error) - if relationship.ForeignKey != "" { - scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Addr().Interface()).PrimaryKeyValue()) + if relationship.ForeignFieldName != "" { + scope.SetColumn(relationship.ForeignFieldName, newDB.NewScope(value.Addr().Interface()).PrimaryKeyValue()) } if relationship.ForeignType != "" { scope.Err(fmt.Errorf("gorm does not support polymorphic belongs_to associations")) @@ -58,8 +58,8 @@ func SaveAfterAssociations(scope *Scope) { newDB := scope.NewDB() elem := value.Index(i).Addr().Interface() - if relationship.JoinTable == "" && relationship.ForeignKey != "" { - newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) + if relationship.JoinTable == "" && relationship.ForeignFieldName != "" { + newDB.NewScope(elem).SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()) } if relationship.ForeignType != "" { newDB.NewScope(elem).SetColumn(relationship.ForeignType, scope.TableName()) @@ -74,9 +74,9 @@ func SaveAfterAssociations(scope *Scope) { newScope := scope.New(elem) joinTable := relationship.JoinTable - foreignKey := ToSnake(relationship.ForeignKey) + foreignKey := ToSnake(relationship.ForeignFieldName) foreignValue := fmt.Sprintf("%v", scope.PrimaryKeyValue()) - associationForeignKey := ToSnake(relationship.AssociationForeignKey) + associationForeignKey := ToSnake(relationship.AssociationForeignFieldName) associationForeignValue := fmt.Sprintf("%v", newScope.PrimaryKeyValue()) newScope.Raw(fmt.Sprintf( @@ -97,8 +97,8 @@ func SaveAfterAssociations(scope *Scope) { default: newDB := scope.NewDB() if value.CanAddr() { - if relationship.ForeignKey != "" { - newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) + if relationship.ForeignFieldName != "" { + newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()) } if relationship.ForeignType != "" { newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName()) @@ -112,8 +112,8 @@ func SaveAfterAssociations(scope *Scope) { } elem := destValue.Addr().Interface() - if relationship.ForeignKey != "" { - newDB.NewScope(elem).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) + if relationship.ForeignFieldName != "" { + newDB.NewScope(elem).SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()) } if relationship.ForeignType != "" { newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignType, scope.TableName()) diff --git a/main.go b/main.go index 377e0a25..0cf94eb1 100644 --- a/main.go +++ b/main.go @@ -430,7 +430,7 @@ func (s *DB) Association(column string) *Association { var field *Field var ok bool if field, ok = scope.FieldByName(column); ok { - if field.Relationship == nil || field.Relationship.ForeignKey == "" { + if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())) } } else { diff --git a/model_struct.go b/model_struct.go index 9f641443..af062567 100644 --- a/model_struct.go +++ b/model_struct.go @@ -2,9 +2,12 @@ package gorm import ( "database/sql" + "fmt" "go/ast" "reflect" + "regexp" "strconv" + "strings" "time" ) @@ -23,7 +26,10 @@ type StructField struct { IsNormal bool IsIgnored bool DefaultValue *string + GormSettings map[string]string + SqlSettings map[string]string SqlTag string + Struct reflect.StructField Relationship *Relationship } @@ -37,14 +43,81 @@ type Relationship struct { JoinTable string } -func (scope *Scope) GetStructFields() (fields []*StructField) { +func (scope *Scope) GenerateSqlTag(field *StructField) { + var sqlType string + reflectValue := reflect.New(field.Struct.Type) + + if value, ok := field.SqlSettings["TYPE"]; ok { + sqlType = value + } + + additionalType := field.SqlSettings["NOT NULL"] + " " + field.SqlSettings["UNIQUE"] + if value, ok := field.SqlSettings["DEFAULT"]; ok { + additionalType = additionalType + "DEFAULT " + value + } + + if field.IsScanner { + var getScannerValue func(reflect.Value) + getScannerValue = func(reflectValue reflect.Value) { + if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner { + getScannerValue(reflectValue.Field(0)) + } + } + getScannerValue(reflectValue.Field(0)) + } + + if sqlType == "" { + var size = 255 + + if value, ok := field.SqlSettings["SIZE"]; ok { + size, _ = strconv.Atoi(value) + } + + if field.IsPrimaryKey { + sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size) + } else { + sqlType = scope.Dialect().SqlTag(reflectValue, size) + } + } + + if strings.TrimSpace(additionalType) == "" { + field.SqlTag = sqlType + } else { + field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType) + } +} + +var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} +var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} + +func (scope *Scope) GetModelStruct() *ModelStruct { + var modelStruct ModelStruct + reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) if reflectValue.Kind() == reflect.Slice { reflectValue = reflect.Indirect(reflect.New(reflectValue.Elem().Type())) } - scopeTyp := reflectValue.Type() - hasPrimaryKey := false + + // Set tablename + if fm := reflect.New(scopeTyp).MethodByName("TableName"); fm.IsValid() { + if results := fm.Call([]reflect.Value{}); len(results) > 0 { + if name, ok := results[0].Interface().(string); ok { + modelStruct.TableName = name + } + } + } else { + modelStruct.TableName = ToSnake(scopeTyp.Name()) + if scope.db == nil || !scope.db.parent.singularTable { + for index, reg := range pluralMapKeys { + if reg.MatchString(modelStruct.TableName) { + modelStruct.TableName = reg.ReplaceAllString(modelStruct.TableName, pluralMapValues[index]) + } + } + } + } + + // Set fields for i := 0; i < scopeTyp.NumField(); i++ { fieldStruct := scopeTyp.Field(i) if !ast.IsExported(fieldStruct.Name) { @@ -52,21 +125,22 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { } var field *StructField + field.Struct = fieldStruct if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true } else { - sqlSettings := parseTagSetting(fieldStruct.Tag.Get("sql")) - settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) - if _, ok := settings["PRIMARY_KEY"]; ok { + field.SqlSettings = parseTagSetting(fieldStruct.Tag.Get("sql")) + field.GormSettings = parseTagSetting(fieldStruct.Tag.Get("gorm")) + if _, ok := field.GormSettings["PRIMARY_KEY"]; ok { field.IsPrimaryKey = true - hasPrimaryKey = true + modelStruct.PrimaryKeyField = field } - if value, ok := sqlSettings["DEFAULT"]; ok { + if value, ok := field.SqlSettings["DEFAULT"]; ok { field.DefaultValue = &value } - if value, ok := settings["COLUMN"]; ok { + if value, ok := field.GormSettings["COLUMN"]; ok { field.DBName = value } else { field.DBName = ToSnake(fieldStruct.Name) @@ -85,11 +159,11 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { field.IsTime, field.IsNormal = true, true } - many2many := settings["MANY2MANY"] - foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) - foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"]) - associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) - if polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"]); polymorphic != "" { + many2many := field.GormSettings["MANY2MANY"] + foreignKey := SnakeToUpperCamel(field.GormSettings["FOREIGNKEY"]) + foreignType := SnakeToUpperCamel(field.GormSettings["FOREIGNTYPE"]) + associationForeignKey := SnakeToUpperCamel(field.GormSettings["ASSOCIATIONFOREIGNKEY"]) + if polymorphic := SnakeToUpperCamel(field.GormSettings["POLYMORPHIC"]); polymorphic != "" { foreignKey = polymorphic + "Id" foreignType = polymorphic + "Type" } @@ -119,20 +193,20 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { foreignKey = "" } - field.Relationship = &relationship{ - JoinTable: many2many, - ForeignKey: foreignKey, - ForeignType: foreignType, - AssociationForeignKey: associationForeignKey, + field.Relationship = &Relationship{ + JoinTable: many2many, + ForeignType: foreignType, + ForeignFieldName: foreignKey, + AssociationForeignFieldName: associationForeignKey, Kind: kind, } } else { field.IsNormal = true } case reflect.Struct: - if _, ok := settings["EMBEDDED"]; ok || fieldStruct.Anonymous { + if _, ok := field.GormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { - fields = append(fields, field) + modelStruct.StructFields = append(modelStruct.StructFields, field) } break } else { @@ -154,7 +228,7 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { kind = "has_one" } - field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind} + field.Relationship = &Relationship{ForeignFieldName: foreignKey, ForeignType: foreignType, Kind: kind} } default: @@ -162,64 +236,21 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { } } } - fields = append(fields, field) + modelStruct.StructFields = append(modelStruct.StructFields, field) } - if !hasPrimaryKey { - for _, field := range fields { - if field.DBName == "id" { - field.IsPrimaryKey = true - } + for _, field := range modelStruct.StructFields { + if modelStruct.PrimaryKeyField == nil && field.DBName == "id" { + field.IsPrimaryKey = true + modelStruct.PrimaryKeyField = field } + + scope.GenerateSqlTag(field) } - for _, field := range fields { - var sqlType string - size := 255 - sqlTag := field.Tag.Get("sql") - sqlSetting = parseTagSetting(sqlTag) - - if value, ok := sqlSetting["SIZE"]; ok { - if i, err := strconv.Atoi(value); err == nil { - size = i - } else { - size = 0 - } - } - - if value, ok := sqlSetting["TYPE"]; ok { - typ = value - } - - additionalType := sqlSetting["NOT NULL"] + " " + sqlSetting["UNIQUE"] - if value, ok := sqlSetting["DEFAULT"]; ok { - additionalType = additionalType + "DEFAULT " + value - } - - if field.IsScanner { - var getScannerValue func(reflect.Value) - getScannerValue = func(reflectValue reflect.Value) { - if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner { - getScannerValue(reflectValue.Field(0)) - } - } - getScannerValue(reflectValue.Field(0)) - } - if field.IsNormal { - typ + " " + additionalType - } else if !field.IsTime { - return typ + " " + additionalType - } - - if len(typ) == 0 { - if field.IsPrimaryKey { - typ = scope.Dialect().PrimaryKeyTag(reflectValue, size) - } else { - typ = scope.Dialect().SqlTag(reflectValue, size) - } - } - - return typ + " " + additionalType - } - return + return &modelStruct +} + +func (scope *Scope) GetStructFields() (fields []*StructField) { + return scope.GetModelStruct().StructFields } diff --git a/preload.go b/preload.go index 87ee5065..7ac0db0d 100644 --- a/preload.go +++ b/preload.go @@ -46,14 +46,14 @@ func Preload(scope *Scope) { switch relation.Kind { case "has_one": - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName())) + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) scope.NewDB().Where(condition, scope.getColumnAsArray(primaryName)).Find(results, conditions...) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if isSlice { - value := getFieldValue(result, relation.ForeignKey) + value := getFieldValue(result, relation.ForeignFieldName) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { if equalAsString(getFieldValue(objects.Index(j), primaryName), value) { @@ -66,13 +66,13 @@ func Preload(scope *Scope) { } } case "has_many": - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName())) + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) scope.NewDB().Where(condition, scope.getColumnAsArray(primaryName)).Find(results, conditions...) resultValues := reflect.Indirect(reflect.ValueOf(results)) if isSlice { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) - value := getFieldValue(result, relation.ForeignKey) + value := getFieldValue(result, relation.ForeignFieldName) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) @@ -87,7 +87,7 @@ func Preload(scope *Scope) { scope.SetColumn(field, resultValues) } case "belongs_to": - scope.NewDB().Where(scope.getColumnAsArray(relation.ForeignKey)).Find(results, conditions...) + scope.NewDB().Where(scope.getColumnAsArray(relation.ForeignFieldName)).Find(results, conditions...) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) @@ -96,7 +96,7 @@ func Preload(scope *Scope) { objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getFieldValue(object, relation.ForeignKey), value) { + if equalAsString(getFieldValue(object, relation.ForeignFieldName), value) { object.FieldByName(field.Name).Set(result) } } diff --git a/scope.go b/scope.go index 5850a517..0bee5b9b 100644 --- a/scope.go +++ b/scope.go @@ -7,7 +7,6 @@ import ( "time" "reflect" - "regexp" ) type Scope struct { @@ -227,8 +226,6 @@ func (scope *Scope) AddToVars(value interface{}) string { } // TableName get table name -var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} -var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} func (scope *Scope) TableName() string { if scope.Search != nil && len(scope.Search.TableName) > 0 { diff --git a/scope_private.go b/scope_private.go index 2e44c7b0..8473fab9 100644 --- a/scope_private.go +++ b/scope_private.go @@ -364,69 +364,6 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore return } -func (scope *Scope) sqlTagForField(field *Field) (typ string) { - if scope.db == nil { - return "" - } - var size = 255 - - fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier) - var setting = parseTagSetting(fieldTag) - - if value, ok := setting["SIZE"]; ok { - if i, err := strconv.Atoi(value); err == nil { - size = i - } else { - size = 0 - } - } - - if value, ok := setting["TYPE"]; ok { - typ = value - } - - additionalType := setting["NOT NULL"] + " " + setting["UNIQUE"] - if value, ok := setting["DEFAULT"]; ok { - additionalType = additionalType + "DEFAULT " + value - } - - value := field.Field.Interface() - reflectValue := field.Field - if reflectValue.Kind() == reflect.Ptr { - reflectValue = reflect.New(reflectValue.Type().Elem()).Elem() - } - - switch reflectValue.Kind() { - case reflect.Slice: - if _, ok := value.([]byte); !ok { - return typ + " " + additionalType - } - case reflect.Struct: - if field.IsScanner { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - reflectValue = value - if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner { - getScannerValue(reflectValue.Field(0)) - } - } - getScannerValue(reflectValue.Field(0)) - } else if !field.IsTime { - return typ + " " + additionalType - } - } - - if len(typ) == 0 { - if field.IsPrimaryKey { - typ = scope.Dialect().PrimaryKeyTag(reflectValue, size) - } else { - typ = scope.Dialect().SqlTag(reflectValue, size) - } - } - - return typ + " " + additionalType -} - func (scope *Scope) row() *sql.Row { defer scope.Trace(NowFunc()) scope.prepareQuerySql() @@ -495,7 +432,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { foreignKey = keys[1] } - var relationship *relationship + var relationship *Relationship var field *Field var scopeHasField bool if field, scopeHasField = scope.FieldByName(foreignKey); scopeHasField { @@ -504,8 +441,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if scopeType == "" || scopeType == fromScopeType { if scopeHasField { - if relationship != nil && relationship.ForeignKey != "" { - foreignKey = relationship.ForeignKey + if relationship != nil && relationship.ForeignFieldName != "" { + foreignKey = relationship.ForeignFieldName } if relationship != nil && relationship.Kind == "many_to_many" { @@ -516,10 +453,10 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { "INNER JOIN %v ON %v.%v = %v.%v", scope.Quote(relationship.JoinTable), scope.Quote(relationship.JoinTable), - scope.Quote(ToSnake(relationship.AssociationForeignKey)), + scope.Quote(ToSnake(relationship.AssociationForeignFieldName)), toScope.QuotedTableName(), scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignKey))) + whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(ToSnake(relationship.ForeignFieldName))) toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value) return scope } @@ -567,8 +504,8 @@ func (scope *Scope) createJoinTable(field *Field) { newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", field.Relationship.JoinTable, strings.Join([]string{ - scope.Quote(ToSnake(field.Relationship.ForeignKey)) + " " + primaryKeySqlType, - scope.Quote(ToSnake(field.Relationship.AssociationForeignKey)) + " " + primaryKeySqlType}, ",")), + scope.Quote(ToSnake(field.Relationship.ForeignFieldName)) + " " + primaryKeySqlType, + scope.Quote(ToSnake(field.Relationship.AssociationForeignFieldName)) + " " + primaryKeySqlType}, ",")), ).Exec() scope.Err(newScope.db.Error) }