From c108cf022b531adaf85cad3b609e7580cbe212d7 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Wed, 22 Oct 2014 17:33:13 +0200 Subject: [PATCH 1/3] Add support for signed and unsigned integer types as primary key type --- query_test.go | 13 ++++++++++++ scope_private.go | 8 ++++---- scope_test.go | 1 - search.go | 51 +++++++++++++++++++++++++++++++++++++++++------- structs_test.go | 3 +-- 5 files changed, 62 insertions(+), 14 deletions(-) diff --git a/query_test.go b/query_test.go index b7ab1f5d..5867612f 100644 --- a/query_test.go +++ b/query_test.go @@ -50,6 +50,19 @@ func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { } } +func TestUIntPrimaryKey(t *testing.T) { + var animal Animal + DB.First(&animal, uint64(1)) + if animal.Counter != 1 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + } + + DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) + if animal.Counter != 2 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + } +} + func TestFindAsSliceOfPointers(t *testing.T) { DB.Save(&User{Name: "user"}) diff --git a/scope_private.go b/scope_private.go index 85c5462a..fb8275b6 100644 --- a/scope_private.go +++ b/scope_private.go @@ -25,11 +25,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } else { str = value } - case int, int64, int32: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return scope.primaryCondiation(scope.AddToVars(value)) case sql.NullInt64: return scope.primaryCondiation(scope.AddToVars(value.Int64)) - case []int64, []int, []int32, []string: + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} case map[string]interface{}: @@ -84,9 +84,9 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } - case int, int64, int32: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) - case []int64, []int, []int32, []string: + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: if reflect.ValueOf(value).Len() > 0 { str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} diff --git a/scope_test.go b/scope_test.go index cf2f626c..42458995 100644 --- a/scope_test.go +++ b/scope_test.go @@ -2,7 +2,6 @@ package gorm_test import ( "github.com/jinzhu/gorm" - "testing" ) diff --git a/search.go b/search.go index 44171365..cda297b6 100644 --- a/search.go +++ b/search.go @@ -125,17 +125,54 @@ func (s *search) table(name string) *search { } func (s *search) getInterfaceAsSql(value interface{}) (str string) { - switch value := value.(type) { + var s_num int64 + var u_num uint64 + var isString, unsigned bool = false, false + + switch value.(type) { case string: - str = value + str = value.(string) + isString = true case int: - if value < 0 { - str = "" - } else { - str = strconv.Itoa(value) - } + s_num = int64(value.(int)) + case int8: + s_num = int64(value.(int8)) + case int16: + s_num = int64(value.(int16)) + case int32: + s_num = int64(value.(int32)) + case int64: + s_num = int64(value.(int64)) + case uint: + u_num = uint64(value.(uint)) + unsigned = true + case uint8: + u_num = uint64(value.(uint8)) + unsigned = true + case uint16: + u_num = uint64(value.(uint16)) + unsigned = true + case uint32: + u_num = uint64(value.(uint32)) + unsigned = true + case uint64: + u_num = uint64(value.(uint64)) + unsigned = true default: s.db.err(InvalidSql) } + + if !isString { + if unsigned { + str = strconv.FormatUint(u_num, 10) + } else { + if s_num < 0 { + str = "" + } else { + str = strconv.FormatInt(s_num, 10) + } + } + } + return } diff --git a/structs_test.go b/structs_test.go index 023adcd0..d1877376 100644 --- a/structs_test.go +++ b/structs_test.go @@ -94,7 +94,6 @@ type Role struct { Name string } - func (role *Role) Scan(value interface{}) error { if b, ok := value.([]uint8); ok { role.Name = string(b) @@ -126,7 +125,7 @@ func (i *Num) Scan(src interface{}) error { } type Animal struct { - Counter int64 `gorm:"primary_key:yes"` + Counter uint64 `gorm:"primary_key:yes"` Name string From string //test reserved sql keyword as field name CreatedAt time.Time From aab0b284575f1f6d4ae082eff682b2c0d3f5e02f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Nov 2014 11:46:21 +0800 Subject: [PATCH 2/3] Fix can't update record with customized primary key --- customize_column_test.go | 25 ++++++++++----- scope.go | 67 +++++++++++++++++++++------------------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/customize_column_test.go b/customize_column_test.go index c7f72f97..aec3f39e 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -6,7 +6,7 @@ import ( ) type CustomizeColumn struct { - Id int64 `gorm:"column:mapped_id; primary_key:yes"` + ID int64 `gorm:"column:mapped_id; primary_key:yes"` Name string `gorm:"column:mapped_name"` Date time.Time `gorm:"column:mapped_time"` } @@ -14,8 +14,8 @@ type CustomizeColumn struct { // Make sure an ignored field does not interfere with another field's custom // column name that matches the ignored field. type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` + Body string `sql:"-"` + RawBody string `gorm:"column:body"` } func TestCustomizeColumn(t *testing.T) { @@ -34,16 +34,25 @@ func TestCustomizeColumn(t *testing.T) { } expected := "foo" - cc := CustomizeColumn{Id: 666, Name: expected, Date: time.Now()} + cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} - if count := DB.Save(&cc).RowsAffected; count != 1 { + if count := DB.Create(&cc).RowsAffected; count != 1 { t.Error("There should be one record be affected when create record") } - var ccs []CustomizeColumn - DB.Find(&ccs) + var cc1 CustomizeColumn + DB.First(&cc1, 666) - if len(ccs) > 0 && ccs[0].Name != expected && ccs[0].Id != 666 { + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, 666) + if cc2.Name != "bar" { t.Errorf("Failed to query CustomizeColumn") } } diff --git a/scope.go b/scope.go index 15413e1b..87405a6d 100644 --- a/scope.go +++ b/scope.go @@ -12,16 +12,16 @@ import ( ) type Scope struct { - Value interface{} - indirectValue *reflect.Value - Search *search - Sql string - SqlVars []interface{} - db *DB - skipLeft bool - primaryKey string - instanceId string - fields map[string]*Field + Value interface{} + indirectValue *reflect.Value + Search *search + Sql string + SqlVars []interface{} + db *DB + skipLeft bool + primaryKeyField *Field + instanceId string + fields map[string]*Field } func (scope *Scope) IndirectValue() reflect.Value { @@ -90,27 +90,33 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } -// PrimaryKey get the primary key's column name -func (scope *Scope) PrimaryKey() string { - if scope.primaryKey != "" { - return scope.primaryKey - } +func (scope *Scope) PrimaryKeyField() *Field { + if scope.primaryKeyField == nil { + var indirectValue = scope.IndirectValue() - var indirectValue = scope.IndirectValue() + clone := scope + if indirectValue.Kind() == reflect.Slice { + clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) + } - clone := scope - if indirectValue.Kind() == reflect.Slice { - clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface()) - } - - for _, field := range clone.Fields() { - if field.IsPrimaryKey { - scope.primaryKey = field.DBName - break + for _, field := range clone.Fields() { + if field.IsPrimaryKey { + scope.primaryKeyField = field + break + } } } - return scope.primaryKey + return scope.primaryKeyField +} + +// PrimaryKey get the primary key's column name +func (scope *Scope) PrimaryKey() string { + if field := scope.PrimaryKeyField(); field != nil { + return field.DBName + } else { + return "" + } } // PrimaryKeyZero check the primary key is blank or not @@ -120,12 +126,11 @@ func (scope *Scope) PrimaryKeyZero() bool { // PrimaryKeyValue get the primary key's value func (scope *Scope) PrimaryKeyValue() interface{} { - if scope.IndirectValue().Kind() == reflect.Struct { - if field := scope.IndirectValue().FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() { - return field.Interface() - } + if field := scope.PrimaryKeyField(); field != nil { + return field.Field.Interface() + } else { + return 0 } - return 0 } // HasColumn to check if has column From dcdbb725701050ef1d1a6106be365ed62f7f3c75 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Nov 2014 13:36:12 +0800 Subject: [PATCH 3/3] Refactor getInterfaceAsSql --- search.go | 49 +++++-------------------------------------------- 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/search.go b/search.go index cda297b6..8591d659 100644 --- a/search.go +++ b/search.go @@ -1,6 +1,6 @@ package gorm -import "strconv" +import "fmt" type search struct { db *DB @@ -125,54 +125,15 @@ func (s *search) table(name string) *search { } func (s *search) getInterfaceAsSql(value interface{}) (str string) { - var s_num int64 - var u_num uint64 - var isString, unsigned bool = false, false - switch value.(type) { - case string: - str = value.(string) - isString = true - case int: - s_num = int64(value.(int)) - case int8: - s_num = int64(value.(int8)) - case int16: - s_num = int64(value.(int16)) - case int32: - s_num = int64(value.(int32)) - case int64: - s_num = int64(value.(int64)) - case uint: - u_num = uint64(value.(uint)) - unsigned = true - case uint8: - u_num = uint64(value.(uint8)) - unsigned = true - case uint16: - u_num = uint64(value.(uint16)) - unsigned = true - case uint32: - u_num = uint64(value.(uint32)) - unsigned = true - case uint64: - u_num = uint64(value.(uint64)) - unsigned = true + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + str = fmt.Sprintf("%v", value) default: s.db.err(InvalidSql) } - if !isString { - if unsigned { - str = strconv.FormatUint(u_num, 10) - } else { - if s_num < 0 { - str = "" - } else { - str = strconv.FormatInt(s_num, 10) - } - } + if str == "-1" { + return "" } - return }