diff --git a/dialect/dialect.go b/dialect/dialect.go index 9418e533..a0f403c4 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -1,10 +1,17 @@ package dialect +import ( + "reflect" + "time" +) + +var timeType = reflect.TypeOf(time.Time{}) + type Dialect interface { BinVar(i int) string SupportLastInsertId() bool - SqlTag(column interface{}, size int) string - PrimaryKeyTag(column interface{}, size int) string + SqlTag(value reflect.Value, size int) string + PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string Quote(key string) string } diff --git a/dialect/mysql.go b/dialect/mysql.go index 2a5ac0c0..9602bcc1 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -2,7 +2,7 @@ package dialect import ( "fmt" - "time" + "reflect" ) type mysql struct{} @@ -15,41 +15,44 @@ func (s *mysql) SupportLastInsertId() bool { return true } -func (d *mysql) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "datetime" - case bool: +func (d *mysql) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "int" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64: + case reflect.Float32, reflect.Float64: return "double" - case []byte: - if size > 0 && size < 65532 { - return fmt.Sprintf("varbinary(%d)", size) - } else { - return "longblob" - } - case string: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { return "longtext" } + case reflect.Struct: + if value.Type() == timeType { + return "datetime" + } default: - panic("Invalid sql type for mysql") + if _, ok := value.Interface().([]byte); ok { + if size > 0 && size < 65532 { + return fmt.Sprintf("varbinary(%d)", size) + } else { + return "longblob" + } + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) PrimaryKeyTag(column interface{}, size int) string { +func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "int" + suffix_str - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" + suffix_str default: panic("Invalid primary key type") diff --git a/dialect/postgres.go b/dialect/postgres.go index 7cffe9dc..7b744fa4 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -2,7 +2,7 @@ package dialect import ( "fmt" - "time" + "reflect" ) type postgres struct { @@ -16,36 +16,38 @@ func (s *postgres) SupportLastInsertId() bool { return false } -func (d *postgres) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "timestamp with time zone" - case bool: +func (d *postgres) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64: + case reflect.Float32, reflect.Float64: return "numeric" - case []byte: - return "bytea" - case string: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) - } else { - return "text" + } + return "text" + case reflect.Struct: + if value.Type() == timeType { + return "timestamp with time zone" } default: - panic("Invalid sql type for postgres") + if _, ok := value.Interface().([]byte); ok { + return "bytea" + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) PrimaryKeyTag(column interface{}, size int) string { - switch column.(type) { - case int, int8, int16, int32, uint, uint8, uint16, uint32: +func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "serial PRIMARY KEY" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigserial PRIMARY KEY" default: panic("Invalid primary key type") diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index 4d16c3a8..ae54e603 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -2,7 +2,7 @@ package dialect import ( "fmt" - "time" + "reflect" ) type sqlite3 struct{} @@ -15,33 +15,41 @@ func (s *sqlite3) SupportLastInsertId() bool { return true } -func (s *sqlite3) SqlTag(column interface{}, size int) string { - switch column.(type) { - case time.Time: - return "datetime" - case bool: +func (s *sqlite3) SqlTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Bool: return "bool" - case int, int8, int16, int32, uint, uint8, uint16, uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: return "integer" - case int64, uint64: + case reflect.Int64, reflect.Uint64: return "bigint" - case float32, float64: + case reflect.Float32, reflect.Float64: return "real" - case []byte: - return "blob" - case string: + case reflect.String: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } else { return "text" } + case reflect.Struct: + if value.Type() == timeType { + return "datetime" + } default: - panic("Invalid sql type for sqlite3") + if _, ok := value.Interface().([]byte); ok { + return "blob" + } } + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string { - return "INTEGER PRIMARY KEY" +func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64: + return "INTEGER PRIMARY KEY" + default: + panic("Invalid primary key type") + } } func (s *sqlite3) ReturningStr(key string) (str string) { diff --git a/main_test.go b/main_test.go index 1df4558a..f244c03d 100644 --- a/main_test.go +++ b/main_test.go @@ -22,6 +22,17 @@ type IgnoredEmbedStruct struct { Name string } +type Num int64 + +func (i *Num) Scan(src interface{}) error { + v := reflect.ValueOf(src) + if v.Kind() != reflect.Int64 { + return errors.New("Cannot scan NamedInt from " + v.String()) + } + *i = Num(v.Int()) + return nil +} + type User struct { Id int64 // Id: Primary key Age int64 @@ -42,6 +53,7 @@ type User struct { PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` + UserNum Num } type CreditCard struct { @@ -156,7 +168,7 @@ func init() { t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00") t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00") t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00") - db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()}) + db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now(), UserNum: Num(111)}) db.Save(&User{Name: "2", Age: 20, Birthday: t2}) db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4}) @@ -181,6 +193,22 @@ func TestFirstAndLast(t *testing.T) { } } +func TestSaveCustomType(t *testing.T) { + var user, user1 User + db.First(&user, "name = ?", "1") + if user.UserNum != Num(111) { + t.Errorf("UserNum should be saved correctly") + } + + user.UserNum = Num(222) + db.Save(&user) + + db.First(&user1, "name = ?", "1") + if user1.UserNum != Num(222) { + t.Errorf("UserNum should be updated correctly") + } +} + func TestPrecision(t *testing.T) { f := 35.03554004971999 user := User{Name: "Precision", Latitude: f} diff --git a/scope_private.go b/scope_private.go index 7212af2b..c2e5627e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -309,26 +309,24 @@ func (scope *Scope) sqlTagForField(field *Field) (tag string) { value := field.Value reflectValue := reflect.ValueOf(value) - if field.IsScanner() { - value = reflectValue.Field(0).Interface() - } - switch reflectValue.Kind() { case reflect.Slice: if _, ok := value.([]byte); !ok { return } case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { + if field.IsScanner() { + reflectValue = reflectValue.Field(0) + } else if !field.IsTime() { return } } if len(tag) == 0 { if field.isPrimaryKey { - tag = scope.Dialect().PrimaryKeyTag(value, size) + tag = scope.Dialect().PrimaryKeyTag(reflectValue, size) } else { - tag = scope.Dialect().SqlTag(value, size) + tag = scope.Dialect().SqlTag(reflectValue, size) } }