mirror of https://github.com/go-gorm/gorm.git
Support Multi primary keys
This commit is contained in:
parent
7b9272a15e
commit
49454839bd
|
@ -52,12 +52,12 @@ func (association *Association) getPrimaryKeys(values ...interface{}) []interfac
|
|||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryKeyField(); !primaryField.IsBlank {
|
||||
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
|
||||
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
||||
}
|
||||
}
|
||||
} else if reflectValue.Kind() == reflect.Struct {
|
||||
if primaryField := scope.New(value).PrimaryKeyField(); !primaryField.IsBlank {
|
||||
if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
|
||||
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||
value := association.Field.Field.Index(i)
|
||||
if primaryField := association.Scope.New(value.Interface()).PrimaryKeyField(); primaryField != nil {
|
||||
if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil {
|
||||
var included = false
|
||||
for _, primaryKey := range primaryKeys {
|
||||
if equalAsString(primaryKey, primaryField.Field.Interface()) {
|
||||
|
|
|
@ -34,7 +34,7 @@ func Create(scope *Scope) {
|
|||
}
|
||||
|
||||
returningKey := "*"
|
||||
primaryField := scope.PrimaryKeyField()
|
||||
primaryField := scope.PrimaryField()
|
||||
if primaryField != nil {
|
||||
returningKey = scope.Quote(primaryField.DBName)
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
|||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
if DB.NewScope(&HNPost{}).PrimaryKeyField() == nil {
|
||||
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
|
||||
t.Errorf("primary key with embedded struct should works")
|
||||
}
|
||||
|
||||
|
|
2
main.go
2
main.go
|
@ -431,7 +431,7 @@ func (s *DB) Association(column string) *Association {
|
|||
var err error
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
|
||||
if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank {
|
||||
if primaryField := scope.PrimaryField(); primaryField.IsBlank {
|
||||
err = errors.New("primary key can't be nil")
|
||||
} else {
|
||||
if field, ok := scope.FieldByName(column); ok {
|
||||
|
|
|
@ -12,10 +12,10 @@ import (
|
|||
)
|
||||
|
||||
type ModelStruct struct {
|
||||
PrimaryKeyField *StructField
|
||||
StructFields []*StructField
|
||||
ModelType reflect.Type
|
||||
TableName string
|
||||
PrimaryFields []*StructField
|
||||
StructFields []*StructField
|
||||
ModelType reflect.Type
|
||||
TableName string
|
||||
}
|
||||
|
||||
type StructField struct {
|
||||
|
@ -131,7 +131,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
||||
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
|
||||
field.IsPrimaryKey = true
|
||||
modelStruct.PrimaryKeyField = field
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||
}
|
||||
|
||||
if _, ok := sqlSettings["DEFAULT"]; ok {
|
||||
|
@ -240,7 +240,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
||||
if toField.IsPrimaryKey {
|
||||
modelStruct.PrimaryKeyField = toField
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
|
||||
}
|
||||
}
|
||||
continue
|
||||
|
@ -277,9 +277,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
}
|
||||
|
||||
if field.IsNormal {
|
||||
if modelStruct.PrimaryKeyField == nil && field.DBName == "id" {
|
||||
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
|
||||
field.IsPrimaryKey = true
|
||||
modelStruct.PrimaryKeyField = field
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||
}
|
||||
|
||||
if scope.db != nil {
|
||||
|
|
|
@ -29,8 +29,8 @@ func Preload(scope *Scope) {
|
|||
if field.Name == key && field.Relationship != nil {
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
primaryName := scope.PrimaryKeyField().Name
|
||||
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
||||
primaryName := scope.PrimaryField().Name
|
||||
associationPrimaryKey := scope.New(results).PrimaryField().Name
|
||||
|
||||
switch relation.Kind {
|
||||
case "has_one":
|
||||
|
|
17
scope.go
17
scope.go
|
@ -109,16 +109,21 @@ func (scope *Scope) HasError() bool {
|
|||
return scope.db.Error != nil
|
||||
}
|
||||
|
||||
func (scope *Scope) PrimaryKeyField() *Field {
|
||||
if field := scope.GetModelStruct().PrimaryKeyField; field != nil {
|
||||
return scope.Fields()[field.DBName]
|
||||
func (scope *Scope) PrimaryField() *Field {
|
||||
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||
if len(primaryFields) > 1 {
|
||||
if field, ok := scope.Fields()["id"]; ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return scope.Fields()[primaryFields[0].DBName]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrimaryKey get the primary key's column name
|
||||
func (scope *Scope) PrimaryKey() string {
|
||||
if field := scope.PrimaryKeyField(); field != nil {
|
||||
if field := scope.PrimaryField(); field != nil {
|
||||
return field.DBName
|
||||
}
|
||||
return ""
|
||||
|
@ -126,13 +131,13 @@ func (scope *Scope) PrimaryKey() string {
|
|||
|
||||
// PrimaryKeyZero check the primary key is blank or not
|
||||
func (scope *Scope) PrimaryKeyZero() bool {
|
||||
field := scope.PrimaryKeyField()
|
||||
field := scope.PrimaryField()
|
||||
return field == nil || field.IsBlank
|
||||
}
|
||||
|
||||
// PrimaryKeyValue get the primary key's value
|
||||
func (scope *Scope) PrimaryKeyValue() interface{} {
|
||||
if field := scope.PrimaryKeyField(); field != nil && field.Field.IsValid() {
|
||||
if field := scope.PrimaryField(); field != nil && field.Field.IsValid() {
|
||||
return field.Field.Interface()
|
||||
}
|
||||
return 0
|
||||
|
|
|
@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
||||
joinTable := joinTableHandler.Table(scope.db, relationship)
|
||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255)
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||
scope.Quote(joinTable),
|
||||
strings.Join([]string{
|
||||
|
|
Loading…
Reference in New Issue