diff --git a/callback_create.go b/callback_create.go index 750ddab9..5e51e6d9 100644 --- a/callback_create.go +++ b/callback_create.go @@ -26,7 +26,7 @@ func Create(scope *Scope) { var sqls, columns []string for _, field := range scope.Fields() { - if len(field.SqlTag) > 0 && !field.IsIgnored && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) { + if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Value)) } diff --git a/callback_update.go b/callback_update.go index 805cd7c6..36654281 100644 --- a/callback_update.go +++ b/callback_update.go @@ -25,16 +25,14 @@ func AssignUpdateAttributes(scope *Scope) { } func BeforeUpdate(scope *Scope) { - _, ok := scope.Get("gorm:update_column") - if !ok { + if _, ok := scope.Get("gorm:update_column"); !ok { scope.CallMethod("BeforeSave") scope.CallMethod("BeforeUpdate") } } func UpdateTimeStampWhenUpdate(scope *Scope) { - _, ok := scope.Get("gorm:update_column") - if !ok { + if _, ok := scope.Get("gorm:update_column"); !ok { scope.SetColumn("UpdatedAt", NowFunc()) } } @@ -50,7 +48,7 @@ func Update(scope *Scope) { } } else { for _, field := range scope.Fields() { - if !field.IsPrimaryKey && len(field.SqlTag) > 0 && !field.IsIgnored { + if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value))) } } diff --git a/field.go b/field.go index 9e9e6660..6ff0cfb7 100644 --- a/field.go +++ b/field.go @@ -19,8 +19,8 @@ type Field struct { Field reflect.Value Value interface{} Tag reflect.StructTag - SqlTag string Relationship *relationship + IsNormal bool IsBlank bool IsIgnored bool IsPrimaryKey bool diff --git a/scope.go b/scope.go index b36669da..bd853ca2 100644 --- a/scope.go +++ b/scope.go @@ -95,7 +95,20 @@ func (scope *Scope) PrimaryKey() string { return scope.primaryKey } - scope.primaryKey = ToSnake(GetPrimaryKey(scope.Value)) + var indirectValue = scope.IndirectValue() + + clone := scope + if indirectValue.Kind() == reflect.Slice { + clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) + } + + for _, field := range clone.Fields() { + if field.IsPrimaryKey { + scope.primaryKey = field.DBName + break + } + } + return scope.primaryKey } @@ -130,8 +143,12 @@ func (scope *Scope) SetColumn(column string, value interface{}) bool { if scope.Value == nil { return false } - - return setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value) + for _, field := range scope.Fields() { + if field.Name == column || field.DBName == column { + return setFieldValue(field.Field, value) + } + } + return false } // CallMethod invoke method with necessary argument @@ -262,13 +279,19 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { // Search for primary key tag identifier settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) - - if scope.PrimaryKey() == field.DBName { + if _, ok := settings["PRIMARY_KEY"]; ok { field.IsPrimaryKey = true } field.Tag = fieldStruct.Tag - field.SqlTag = scope.sqlTagForField(&field) + + tagIdentifier := "sql" + if scope.db != nil { + tagIdentifier = scope.db.parent.tagIdentifier + } + if fieldStruct.Tag.Get(tagIdentifier) == "-" { + field.IsIgnored = true + } if !field.IsIgnored { // parse association @@ -311,6 +334,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { if many2many != "" { field.Relationship.Kind = "many_to_many" } + } else { + field.IsNormal = true } case reflect.Struct: embedded := settings["EMBEDDED"] @@ -321,7 +346,9 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { fields = append(fields, field) } return fields - } else if !field.IsTime() && !field.IsScanner() { + } else if field.IsTime() || field.IsScanner() { + field.IsNormal = true + } else { if foreignKey == "" && scope.HasColumn(field.Name+"Id") { field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} } else if scope.HasColumn(foreignKey) { @@ -335,6 +362,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { } } } + default: + field.IsNormal = true } } return []*Field{&field} @@ -345,12 +374,16 @@ func (scope *Scope) Fields() map[string]*Field { var fields = map[string]*Field{} if scope.IndirectValue().IsValid() { scopeTyp := scope.IndirectValue().Type() + var hasPrimaryKey = false for i := 0; i < scopeTyp.NumField(); i++ { fieldStruct := scopeTyp.Field(i) if !ast.IsExported(fieldStruct.Name) { continue } for _, field := range scope.fieldFromStruct(fieldStruct) { + if field.IsPrimaryKey { + hasPrimaryKey = true + } if _, ok := fields[field.DBName]; ok { panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum())) } else { @@ -358,6 +391,12 @@ func (scope *Scope) Fields() map[string]*Field { } } } + + if !hasPrimaryKey { + if field, ok := fields["id"]; ok { + field.IsPrimaryKey = true + } + } } return fields } diff --git a/scope_private.go b/scope_private.go index aa75346f..3cfb47ab 100644 --- a/scope_private.go +++ b/scope_private.go @@ -305,11 +305,6 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) { var size = 255 fieldTag := field.Tag.Get(scope.db.parent.tagIdentifier) - if fieldTag == "-" { - field.IsIgnored = true - return - } - var setting = parseTagSetting(fieldTag) if value, ok := setting["SIZE"]; ok { @@ -481,8 +476,9 @@ func (scope *Scope) createJoinTable(field *Field) { func (scope *Scope) createTable() *Scope { var sqls []string for _, field := range scope.Fields() { - if !field.IsIgnored && len(field.SqlTag) > 0 { - sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) + if field.IsNormal { + sqlTag := scope.sqlTagForField(field) + sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag) } scope.createJoinTable(field) } @@ -535,8 +531,9 @@ func (scope *Scope) autoMigrate() *Scope { } else { for _, field := range scope.Fields() { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { - if len(field.SqlTag) > 0 && !field.IsIgnored { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec() + if field.IsNormal { + sqlTag := scope.sqlTagForField(field) + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec() } } scope.createJoinTable(field) diff --git a/utils.go b/utils.go index 50512143..b75aa1b6 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,6 @@ package gorm import ( "bytes" - "go/ast" "reflect" "strings" "sync" @@ -91,37 +90,6 @@ func SnakeToUpperCamel(s string) string { return u } -func GetPrimaryKey(value interface{}) string { - var indirectValue = reflect.Indirect(reflect.ValueOf(value)) - - if indirectValue.Kind() == reflect.Slice { - indirectValue = reflect.New(indirectValue.Type().Elem()).Elem() - } - - if indirectValue.IsValid() { - hasId := false - scopeTyp := indirectValue.Type() - for i := 0; i < scopeTyp.NumField(); i++ { - fieldStruct := scopeTyp.Field(i) - if !ast.IsExported(fieldStruct.Name) { - continue - } - - settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) - if _, ok := settings["PRIMARY_KEY"]; ok { - return fieldStruct.Name - } else if fieldStruct.Name == "Id" { - hasId = true - } - } - if hasId { - return "Id" - } - } - - return "" -} - func parseTagSetting(str string) map[string]string { tags := strings.Split(str, ";") setting := map[string]string{}