From c108cf022b531adaf85cad3b609e7580cbe212d7 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Wed, 22 Oct 2014 17:33:13 +0200 Subject: [PATCH] Add support for signed and unsigned integer types as primary key type --- query_test.go | 13 ++++++++++++ scope_private.go | 8 ++++---- scope_test.go | 1 - search.go | 51 +++++++++++++++++++++++++++++++++++++++++------- structs_test.go | 3 +-- 5 files changed, 62 insertions(+), 14 deletions(-) diff --git a/query_test.go b/query_test.go index b7ab1f5d..5867612f 100644 --- a/query_test.go +++ b/query_test.go @@ -50,6 +50,19 @@ func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { } } +func TestUIntPrimaryKey(t *testing.T) { + var animal Animal + DB.First(&animal, uint64(1)) + if animal.Counter != 1 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + } + + DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) + if animal.Counter != 2 { + t.Errorf("Fetch a record from with a non-int primary key should work, but failed") + } +} + func TestFindAsSliceOfPointers(t *testing.T) { DB.Save(&User{Name: "user"}) diff --git a/scope_private.go b/scope_private.go index 85c5462a..fb8275b6 100644 --- a/scope_private.go +++ b/scope_private.go @@ -25,11 +25,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } else { str = value } - case int, int64, int32: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return scope.primaryCondiation(scope.AddToVars(value)) case sql.NullInt64: return scope.primaryCondiation(scope.AddToVars(value.Int64)) - case []int64, []int, []int32, []string: + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} case map[string]interface{}: @@ -84,9 +84,9 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } - case int, int64, int32: + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) - case []int64, []int, []int32, []string: + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: if reflect.ValueOf(value).Len() > 0 { str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} diff --git a/scope_test.go b/scope_test.go index cf2f626c..42458995 100644 --- a/scope_test.go +++ b/scope_test.go @@ -2,7 +2,6 @@ package gorm_test import ( "github.com/jinzhu/gorm" - "testing" ) diff --git a/search.go b/search.go index 44171365..cda297b6 100644 --- a/search.go +++ b/search.go @@ -125,17 +125,54 @@ func (s *search) table(name string) *search { } func (s *search) getInterfaceAsSql(value interface{}) (str string) { - switch value := value.(type) { + var s_num int64 + var u_num uint64 + var isString, unsigned bool = false, false + + switch value.(type) { case string: - str = value + str = value.(string) + isString = true case int: - if value < 0 { - str = "" - } else { - str = strconv.Itoa(value) - } + s_num = int64(value.(int)) + case int8: + s_num = int64(value.(int8)) + case int16: + s_num = int64(value.(int16)) + case int32: + s_num = int64(value.(int32)) + case int64: + s_num = int64(value.(int64)) + case uint: + u_num = uint64(value.(uint)) + unsigned = true + case uint8: + u_num = uint64(value.(uint8)) + unsigned = true + case uint16: + u_num = uint64(value.(uint16)) + unsigned = true + case uint32: + u_num = uint64(value.(uint32)) + unsigned = true + case uint64: + u_num = uint64(value.(uint64)) + unsigned = true default: s.db.err(InvalidSql) } + + if !isString { + if unsigned { + str = strconv.FormatUint(u_num, 10) + } else { + if s_num < 0 { + str = "" + } else { + str = strconv.FormatInt(s_num, 10) + } + } + } + return } diff --git a/structs_test.go b/structs_test.go index 023adcd0..d1877376 100644 --- a/structs_test.go +++ b/structs_test.go @@ -94,7 +94,6 @@ type Role struct { Name string } - func (role *Role) Scan(value interface{}) error { if b, ok := value.([]uint8); ok { role.Name = string(b) @@ -126,7 +125,7 @@ func (i *Num) Scan(src interface{}) error { } type Animal struct { - Counter int64 `gorm:"primary_key:yes"` + Counter uint64 `gorm:"primary_key:yes"` Name string From string //test reserved sql keyword as field name CreatedAt time.Time