Support Multi primary keys

This commit is contained in:
Jinzhu 2015-03-11 11:28:30 +08:00
parent 7b9272a15e
commit 49454839bd
8 changed files with 28 additions and 23 deletions

View File

@ -52,12 +52,12 @@ func (association *Association) getPrimaryKeys(values ...interface{}) []interfac
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice { if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryKeyField(); !primaryField.IsBlank { if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface()) primaryKeys = append(primaryKeys, primaryField.Field.Interface())
} }
} }
} else if reflectValue.Kind() == reflect.Struct { } else if reflectValue.Kind() == reflect.Struct {
if primaryField := scope.New(value).PrimaryKeyField(); !primaryField.IsBlank { if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface()) primaryKeys = append(primaryKeys, primaryField.Field.Interface())
} }
} }
@ -81,7 +81,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
leftValues := reflect.Zero(association.Field.Field.Type()) leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ { for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i) value := association.Field.Field.Index(i)
if primaryField := association.Scope.New(value.Interface()).PrimaryKeyField(); primaryField != nil { if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil {
var included = false var included = false
for _, primaryKey := range primaryKeys { for _, primaryKey := range primaryKeys {
if equalAsString(primaryKey, primaryField.Field.Interface()) { if equalAsString(primaryKey, primaryField.Field.Interface()) {

View File

@ -34,7 +34,7 @@ func Create(scope *Scope) {
} }
returningKey := "*" returningKey := "*"
primaryField := scope.PrimaryKeyField() primaryField := scope.PrimaryField()
if primaryField != nil { if primaryField != nil {
returningKey = scope.Quote(primaryField.DBName) returningKey = scope.Quote(primaryField.DBName)
} }

View File

@ -36,7 +36,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
if DB.NewScope(&HNPost{}).PrimaryKeyField() == nil { if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works") t.Errorf("primary key with embedded struct should works")
} }

View File

@ -431,7 +431,7 @@ func (s *DB) Association(column string) *Association {
var err error var err error
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank { if primaryField := scope.PrimaryField(); primaryField.IsBlank {
err = errors.New("primary key can't be nil") err = errors.New("primary key can't be nil")
} else { } else {
if field, ok := scope.FieldByName(column); ok { if field, ok := scope.FieldByName(column); ok {

View File

@ -12,7 +12,7 @@ import (
) )
type ModelStruct struct { type ModelStruct struct {
PrimaryKeyField *StructField PrimaryFields []*StructField
StructFields []*StructField StructFields []*StructField
ModelType reflect.Type ModelType reflect.Type
TableName string TableName string
@ -131,7 +131,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
gormSettings := parseTagSetting(field.Tag.Get("gorm")) gormSettings := parseTagSetting(field.Tag.Get("gorm"))
if _, ok := gormSettings["PRIMARY_KEY"]; ok { if _, ok := gormSettings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true field.IsPrimaryKey = true
modelStruct.PrimaryKeyField = field modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
} }
if _, ok := sqlSettings["DEFAULT"]; ok { if _, ok := sqlSettings["DEFAULT"]; ok {
@ -240,7 +240,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
toField.Names = append([]string{fieldStruct.Name}, toField.Names...) toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, toField) modelStruct.StructFields = append(modelStruct.StructFields, toField)
if toField.IsPrimaryKey { if toField.IsPrimaryKey {
modelStruct.PrimaryKeyField = toField modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
} }
} }
continue continue
@ -277,9 +277,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
if field.IsNormal { if field.IsNormal {
if modelStruct.PrimaryKeyField == nil && field.DBName == "id" { if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
field.IsPrimaryKey = true field.IsPrimaryKey = true
modelStruct.PrimaryKeyField = field modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
} }
if scope.db != nil { if scope.db != nil {

View File

@ -29,8 +29,8 @@ func Preload(scope *Scope) {
if field.Name == key && field.Relationship != nil { if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
relation := field.Relationship relation := field.Relationship
primaryName := scope.PrimaryKeyField().Name primaryName := scope.PrimaryField().Name
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name associationPrimaryKey := scope.New(results).PrimaryField().Name
switch relation.Kind { switch relation.Kind {
case "has_one": case "has_one":

View File

@ -109,16 +109,21 @@ func (scope *Scope) HasError() bool {
return scope.db.Error != nil return scope.db.Error != nil
} }
func (scope *Scope) PrimaryKeyField() *Field { func (scope *Scope) PrimaryField() *Field {
if field := scope.GetModelStruct().PrimaryKeyField; field != nil { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
return scope.Fields()[field.DBName] if len(primaryFields) > 1 {
if field, ok := scope.Fields()["id"]; ok {
return field
}
}
return scope.Fields()[primaryFields[0].DBName]
} }
return nil return nil
} }
// PrimaryKey get the primary key's column name // PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string { func (scope *Scope) PrimaryKey() string {
if field := scope.PrimaryKeyField(); field != nil { if field := scope.PrimaryField(); field != nil {
return field.DBName return field.DBName
} }
return "" return ""
@ -126,13 +131,13 @@ func (scope *Scope) PrimaryKey() string {
// PrimaryKeyZero check the primary key is blank or not // PrimaryKeyZero check the primary key is blank or not
func (scope *Scope) PrimaryKeyZero() bool { func (scope *Scope) PrimaryKeyZero() bool {
field := scope.PrimaryKeyField() field := scope.PrimaryField()
return field == nil || field.IsBlank return field == nil || field.IsBlank
} }
// PrimaryKeyValue get the primary key's value // PrimaryKeyValue get the primary key's value
func (scope *Scope) PrimaryKeyValue() interface{} { func (scope *Scope) PrimaryKeyValue() interface{} {
if field := scope.PrimaryKeyField(); field != nil && field.Field.IsValid() { if field := scope.PrimaryField(); field != nil && field.Field.IsValid() {
return field.Field.Interface() return field.Field.Interface()
} }
return 0 return 0

View File

@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
joinTable := joinTableHandler.Table(scope.db, relationship) joinTable := joinTableHandler.Table(scope.db, relationship)
if !scope.Dialect().HasTable(scope, joinTable) { if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255) primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
scope.Quote(joinTable), scope.Quote(joinTable),
strings.Join([]string{ strings.Join([]string{