mirror of https://github.com/go-gorm/gorm.git
Support Multi primary keys
This commit is contained in:
parent
7b9272a15e
commit
49454839bd
|
@ -52,12 +52,12 @@ func (association *Association) getPrimaryKeys(values ...interface{}) []interfac
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
if reflectValue.Kind() == reflect.Slice {
|
if reflectValue.Kind() == reflect.Slice {
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryKeyField(); !primaryField.IsBlank {
|
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
|
||||||
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if reflectValue.Kind() == reflect.Struct {
|
} else if reflectValue.Kind() == reflect.Struct {
|
||||||
if primaryField := scope.New(value).PrimaryKeyField(); !primaryField.IsBlank {
|
if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
|
||||||
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||||
value := association.Field.Field.Index(i)
|
value := association.Field.Field.Index(i)
|
||||||
if primaryField := association.Scope.New(value.Interface()).PrimaryKeyField(); primaryField != nil {
|
if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil {
|
||||||
var included = false
|
var included = false
|
||||||
for _, primaryKey := range primaryKeys {
|
for _, primaryKey := range primaryKeys {
|
||||||
if equalAsString(primaryKey, primaryField.Field.Interface()) {
|
if equalAsString(primaryKey, primaryField.Field.Interface()) {
|
||||||
|
|
|
@ -34,7 +34,7 @@ func Create(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
returningKey := "*"
|
returningKey := "*"
|
||||||
primaryField := scope.PrimaryKeyField()
|
primaryField := scope.PrimaryField()
|
||||||
if primaryField != nil {
|
if primaryField != nil {
|
||||||
returningKey = scope.Quote(primaryField.DBName)
|
returningKey = scope.Quote(primaryField.DBName)
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||||
t.Errorf("embedded struct's value should be scanned correctly")
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.NewScope(&HNPost{}).PrimaryKeyField() == nil {
|
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
|
||||||
t.Errorf("primary key with embedded struct should works")
|
t.Errorf("primary key with embedded struct should works")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
main.go
2
main.go
|
@ -431,7 +431,7 @@ func (s *DB) Association(column string) *Association {
|
||||||
var err error
|
var err error
|
||||||
scope := s.clone().NewScope(s.Value)
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
|
||||||
if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank {
|
if primaryField := scope.PrimaryField(); primaryField.IsBlank {
|
||||||
err = errors.New("primary key can't be nil")
|
err = errors.New("primary key can't be nil")
|
||||||
} else {
|
} else {
|
||||||
if field, ok := scope.FieldByName(column); ok {
|
if field, ok := scope.FieldByName(column); ok {
|
||||||
|
|
|
@ -12,10 +12,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelStruct struct {
|
type ModelStruct struct {
|
||||||
PrimaryKeyField *StructField
|
PrimaryFields []*StructField
|
||||||
StructFields []*StructField
|
StructFields []*StructField
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
TableName string
|
TableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type StructField struct {
|
type StructField struct {
|
||||||
|
@ -131,7 +131,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
||||||
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
|
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
modelStruct.PrimaryKeyField = field
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := sqlSettings["DEFAULT"]; ok {
|
if _, ok := sqlSettings["DEFAULT"]; ok {
|
||||||
|
@ -240,7 +240,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
||||||
if toField.IsPrimaryKey {
|
if toField.IsPrimaryKey {
|
||||||
modelStruct.PrimaryKeyField = toField
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
@ -277,9 +277,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
if modelStruct.PrimaryKeyField == nil && field.DBName == "id" {
|
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
modelStruct.PrimaryKeyField = field
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.db != nil {
|
if scope.db != nil {
|
||||||
|
|
|
@ -29,8 +29,8 @@ func Preload(scope *Scope) {
|
||||||
if field.Name == key && field.Relationship != nil {
|
if field.Name == key && field.Relationship != nil {
|
||||||
results := makeSlice(field.Struct.Type)
|
results := makeSlice(field.Struct.Type)
|
||||||
relation := field.Relationship
|
relation := field.Relationship
|
||||||
primaryName := scope.PrimaryKeyField().Name
|
primaryName := scope.PrimaryField().Name
|
||||||
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
|
associationPrimaryKey := scope.New(results).PrimaryField().Name
|
||||||
|
|
||||||
switch relation.Kind {
|
switch relation.Kind {
|
||||||
case "has_one":
|
case "has_one":
|
||||||
|
|
17
scope.go
17
scope.go
|
@ -109,16 +109,21 @@ func (scope *Scope) HasError() bool {
|
||||||
return scope.db.Error != nil
|
return scope.db.Error != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) PrimaryKeyField() *Field {
|
func (scope *Scope) PrimaryField() *Field {
|
||||||
if field := scope.GetModelStruct().PrimaryKeyField; field != nil {
|
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||||
return scope.Fields()[field.DBName]
|
if len(primaryFields) > 1 {
|
||||||
|
if field, ok := scope.Fields()["id"]; ok {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope.Fields()[primaryFields[0].DBName]
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrimaryKey get the primary key's column name
|
// PrimaryKey get the primary key's column name
|
||||||
func (scope *Scope) PrimaryKey() string {
|
func (scope *Scope) PrimaryKey() string {
|
||||||
if field := scope.PrimaryKeyField(); field != nil {
|
if field := scope.PrimaryField(); field != nil {
|
||||||
return field.DBName
|
return field.DBName
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
@ -126,13 +131,13 @@ func (scope *Scope) PrimaryKey() string {
|
||||||
|
|
||||||
// PrimaryKeyZero check the primary key is blank or not
|
// PrimaryKeyZero check the primary key is blank or not
|
||||||
func (scope *Scope) PrimaryKeyZero() bool {
|
func (scope *Scope) PrimaryKeyZero() bool {
|
||||||
field := scope.PrimaryKeyField()
|
field := scope.PrimaryField()
|
||||||
return field == nil || field.IsBlank
|
return field == nil || field.IsBlank
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrimaryKeyValue get the primary key's value
|
// PrimaryKeyValue get the primary key's value
|
||||||
func (scope *Scope) PrimaryKeyValue() interface{} {
|
func (scope *Scope) PrimaryKeyValue() interface{} {
|
||||||
if field := scope.PrimaryKeyField(); field != nil && field.Field.IsValid() {
|
if field := scope.PrimaryField(); field != nil && field.Field.IsValid() {
|
||||||
return field.Field.Interface()
|
return field.Field.Interface()
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
||||||
joinTable := joinTableHandler.Table(scope.db, relationship)
|
joinTable := joinTableHandler.Table(scope.db, relationship)
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255)
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
scope.Quote(joinTable),
|
scope.Quote(joinTable),
|
||||||
strings.Join([]string{
|
strings.Join([]string{
|
||||||
|
|
Loading…
Reference in New Issue