Finish dialects

This commit is contained in:
Jinzhu 2013-11-14 18:59:11 +08:00
parent ac4b2e2a42
commit c1c13129f1
11 changed files with 275 additions and 235 deletions

View File

@ -28,7 +28,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
UpdatedAt time.Time // Time of record is updated, will be updated automatically
DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more
Email []Email // Embedded structs
Emails []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
ShippingAddress Address // Embedded struct
@ -125,7 +125,7 @@ user := User{
Name: "jinzhu",
BillingAddress: Address{Address1: "Billing Address - Address 1"},
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
Email: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
}
db.Save(&user)

View File

@ -28,10 +28,6 @@ type Chain struct {
unscoped bool
}
func (s *Chain) driver() string {
return s.d.driver
}
func (s *Chain) err(err error) error {
if err != nil {
s.Errors = append(s.Errors, err)

View File

@ -1,23 +1,22 @@
package dialect
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
SqlTag(column interface{}, size int) string
PrimaryKeyTag(column interface{}, size int) string
ReturningStr(key string) string
}
func NewDialect(driver string) *Dialect {
func NewDialect(driver string) Dialect {
var d Dialect
switch driver {
case "postgres":
d = postgres{}
d = &postgres{}
case "mysql":
d = mysql{}
d = &mysql{}
case "sqlite3":
d = sqlite3{}
d = &sqlite3{}
}
return &d
return d
}
type mysql struct{}
type postgres struct{}
type sqlite3 struct{}

View File

@ -1 +1,62 @@
package dialect
import (
"database/sql"
"fmt"
"time"
)
type mysql struct{}
func (s *mysql) BinVar(i int) string {
return "?"
}
func (s *mysql) SupportLastInsertId() bool {
return true
}
func (d *mysql) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "timestamp"
case bool, sql.NullBool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int"
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64, sql.NullFloat64:
return "double"
case []byte:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
}
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "longtext"
}
default:
panic("Invalid sql type for mysql")
}
}
func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int" + suffix_str
case int64, uint64:
return "bigint" + suffix_str
default:
panic("Invalid primary key type")
}
}
func (s *mysql) ReturningStr(key string) (str string) {
return
}

View File

@ -1 +1,58 @@
package dialect
import (
"database/sql"
"fmt"
"time"
)
type postgres struct {
}
func (s *postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (s *postgres) SupportLastInsertId() bool {
return false
}
func (d *postgres) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "timestamp with time zone"
case bool, sql.NullBool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64, sql.NullFloat64:
return "double precision"
case []byte:
return "bytea"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
}
default:
panic("Invalid sql type for postgres")
}
}
func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "sehrial"
case int64, uint64:
return "bigserial"
default:
panic("Invalid primary key type")
}
}
func (s *postgres) ReturningStr(key string) (str string) {
return fmt.Sprintf("RETURNING \"%v\"", key)
}

View File

@ -1 +1,48 @@
package dialect
import (
"database/sql"
"fmt"
"time"
)
type sqlite3 struct{}
func (s *sqlite3) BinVar(i int) string {
return "?"
}
func (s *sqlite3) SupportLastInsertId() bool {
return true
}
func (s *sqlite3) SqlTag(column interface{}, size int) string {
switch column.(type) {
case time.Time:
return "datetime"
case bool, sql.NullBool:
return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64, sql.NullFloat64:
return "real"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
}
default:
panic("Invalid sql type for sqlite3")
}
}
func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
return "INTEGER PRIMARY KEY"
}
func (s *sqlite3) ReturningStr(key string) (str string) {
return
}

14
do.go
View File

@ -60,11 +60,7 @@ func (s *Do) setModel(value interface{}) *Do {
func (s *Do) addToVars(value interface{}) string {
s.sqlVars = append(s.sqlVars, value)
if s.chain.driver() == "postgres" {
return fmt.Sprintf("$%d", len(s.sqlVars))
} else {
return "?"
}
return s.chain.d.dialect.BinVar(len(s.sqlVars))
}
func (s *Do) exec(sqls ...string) (err error) {
@ -102,7 +98,7 @@ func (s *Do) prepareCreateSql() {
s.tableName(),
strings.Join(columns, ","),
strings.Join(sqls, ","),
s.model.returningStr(),
s.chain.d.dialect.ReturningStr(s.model.primaryKeyDb()),
)
return
}
@ -178,13 +174,13 @@ func (s *Do) create() (i interface{}) {
var id interface{}
now := time.Now()
if s.chain.driver() == "postgres" {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else {
if s.chain.d.dialect.SupportLastInsertId() {
if sql_result, err := s.db.Exec(s.sql, s.sqlVars...); s.err(err) == nil {
id, err = sql_result.LastInsertId()
s.err(err)
}
} else {
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
}
s.chain.slog(s.sql, now, s.sqlVars...)

View File

@ -12,7 +12,7 @@ func init() {
type DB struct {
db sql_common
dialect *dialect.Dialect
dialect dialect.Dialect
logger Logger
log_mode bool
}

View File

@ -3,11 +3,11 @@ package gorm
import (
"database/sql"
"errors"
"fmt"
"go/ast"
"reflect"
"regexp"
"strconv"
"time"
)
@ -115,7 +115,6 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
tag_value := p.Tag.Get(tagIdentifier)
if is_time {
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
@ -130,10 +129,14 @@ func (m *Model) fields(operation string) (fields []Field) {
value.Set(reflect.ValueOf(time.Now()))
}
}
}
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
field.Value = value.Interface()
if is_time {
field.SqlType = m.getSqlTag(field, p)
} else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, tag_value)
field.SqlType = m.getSqlTag(field, p)
} else {
field_value := reflect.Indirect(value)
@ -148,7 +151,7 @@ func (m *Model) fields(operation string) (fields []Field) {
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
field.SqlType = m.getSqlTag(field, p)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
@ -162,11 +165,10 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
default:
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
field.SqlType = m.getSqlTag(field, p)
}
}
field.Value = value.Interface()
fields = append(fields, field)
}
}
@ -313,13 +315,6 @@ func (m *Model) callMethod(method string) {
return
}
func (m *Model) returningStr() (str string) {
if m.do.chain.driver() == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", m.primaryKeyDb())
}
return
}
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
data := reflect.Indirect(reflect.ValueOf(out))
setFieldValue(data.FieldByName(snakeToUpperCamel(name)), value)
@ -343,23 +338,24 @@ func (m *Model) afterAssociations() (fields []Field) {
return
}
func setFieldValue(field reflect.Value, value interface{}) bool {
if field.IsValid() && field.CanAddr() {
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if str, ok := value.(string); ok {
value, _ = strconv.Atoi(str)
}
field.SetInt(reflect.ValueOf(value).Int())
default:
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else {
field.Set(reflect.ValueOf(value))
}
}
return true
func (m *Model) getSqlTag(field Field, struct_field reflect.StructField) string {
column := getInterfaceValue(field.Value)
typ, addational_typ, size := parseSqlTag(struct_field.Tag.Get(tagIdentifier))
if typ == "-" {
return ""
}
return false
if len(typ) == 0 {
if field.IsPrimaryKey {
typ = m.do.chain.d.dialect.PrimaryKeyTag(column, size)
} else {
typ = m.do.chain.d.dialect.SqlTag(column, size)
}
}
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
}

View File

@ -1,176 +0,0 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
func parseTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
func formatColumnValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}
func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string {
column = formatColumnValue(column)
typ, addational_typ, _ := parseTag(tag)
if len(typ) != 0 {
return typ + addational_typ
}
switch adaptor {
case "sqlite3":
return "INTEGER PRIMARY KEY"
case "mysql":
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "int" + suffix_str
case int64, uint64:
typ = "bigint" + suffix_str
}
case "postgres":
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "serial"
case int64, uint64:
typ = "bigserial"
}
default:
panic("unsupported sql adaptor, please submit an issue in github")
}
return typ
}
func getSqlType(adaptor string, column interface{}, tag string) string {
column = formatColumnValue(column)
typ, addational_typ, size := parseTag(tag)
if typ == "-" {
return ""
}
if len(typ) == 0 {
switch adaptor {
case "sqlite3":
switch column.(type) {
case time.Time:
typ = "datetime"
case bool, sql.NullBool:
typ = "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "integer"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "real"
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
default:
panic("invalid sql type")
}
case "mysql":
switch column.(type) {
case time.Time:
typ = "timestamp"
case bool, sql.NullBool:
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "int"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "double"
case []byte:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varbinary(%d)", size)
} else {
typ = "longblob"
}
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "longtext"
}
default:
panic("invalid sql type")
}
case "postgres":
switch column.(type) {
case time.Time:
typ = "timestamp with time zone"
case bool, sql.NullBool:
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "integer"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "double precision"
case []byte:
typ = "bytea"
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
default:
panic("invalid sql type")
}
default:
panic("unsupported sql adaptor, please submit an issue in github")
}
}
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
}

View File

@ -2,7 +2,10 @@ package gorm
import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"strconv"
"fmt"
@ -63,3 +66,64 @@ func getInterfaceAsString(value interface{}) (str string, err error) {
}
return
}
func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
func getInterfaceValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}
func setFieldValue(field reflect.Value, value interface{}) bool {
if field.IsValid() && field.CanAddr() {
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if str, ok := value.(string); ok {
value, _ = strconv.Atoi(str)
}
field.SetInt(reflect.ValueOf(value).Int())
default:
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else {
field.Set(reflect.ValueOf(value))
}
}
return true
}
return false
}