diff --git a/association.go b/association.go index 8dd844ed..b011971a 100644 --- a/association.go +++ b/association.go @@ -52,12 +52,12 @@ func (association *Association) getPrimaryKeys(values ...interface{}) []interfac reflectValue := reflect.Indirect(reflect.ValueOf(value)) if reflectValue.Kind() == reflect.Slice { for i := 0; i < reflectValue.Len(); i++ { - if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryKeyField(); !primaryField.IsBlank { + if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank { primaryKeys = append(primaryKeys, primaryField.Field.Interface()) } } } else if reflectValue.Kind() == reflect.Struct { - if primaryField := scope.New(value).PrimaryKeyField(); !primaryField.IsBlank { + if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank { primaryKeys = append(primaryKeys, primaryField.Field.Interface()) } } @@ -81,7 +81,7 @@ func (association *Association) Delete(values ...interface{}) *Association { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) - if primaryField := association.Scope.New(value.Interface()).PrimaryKeyField(); primaryField != nil { + if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { var included = false for _, primaryKey := range primaryKeys { if equalAsString(primaryKey, primaryField.Field.Interface()) { diff --git a/callback_create.go b/callback_create.go index 1c41c03f..6eddddb3 100644 --- a/callback_create.go +++ b/callback_create.go @@ -34,7 +34,7 @@ func Create(scope *Scope) { } returningKey := "*" - primaryField := scope.PrimaryKeyField() + primaryField := scope.PrimaryField() if primaryField != nil { returningKey = scope.Quote(primaryField.DBName) } diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 74997f13..7be75d99 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -36,7 +36,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) { t.Errorf("embedded struct's value should be scanned correctly") } - if DB.NewScope(&HNPost{}).PrimaryKeyField() == nil { + if DB.NewScope(&HNPost{}).PrimaryField() == nil { t.Errorf("primary key with embedded struct should works") } diff --git a/main.go b/main.go index 0db07079..17102cde 100644 --- a/main.go +++ b/main.go @@ -431,7 +431,7 @@ func (s *DB) Association(column string) *Association { var err error scope := s.clone().NewScope(s.Value) - if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank { + if primaryField := scope.PrimaryField(); primaryField.IsBlank { err = errors.New("primary key can't be nil") } else { if field, ok := scope.FieldByName(column); ok { diff --git a/model_struct.go b/model_struct.go index 2a0ab743..17605e50 100644 --- a/model_struct.go +++ b/model_struct.go @@ -12,10 +12,10 @@ import ( ) type ModelStruct struct { - PrimaryKeyField *StructField - StructFields []*StructField - ModelType reflect.Type - TableName string + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + TableName string } type StructField struct { @@ -131,7 +131,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { gormSettings := parseTagSetting(field.Tag.Get("gorm")) if _, ok := gormSettings["PRIMARY_KEY"]; ok { field.IsPrimaryKey = true - modelStruct.PrimaryKeyField = field + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } if _, ok := sqlSettings["DEFAULT"]; ok { @@ -240,7 +240,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { toField.Names = append([]string{fieldStruct.Name}, toField.Names...) modelStruct.StructFields = append(modelStruct.StructFields, toField) if toField.IsPrimaryKey { - modelStruct.PrimaryKeyField = toField + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) } } continue @@ -277,9 +277,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if field.IsNormal { - if modelStruct.PrimaryKeyField == nil && field.DBName == "id" { + if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { field.IsPrimaryKey = true - modelStruct.PrimaryKeyField = field + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } if scope.db != nil { diff --git a/preload.go b/preload.go index 7ca8ca55..8880870e 100644 --- a/preload.go +++ b/preload.go @@ -29,8 +29,8 @@ func Preload(scope *Scope) { if field.Name == key && field.Relationship != nil { results := makeSlice(field.Struct.Type) relation := field.Relationship - primaryName := scope.PrimaryKeyField().Name - associationPrimaryKey := scope.New(results).PrimaryKeyField().Name + primaryName := scope.PrimaryField().Name + associationPrimaryKey := scope.New(results).PrimaryField().Name switch relation.Kind { case "has_one": diff --git a/scope.go b/scope.go index cabe9743..2cfeaa9d 100644 --- a/scope.go +++ b/scope.go @@ -109,16 +109,21 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } -func (scope *Scope) PrimaryKeyField() *Field { - if field := scope.GetModelStruct().PrimaryKeyField; field != nil { - return scope.Fields()[field.DBName] +func (scope *Scope) PrimaryField() *Field { + if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { + if len(primaryFields) > 1 { + if field, ok := scope.Fields()["id"]; ok { + return field + } + } + return scope.Fields()[primaryFields[0].DBName] } return nil } // PrimaryKey get the primary key's column name func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryKeyField(); field != nil { + if field := scope.PrimaryField(); field != nil { return field.DBName } return "" @@ -126,13 +131,13 @@ func (scope *Scope) PrimaryKey() string { // PrimaryKeyZero check the primary key is blank or not func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryKeyField() + field := scope.PrimaryField() return field == nil || field.IsBlank } // PrimaryKeyValue get the primary key's value func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryKeyField(); field != nil && field.Field.IsValid() { + if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { return field.Field.Interface() } return 0 diff --git a/scope_private.go b/scope_private.go index 3745f0fc..e4262d64 100644 --- a/scope_private.go +++ b/scope_private.go @@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) { joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) joinTable := joinTableHandler.Table(scope.db, relationship) if !scope.Dialect().HasTable(scope, joinTable) { - primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255) + primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join([]string{