forked from mirror/gorm
Support custom types from base types
This commit is contained in:
parent
4969fc9cb5
commit
e6c953dd4c
|
@ -1,10 +1,17 @@
|
||||||
package dialect
|
package dialect
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var timeType = reflect.TypeOf(time.Time{})
|
||||||
|
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
BinVar(i int) string
|
BinVar(i int) string
|
||||||
SupportLastInsertId() bool
|
SupportLastInsertId() bool
|
||||||
SqlTag(column interface{}, size int) string
|
SqlTag(value reflect.Value, size int) string
|
||||||
PrimaryKeyTag(column interface{}, size int) string
|
PrimaryKeyTag(value reflect.Value, size int) string
|
||||||
ReturningStr(key string) string
|
ReturningStr(key string) string
|
||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ package dialect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mysql struct{}
|
type mysql struct{}
|
||||||
|
@ -15,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *mysql) SqlTag(column interface{}, size int) string {
|
func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case time.Time:
|
case reflect.Bool:
|
||||||
return "datetime"
|
|
||||||
case bool:
|
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "int"
|
return "int"
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "double"
|
return "double"
|
||||||
case []byte:
|
case reflect.String:
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varbinary(%d)", size)
|
|
||||||
} else {
|
|
||||||
return "longblob"
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
} else {
|
} else {
|
||||||
return "longtext"
|
return "longtext"
|
||||||
}
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if value.Type() == timeType {
|
||||||
|
return "datetime"
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
panic("Invalid sql type for mysql")
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
return fmt.Sprintf("varbinary(%d)", size)
|
||||||
|
} else {
|
||||||
|
return "longblob"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
|
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||||
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "int" + suffix_str
|
return "int" + suffix_str
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint" + suffix_str
|
return "bigint" + suffix_str
|
||||||
default:
|
default:
|
||||||
panic("Invalid primary key type")
|
panic("Invalid primary key type")
|
||||||
|
|
|
@ -2,7 +2,7 @@ package dialect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
|
@ -16,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *postgres) SqlTag(column interface{}, size int) string {
|
func (d *postgres) SqlTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case time.Time:
|
case reflect.Bool:
|
||||||
return "timestamp with time zone"
|
|
||||||
case bool:
|
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "integer"
|
return "integer"
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "numeric"
|
return "numeric"
|
||||||
case []byte:
|
case reflect.String:
|
||||||
return "bytea"
|
|
||||||
case string:
|
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
} else {
|
}
|
||||||
return "text"
|
return "text"
|
||||||
|
case reflect.Struct:
|
||||||
|
if value.Type() == timeType {
|
||||||
|
return "timestamp with time zone"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
panic("Invalid sql type for postgres")
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
return "bytea"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
|
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "serial PRIMARY KEY"
|
return "serial PRIMARY KEY"
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigserial PRIMARY KEY"
|
return "bigserial PRIMARY KEY"
|
||||||
default:
|
default:
|
||||||
panic("Invalid primary key type")
|
panic("Invalid primary key type")
|
||||||
|
|
|
@ -2,7 +2,7 @@ package dialect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlite3 struct{}
|
type sqlite3 struct{}
|
||||||
|
@ -15,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) SqlTag(column interface{}, size int) string {
|
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||||
switch column.(type) {
|
switch value.Kind() {
|
||||||
case time.Time:
|
case reflect.Bool:
|
||||||
return "datetime"
|
|
||||||
case bool:
|
|
||||||
return "bool"
|
return "bool"
|
||||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "integer"
|
return "integer"
|
||||||
case int64, uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case float32, float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "real"
|
return "real"
|
||||||
case []byte:
|
case reflect.String:
|
||||||
return "blob"
|
|
||||||
case string:
|
|
||||||
if size > 0 && size < 65532 {
|
if size > 0 && size < 65532 {
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
} else {
|
} else {
|
||||||
return "text"
|
return "text"
|
||||||
}
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if value.Type() == timeType {
|
||||||
|
return "datetime"
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
panic("Invalid sql type for sqlite3")
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
return "blob"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
|
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||||
return "INTEGER PRIMARY KEY"
|
switch value.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
|
||||||
|
return "INTEGER PRIMARY KEY"
|
||||||
|
default:
|
||||||
|
panic("Invalid primary key type")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) ReturningStr(key string) (str string) {
|
func (s *sqlite3) ReturningStr(key string) (str string) {
|
||||||
|
|
30
main_test.go
30
main_test.go
|
@ -22,6 +22,17 @@ type IgnoredEmbedStruct struct {
|
||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Num int64
|
||||||
|
|
||||||
|
func (i *Num) Scan(src interface{}) error {
|
||||||
|
v := reflect.ValueOf(src)
|
||||||
|
if v.Kind() != reflect.Int64 {
|
||||||
|
return errors.New("Cannot scan NamedInt from " + v.String())
|
||||||
|
}
|
||||||
|
*i = Num(v.Int())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64 // Id: Primary key
|
Id int64 // Id: Primary key
|
||||||
Age int64
|
Age int64
|
||||||
|
@ -42,6 +53,7 @@ type User struct {
|
||||||
PasswordHash []byte
|
PasswordHash []byte
|
||||||
IgnoreMe int64 `sql:"-"`
|
IgnoreMe int64 `sql:"-"`
|
||||||
IgnoreStringSlice []string `sql:"-"`
|
IgnoreStringSlice []string `sql:"-"`
|
||||||
|
UserNum Num
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreditCard struct {
|
type CreditCard struct {
|
||||||
|
@ -156,7 +168,7 @@ func init() {
|
||||||
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
||||||
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
||||||
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
||||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
|
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
|
||||||
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
||||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||||
|
@ -181,6 +193,22 @@ func TestFirstAndLast(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveCustomType(t *testing.T) {
|
||||||
|
var user, user1 User
|
||||||
|
db.First(&user, "name = ?", "1")
|
||||||
|
if user.UserNum != Num(111) {
|
||||||
|
t.Errorf("UserNum should be saved correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.UserNum = Num(222)
|
||||||
|
db.Save(&user)
|
||||||
|
|
||||||
|
db.First(&user1, "name = ?", "1")
|
||||||
|
if user1.UserNum != Num(222) {
|
||||||
|
t.Errorf("UserNum should be updated correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrecision(t *testing.T) {
|
func TestPrecision(t *testing.T) {
|
||||||
f := 35.03554004971999
|
f := 35.03554004971999
|
||||||
user := User{Name: "Precision", Latitude: f}
|
user := User{Name: "Precision", Latitude: f}
|
||||||
|
|
|
@ -309,26 +309,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
|
||||||
value := field.Value
|
value := field.Value
|
||||||
reflectValue := reflect.ValueOf(value)
|
reflectValue := reflect.ValueOf(value)
|
||||||
|
|
||||||
if field.IsScanner() {
|
|
||||||
value = reflectValue.Field(0).Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if _, ok := value.([]byte); !ok {
|
if _, ok := value.([]byte); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if !field.IsTime() && !field.IsScanner() {
|
if field.IsScanner() {
|
||||||
|
reflectValue = reflectValue.Field(0)
|
||||||
|
} else if !field.IsTime() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tag) == 0 {
|
if len(tag) == 0 {
|
||||||
if field.isPrimaryKey {
|
if field.isPrimaryKey {
|
||||||
tag = scope.Dialect().PrimaryKeyTag(value, size)
|
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||||
} else {
|
} else {
|
||||||
tag = scope.Dialect().SqlTag(value, size)
|
tag = scope.Dialect().SqlTag(reflectValue, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue