gorm/sql_type.go

177 lines
3.8 KiB
Go
Raw Normal View History

2013-10-26 12:01:15 +04:00
package gorm
import (
"database/sql"
2013-11-10 15:38:28 +04:00
"database/sql/driver"
2013-10-26 12:01:15 +04:00
"fmt"
2013-11-10 15:38:28 +04:00
"reflect"
2013-11-13 20:03:31 +04:00
"strconv"
"strings"
2013-10-26 12:01:15 +04:00
"time"
)
2013-11-13 20:03:31 +04:00
func parseTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
2013-11-10 15:38:28 +04:00
func formatColumnValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
}
if valuer, ok := interface{}(column).(driver.Valuer); ok {
column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface()
}
return column
}
2013-11-13 20:03:31 +04:00
func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string {
2013-11-10 15:38:28 +04:00
column = formatColumnValue(column)
2013-11-13 20:03:31 +04:00
typ, addational_typ, _ := parseTag(tag)
if len(typ) != 0 {
return typ + addational_typ
}
2013-11-10 15:38:28 +04:00
2013-10-26 13:56:00 +04:00
switch adaptor {
2013-11-04 16:47:45 +04:00
case "sqlite3":
return "INTEGER PRIMARY KEY"
2013-10-26 13:56:00 +04:00
case "mysql":
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
2013-11-13 20:03:31 +04:00
typ = "int" + suffix_str
2013-10-26 13:56:00 +04:00
case int64, uint64:
2013-11-13 20:03:31 +04:00
typ = "bigint" + suffix_str
2013-10-26 13:56:00 +04:00
}
case "postgres":
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
2013-11-13 20:03:31 +04:00
typ = "serial"
2013-10-26 13:56:00 +04:00
case int64, uint64:
2013-11-13 20:03:31 +04:00
typ = "bigserial"
2013-10-26 13:56:00 +04:00
}
2013-11-13 20:03:31 +04:00
default:
panic("unsupported sql adaptor, please submit an issue in github")
2013-10-26 13:56:00 +04:00
}
2013-11-13 20:03:31 +04:00
return typ
2013-10-26 13:56:00 +04:00
}
2013-11-13 20:03:31 +04:00
func getSqlType(adaptor string, column interface{}, tag string) string {
2013-11-10 15:38:28 +04:00
column = formatColumnValue(column)
2013-11-13 20:03:31 +04:00
typ, addational_typ, size := parseTag(tag)
2013-11-10 15:38:28 +04:00
2013-11-13 20:03:31 +04:00
if typ == "-" {
return ""
}
if len(typ) == 0 {
switch adaptor {
case "sqlite3":
switch column.(type) {
case time.Time:
typ = "datetime"
case bool, sql.NullBool:
typ = "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "integer"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "real"
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
default:
panic("invalid sql type")
2013-10-26 12:01:15 +04:00
}
2013-11-13 20:03:31 +04:00
case "mysql":
switch column.(type) {
case time.Time:
typ = "timestamp"
case bool, sql.NullBool:
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "int"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "double"
case []byte:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varbinary(%d)", size)
} else {
typ = "longblob"
}
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "longtext"
}
default:
panic("invalid sql type")
2013-10-26 12:01:15 +04:00
}
2013-11-13 20:03:31 +04:00
case "postgres":
switch column.(type) {
case time.Time:
typ = "timestamp with time zone"
case bool, sql.NullBool:
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "integer"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "double precision"
case []byte:
typ = "bytea"
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
default:
panic("invalid sql type")
2013-10-26 12:01:15 +04:00
}
default:
2013-11-13 20:03:31 +04:00
panic("unsupported sql adaptor, please submit an issue in github")
2013-10-26 12:01:15 +04:00
}
}
2013-11-13 20:03:31 +04:00
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
2013-10-26 12:01:15 +04:00
}