Support custom types from base types

This commit is contained in:
Jinzhu 2014-03-16 09:28:43 +08:00
parent 4969fc9cb5
commit e6c953dd4c
6 changed files with 111 additions and 65 deletions

View File

@ -1,10 +1,17 @@
package dialect package dialect
import (
"reflect"
"time"
)
var timeType = reflect.TypeOf(time.Time{})
type Dialect interface { type Dialect interface {
BinVar(i int) string BinVar(i int) string
SupportLastInsertId() bool SupportLastInsertId() bool
SqlTag(column interface{}, size int) string SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(column interface{}, size int) string PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string ReturningStr(key string) string
Quote(key string) string Quote(key string) string
} }

View File

@ -2,7 +2,7 @@ package dialect
import ( import (
"fmt" "fmt"
"time" "reflect"
) )
type mysql struct{} type mysql struct{}
@ -15,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool {
return true return true
} }
func (d *mysql) SqlTag(column interface{}, size int) string { func (d *mysql) SqlTag(value reflect.Value, size int) string {
switch column.(type) { switch value.Kind() {
case time.Time: case reflect.Bool:
return "datetime"
case bool:
return "boolean" return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" return "int"
case int64, uint64: case reflect.Int64, reflect.Uint64:
return "bigint" return "bigint"
case float32, float64: case reflect.Float32, reflect.Float64:
return "double" return "double"
case []byte: case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
case string:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} else { } else {
return "longtext" return "longtext"
} }
case reflect.Struct:
if value.Type() == timeType {
return "datetime"
}
default: default:
panic("Invalid sql type for mysql") if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
}
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
} }
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string { func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) { switch value.Kind() {
case int, int8, int16, int32, uint, uint8, uint16, uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix_str return "int" + suffix_str
case int64, uint64: case reflect.Int64, reflect.Uint64:
return "bigint" + suffix_str return "bigint" + suffix_str
default: default:
panic("Invalid primary key type") panic("Invalid primary key type")

View File

@ -2,7 +2,7 @@ package dialect
import ( import (
"fmt" "fmt"
"time" "reflect"
) )
type postgres struct { type postgres struct {
@ -16,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool {
return false return false
} }
func (d *postgres) SqlTag(column interface{}, size int) string { func (d *postgres) SqlTag(value reflect.Value, size int) string {
switch column.(type) { switch value.Kind() {
case time.Time: case reflect.Bool:
return "timestamp with time zone"
case bool:
return "boolean" return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer" return "integer"
case int64, uint64: case reflect.Int64, reflect.Uint64:
return "bigint" return "bigint"
case float32, float64: case reflect.Float32, reflect.Float64:
return "numeric" return "numeric"
case []byte: case reflect.String:
return "bytea"
case string:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} else { }
return "text" return "text"
case reflect.Struct:
if value.Type() == timeType {
return "timestamp with time zone"
} }
default: default:
panic("Invalid sql type for postgres") if _, ok := value.Interface().([]byte); ok {
return "bytea"
}
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
} }
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string { func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
switch column.(type) { switch value.Kind() {
case int, int8, int16, int32, uint, uint8, uint16, uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "serial PRIMARY KEY" return "serial PRIMARY KEY"
case int64, uint64: case reflect.Int64, reflect.Uint64:
return "bigserial PRIMARY KEY" return "bigserial PRIMARY KEY"
default: default:
panic("Invalid primary key type") panic("Invalid primary key type")

View File

@ -2,7 +2,7 @@ package dialect
import ( import (
"fmt" "fmt"
"time" "reflect"
) )
type sqlite3 struct{} type sqlite3 struct{}
@ -15,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool {
return true return true
} }
func (s *sqlite3) SqlTag(column interface{}, size int) string { func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
switch column.(type) { switch value.Kind() {
case time.Time: case reflect.Bool:
return "datetime"
case bool:
return "bool" return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer" return "integer"
case int64, uint64: case reflect.Int64, reflect.Uint64:
return "bigint" return "bigint"
case float32, float64: case reflect.Float32, reflect.Float64:
return "real" return "real"
case []byte: case reflect.String:
return "blob"
case string:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} else { } else {
return "text" return "text"
} }
case reflect.Struct:
if value.Type() == timeType {
return "datetime"
}
default: default:
panic("Invalid sql type for sqlite3") if _, ok := value.Interface().([]byte); ok {
return "blob"
}
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
} }
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string { func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
return "INTEGER PRIMARY KEY" switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
return "INTEGER PRIMARY KEY"
default:
panic("Invalid primary key type")
}
} }
func (s *sqlite3) ReturningStr(key string) (str string) { func (s *sqlite3) ReturningStr(key string) (str string) {

View File

@ -22,6 +22,17 @@ type IgnoredEmbedStruct struct {
Name string Name string
} }
type Num int64
func (i *Num) Scan(src interface{}) error {
v := reflect.ValueOf(src)
if v.Kind() != reflect.Int64 {
return errors.New("Cannot scan NamedInt from " + v.String())
}
*i = Num(v.Int())
return nil
}
type User struct { type User struct {
Id int64 // Id: Primary key Id int64 // Id: Primary key
Age int64 Age int64
@ -42,6 +53,7 @@ type User struct {
PasswordHash []byte PasswordHash []byte
IgnoreMe int64 `sql:"-"` IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"` IgnoreStringSlice []string `sql:"-"`
UserNum Num
} }
type CreditCard struct { type CreditCard struct {
@ -156,7 +168,7 @@ func init() {
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00") t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00") t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00") t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()}) db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)})
db.Save(&User{Name: "2", Age: 20, Birthday: t2}) db.Save(&User{Name: "2", Age: 20, Birthday: t2})
db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 22, Birthday: t3})
db.Save(&User{Name: "3", Age: 24, Birthday: t4}) db.Save(&User{Name: "3", Age: 24, Birthday: t4})
@ -181,6 +193,22 @@ func TestFirstAndLast(t *testing.T) {
} }
} }
func TestSaveCustomType(t *testing.T) {
var user, user1 User
db.First(&user, "name = ?", "1")
if user.UserNum != Num(111) {
t.Errorf("UserNum should be saved correctly")
}
user.UserNum = Num(222)
db.Save(&user)
db.First(&user1, "name = ?", "1")
if user1.UserNum != Num(222) {
t.Errorf("UserNum should be updated correctly")
}
}
func TestPrecision(t *testing.T) { func TestPrecision(t *testing.T) {
f := 35.03554004971999 f := 35.03554004971999
user := User{Name: "Precision", Latitude: f} user := User{Name: "Precision", Latitude: f}

View File

@ -309,26 +309,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) {
value := field.Value value := field.Value
reflectValue := reflect.ValueOf(value) reflectValue := reflect.ValueOf(value)
if field.IsScanner() {
value = reflectValue.Field(0).Interface()
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice: case reflect.Slice:
if _, ok := value.([]byte); !ok { if _, ok := value.([]byte); !ok {
return return
} }
case reflect.Struct: case reflect.Struct:
if !field.IsTime() && !field.IsScanner() { if field.IsScanner() {
reflectValue = reflectValue.Field(0)
} else if !field.IsTime() {
return return
} }
} }
if len(tag) == 0 { if len(tag) == 0 {
if field.isPrimaryKey { if field.isPrimaryKey {
tag = scope.Dialect().PrimaryKeyTag(value, size) tag = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else { } else {
tag = scope.Dialect().SqlTag(value, size) tag = scope.Dialect().SqlTag(reflectValue, size)
} }
} }