gorm/dialect.go

148 lines
4.8 KiB
Go

package gorm
import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
)
// Dialect interface contains behaviors that differ across SQL database
type Dialect interface {
// GetName get dialect's name
GetName() string
// SetDB set db for dialect
SetDB(db SQLCommon)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// DefaultValueStr
DefaultValueStr() string
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
BuildKeyName(kind, tableName string, fields ...string) string
// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
// CurrentDatabase return current database name
CurrentDatabase() string
}
var dialectsMap = map[string]Dialect{}
func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)
return dialect
}
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
}
// RegisterDialect register new dialect
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// GetDialect gets the dialect for the specified dialect name
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
// ParseFieldStructForDialect get field's sql data type
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var (
reflectType = field.Struct.Type
dataType, _ = field.TagSettingsGet("TYPE")
)
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))
if gormDataType, ok := fieldValue.Interface().(interface {
GormDataType(Dialect) string
}); ok {
dataType = gormDataType.GormDataType(dialect)
}
// Get scanner's real value
if dataType == "" {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
fieldValue = value
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
getScannerValue(fieldValue.Field(0))
}
}
getScannerValue(fieldValue)
}
// Default Size
if num, ok := field.TagSettingsGet("SIZE"); ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}
// Default type from tag setting
notNull, _ := field.TagSettingsGet("NOT NULL")
unique, _ := field.TagSettingsGet("UNIQUE")
additionalType = notNull + " " + unique
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
additionalType = additionalType + " DEFAULT " + value
}
if value, ok := field.TagSettingsGet("COMMENT"); ok {
additionalType = additionalType + " COMMENT " + value
}
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}