Support custom types from base types

This commit is contained in:
Jinzhu 2014-03-16 09:28:43 +08:00
parent 4969fc9cb5
commit e6c953dd4c
6 changed files with 111 additions and 65 deletions

View File

@ -1,10 +1,17 @@
package dialect
import (
"reflect"
"time"
)
var timeType = reflect.TypeOf(time.Time{})
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
SqlTag(column interface{}, size int) string
PrimaryKeyTag(column interface{}, size int) string
SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string
Quote(key string) string
}

View File

@ -2,7 +2,7 @@ package dialect
import (
"fmt"
"time"
"reflect"
)
type mysql struct{}
@ -15,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
return true
}
func (d *mysql) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "datetime"
case bool:
func (d *mysql) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int"
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigint"
case float32, float64:
case reflect.Float32, reflect.Float64:
return "double"
case []byte:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
case string:
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "longtext"
}
default:
panic("Invalid sql type for mysql")
case reflect.Struct:
if value.Type() == timeType {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix_str
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix_str
default:
panic("Invalid primary key type")

View File

@ -2,7 +2,7 @@ package dialect
import (
"fmt"
"time"
"reflect"
)
type postgres struct {
@ -16,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
return false
}
func (d *postgres) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "timestamp with time zone"
case bool:
func (d *postgres) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigint"
case float32, float64:
case reflect.Float32, reflect.Float64:
return "numeric"
case []byte:
return "bytea"
case string:
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
}
return "text"
case reflect.Struct:
if value.Type() == timeType {
return "timestamp with time zone"
}
default:
panic("Invalid sql type for postgres")
if _, ok := value.Interface().([]byte); ok {
return "bytea"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "serial PRIMARY KEY"
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigserial PRIMARY KEY"
default:
panic("Invalid primary key type")

View File

@ -2,7 +2,7 @@ package dialect
import (
"fmt"
"time"
"reflect"
)
type sqlite3 struct{}
@ -15,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
return true
}
func (s *sqlite3) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "datetime"
case bool:
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case int64, uint64:
case reflect.Int64, reflect.Uint64:
return "bigint"
case float32, float64:
case reflect.Float32, reflect.Float64:
return "real"
case []byte:
return "blob"
case string:
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
}
default:
panic("Invalid sql type for sqlite3")
case reflect.Struct:
if value.Type() == timeType {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
return "INTEGER PRIMARY KEY"
default:
panic("Invalid primary key type")
}
}
func (s *sqlite3) ReturningStr(key string) (str string) {

View File

@ -22,6 +22,17 @@ type IgnoredEmbedStruct struct {
Name string
}
type Num int64
func (i *Num) Scan(src interface{}) error {
v := reflect.ValueOf(src)
if v.Kind() != reflect.Int64 {
return errors.New("Cannot scan NamedInt from " + v.String())
}
*i = Num(v.Int())
return nil
}
type User struct {
Id int64 // Id: Primary key
Age int64
@ -42,6 +53,7 @@ type User struct {
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
UserNum Num
}
type CreditCard struct {
@ -156,7 +168,7 @@ func init() {
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
@ -181,6 +193,22 @@ func TestFirstAndLast(t *testing.T) {
}
}
func TestSaveCustomType(t *testing.T) {
var user, user1 User
db.First(&user, "name = ?", "1")
if user.UserNum != Num(111) {
t.Errorf("UserNum should be saved correctly")
}
user.UserNum = Num(222)
db.Save(&user)
db.First(&user1, "name = ?", "1")
if user1.UserNum != Num(222) {
t.Errorf("UserNum should be updated correctly")
}
}
func TestPrecision(t *testing.T) {
f := 35.03554004971999
user := User{Name: "Precision", Latitude: f}

View File

@ -309,26 +309,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
value := field.Value
reflectValue := reflect.ValueOf(value)
if field.IsScanner() {
value = reflectValue.Field(0).Interface()
}
switch reflectValue.Kind() {
case reflect.Slice:
if _, ok := value.([]byte); !ok {
return
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if field.IsScanner() {
reflectValue = reflectValue.Field(0)
} else if !field.IsTime() {
return
}
}
if len(tag) == 0 {
if field.isPrimaryKey {
tag = scope.Dialect().PrimaryKeyTag(value, size)
tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else {
tag = scope.Dialect().SqlTag(value, size)
tag = scope.Dialect().SqlTag(reflectValue, size)
}
}