forked from mirror/gorm
Support custom types from base types
This commit is contained in:
parent
4969fc9cb5
commit
e6c953dd4c
|
@ -1,10 +1,17 @@
|
|||
package dialect
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
var timeType = reflect.TypeOf(time.Time{})
|
||||
|
||||
type Dialect interface {
|
||||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
SqlTag(column interface{}, size int) string
|
||||
PrimaryKeyTag(column interface{}, size int) string
|
||||
SqlTag(value reflect.Value, size int) string
|
||||
PrimaryKeyTag(value reflect.Value, size int) string
|
||||
ReturningStr(key string) string
|
||||
Quote(key string) string
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ package dialect
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type mysql struct{}
|
||||
|
@ -15,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (d *mysql) SqlTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "datetime"
|
||||
case bool:
|
||||
func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int"
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint"
|
||||
case float32, float64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
case []byte:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
return "longblob"
|
||||
}
|
||||
case string:
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "longtext"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
panic("Invalid sql type for mysql")
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
return "longblob"
|
||||
}
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
|
||||
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||
switch column.(type) {
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int" + suffix_str
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint" + suffix_str
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
|
|
|
@ -2,7 +2,7 @@ package dialect
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
|
@ -16,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (d *postgres) SqlTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "timestamp with time zone"
|
||||
case bool:
|
||||
func (d *postgres) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "integer"
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint"
|
||||
case float32, float64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "numeric"
|
||||
case []byte:
|
||||
return "bytea"
|
||||
case string:
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
default:
|
||||
panic("Invalid sql type for postgres")
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "bytea"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "serial PRIMARY KEY"
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigserial PRIMARY KEY"
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
|
|
|
@ -2,7 +2,7 @@ package dialect
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type sqlite3 struct{}
|
||||
|
@ -15,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (s *sqlite3) SqlTag(column interface{}, size int) string {
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "datetime"
|
||||
case bool:
|
||||
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "integer"
|
||||
case int64, uint64:
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint"
|
||||
case float32, float64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "real"
|
||||
case []byte:
|
||||
return "blob"
|
||||
case string:
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
panic("Invalid sql type for sqlite3")
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
return "blob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
|
||||
return "INTEGER PRIMARY KEY"
|
||||
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
|
||||
return "INTEGER PRIMARY KEY"
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sqlite3) ReturningStr(key string) (str string) {
|
||||
|
|
30
main_test.go
30
main_test.go
|
@ -22,6 +22,17 @@ type IgnoredEmbedStruct struct {
|
|||
Name string
|
||||
}
|
||||
|
||||
type Num int64
|
||||
|
||||
func (i *Num) Scan(src interface{}) error {
|
||||
v := reflect.ValueOf(src)
|
||||
if v.Kind() != reflect.Int64 {
|
||||
return errors.New("Cannot scan NamedInt from " + v.String())
|
||||
}
|
||||
*i = Num(v.Int())
|
||||
return nil
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Id int64 // Id: Primary key
|
||||
Age int64
|
||||
|
@ -42,6 +53,7 @@ type User struct {
|
|||
PasswordHash []byte
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
IgnoreStringSlice []string `sql:"-"`
|
||||
UserNum Num
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
|
@ -156,7 +168,7 @@ func init() {
|
|||
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
||||
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
||||
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
|
||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
|
||||
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||
|
@ -181,6 +193,22 @@ func TestFirstAndLast(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSaveCustomType(t *testing.T) {
|
||||
var user, user1 User
|
||||
db.First(&user, "name = ?", "1")
|
||||
if user.UserNum != Num(111) {
|
||||
t.Errorf("UserNum should be saved correctly")
|
||||
}
|
||||
|
||||
user.UserNum = Num(222)
|
||||
db.Save(&user)
|
||||
|
||||
db.First(&user1, "name = ?", "1")
|
||||
if user1.UserNum != Num(222) {
|
||||
t.Errorf("UserNum should be updated correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecision(t *testing.T) {
|
||||
f := 35.03554004971999
|
||||
user := User{Name: "Precision", Latitude: f}
|
||||
|
|
|
@ -309,26 +309,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
|
|||
value := field.Value
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
|
||||
if field.IsScanner() {
|
||||
value = reflectValue.Field(0).Interface()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
if _, ok := value.([]byte); !ok {
|
||||
return
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !field.IsTime() && !field.IsScanner() {
|
||||
if field.IsScanner() {
|
||||
reflectValue = reflectValue.Field(0)
|
||||
} else if !field.IsTime() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(tag) == 0 {
|
||||
if field.isPrimaryKey {
|
||||
tag = scope.Dialect().PrimaryKeyTag(value, size)
|
||||
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||
} else {
|
||||
tag = scope.Dialect().SqlTag(value, size)
|
||||
tag = scope.Dialect().SqlTag(reflectValue, size)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue