From 828bf5d5fd9f24a045b6e5e9ed0429cf3d81f387 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Nov 2014 11:46:21 +0800 Subject: [PATCH] Fix can't update record with customized primary key --- customize_column_test.go | 25 ++++++++++----- scope.go | 67 +++++++++++++++++++++------------------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/customize_column_test.go b/customize_column_test.go index c7f72f97..aec3f39e 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -6,7 +6,7 @@ import ( ) type CustomizeColumn struct { - Id int64 `gorm:"column:mapped_id; primary_key:yes"` + ID int64 `gorm:"column:mapped_id; primary_key:yes"` Name string `gorm:"column:mapped_name"` Date time.Time `gorm:"column:mapped_time"` } @@ -14,8 +14,8 @@ type CustomizeColumn struct { // Make sure an ignored field does not interfere with another field's custom // column name that matches the ignored field. type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` + Body string `sql:"-"` + RawBody string `gorm:"column:body"` } func TestCustomizeColumn(t *testing.T) { @@ -34,16 +34,25 @@ func TestCustomizeColumn(t *testing.T) { } expected := "foo" - cc := CustomizeColumn{Id: 666, Name: expected, Date: time.Now()} + cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} - if count := DB.Save(&cc).RowsAffected; count != 1 { + if count := DB.Create(&cc).RowsAffected; count != 1 { t.Error("There should be one record be affected when create record") } - var ccs []CustomizeColumn - DB.Find(&ccs) + var cc1 CustomizeColumn + DB.First(&cc1, 666) - if len(ccs) > 0 && ccs[0].Name != expected && ccs[0].Id != 666 { + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, 666) + if cc2.Name != "bar" { t.Errorf("Failed to query CustomizeColumn") } } diff --git a/scope.go b/scope.go index 15413e1b..87405a6d 100644 --- a/scope.go +++ b/scope.go @@ -12,16 +12,16 @@ import ( ) type Scope struct { - Value interface{} - indirectValue *reflect.Value - Search *search - Sql string - SqlVars []interface{} - db *DB - skipLeft bool - primaryKey string - instanceId string - fields map[string]*Field + Value interface{} + indirectValue *reflect.Value + Search *search + Sql string + SqlVars []interface{} + db *DB + skipLeft bool + primaryKeyField *Field + instanceId string + fields map[string]*Field } func (scope *Scope) IndirectValue() reflect.Value { @@ -90,27 +90,33 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } -// PrimaryKey get the primary key's column name -func (scope *Scope) PrimaryKey() string { - if scope.primaryKey != "" { - return scope.primaryKey - } +func (scope *Scope) PrimaryKeyField() *Field { + if scope.primaryKeyField == nil { + var indirectValue = scope.IndirectValue() - var indirectValue = scope.IndirectValue() + clone := scope + if indirectValue.Kind() == reflect.Slice { + clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) + } - clone := scope - if indirectValue.Kind() == reflect.Slice { - clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) - } - - for _, field := range clone.Fields() { - if field.IsPrimaryKey { - scope.primaryKey = field.DBName - break + for _, field := range clone.Fields() { + if field.IsPrimaryKey { + scope.primaryKeyField = field + break + } } } - return scope.primaryKey + return scope.primaryKeyField +} + +// PrimaryKey get the primary key's column name +func (scope *Scope) PrimaryKey() string { + if field := scope.PrimaryKeyField(); field != nil { + return field.DBName + } else { + return "" + } } // PrimaryKeyZero check the primary key is blank or not @@ -120,12 +126,11 @@ func (scope *Scope) PrimaryKeyZero() bool { // PrimaryKeyValue get the primary key's value func (scope *Scope) PrimaryKeyValue() interface{} { - if scope.IndirectValue().Kind() == reflect.Struct { - if field := scope.IndirectValue().FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() { - return field.Interface() - } + if field := scope.PrimaryKeyField(); field != nil { + return field.Field.Interface() + } else { + return 0 } - return 0 } // HasColumn to check if has column