mirror of https://github.com/go-gorm/gorm.git
Refactor dialect
This commit is contained in:
parent
e159ca1914
commit
d92c5db9e7
|
@ -71,29 +71,32 @@ func createCallback(scope *Scope) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
returningKey := "*"
|
var (
|
||||||
primaryField := scope.PrimaryField()
|
returningColumn = "*"
|
||||||
|
quotedTableName = scope.QuotedTableName()
|
||||||
|
primaryField = scope.PrimaryField()
|
||||||
|
)
|
||||||
|
|
||||||
if primaryField != nil {
|
if primaryField != nil {
|
||||||
returningKey = scope.Quote(primaryField.DBName)
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn)
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
|
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", quotedTableName, lastInsertIdReturningSuffix))
|
||||||
scope.QuotedTableName(),
|
|
||||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
strings.Join(placeholders, ","),
|
strings.Join(placeholders, ","),
|
||||||
scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey),
|
lastInsertIdReturningSuffix,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute create sql
|
// execute create sql
|
||||||
if scope.Dialect().SupportLastInsertId() || primaryField == nil {
|
if lastInsertIdReturningSuffix == "" || primaryField == nil {
|
||||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||||
// set rows affected count
|
// set rows affected count
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
19
dialect.go
19
dialect.go
|
@ -5,21 +5,30 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Dialect interface contains behaviors that differ across SQL database
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
BinVar(i int) string
|
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||||
|
BindVar(i int) string
|
||||||
|
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
// DataTypeOf return data's sql type
|
||||||
|
DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
|
||||||
|
|
||||||
|
// HasIndex check has index or not
|
||||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||||
|
// RemoveIndex remove index
|
||||||
RemoveIndex(scope *Scope, indexName string)
|
RemoveIndex(scope *Scope, indexName string)
|
||||||
|
// HasTable check has table or not
|
||||||
HasTable(scope *Scope, tableName string) bool
|
HasTable(scope *Scope, tableName string) bool
|
||||||
|
// HasColumn check has column or not
|
||||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
HasColumn(scope *Scope, tableName string, columnName string) bool
|
||||||
CurrentDatabase(scope *Scope) string
|
|
||||||
|
|
||||||
ReturningStr(tableName, key string) string
|
// LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case
|
||||||
LimitAndOffsetSQL(limit, offset int) string
|
LimitAndOffsetSQL(limit, offset int) string
|
||||||
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
SupportLastInsertId() bool
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
|
LastInsertIdReturningSuffix(tableName, columnName string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDialect(driver string) Dialect {
|
func NewDialect(driver string) Dialect {
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
|
|
||||||
type commonDialect struct{}
|
type commonDialect struct{}
|
||||||
|
|
||||||
func (commonDialect) BinVar(i int) string {
|
func (commonDialect) BindVar(i int) string {
|
||||||
return "$$" // ?
|
return "$$" // ?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ func (commonDialect) Quote(key string) string {
|
||||||
return fmt.Sprintf(`"%s"`, key)
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "BOOLEAN"
|
return "BOOLEAN"
|
||||||
|
@ -55,7 +55,7 @@ func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) st
|
||||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
var (
|
var (
|
||||||
count int
|
count int
|
||||||
databaseName = c.CurrentDatabase(scope)
|
databaseName = c.currentDatabase(scope)
|
||||||
)
|
)
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
||||||
return count > 0
|
return count > 0
|
||||||
|
@ -68,7 +68,7 @@ func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
||||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
||||||
var (
|
var (
|
||||||
count int
|
count int
|
||||||
databaseName = c.CurrentDatabase(scope)
|
databaseName = c.currentDatabase(scope)
|
||||||
)
|
)
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
|
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
|
||||||
return count > 0
|
return count > 0
|
||||||
|
@ -77,7 +77,7 @@ func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
||||||
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var (
|
var (
|
||||||
count int
|
count int
|
||||||
databaseName = c.CurrentDatabase(scope)
|
databaseName = c.currentDatabase(scope)
|
||||||
)
|
)
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
||||||
return count > 0
|
return count > 0
|
||||||
|
@ -95,15 +95,11 @@ func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string,
|
||||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
|
func (commonDialect) currentDatabase(scope *Scope) (name string) {
|
||||||
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
|
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) ReturningStr(tableName, key string) string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
|
||||||
if limit >= 0 {
|
if limit >= 0 {
|
||||||
sql += fmt.Sprintf(" LIMIT %d", limit)
|
sql += fmt.Sprintf(" LIMIT %d", limit)
|
||||||
|
@ -118,6 +114,6 @@ func (commonDialect) SelectFromDummyTable() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) SupportLastInsertId() bool {
|
func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string {
|
||||||
return true
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ type mssql struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "bit"
|
return "bit"
|
||||||
|
@ -55,7 +55,7 @@ func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
||||||
var (
|
var (
|
||||||
count int
|
count int
|
||||||
databaseName = s.CurrentDatabase(scope)
|
databaseName = s.currentDatabase(scope)
|
||||||
)
|
)
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
|
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
|
||||||
return count > 0
|
return count > 0
|
||||||
|
@ -64,13 +64,13 @@ func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
||||||
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var (
|
var (
|
||||||
count int
|
count int
|
||||||
databaseName = s.CurrentDatabase(scope)
|
databaseName = s.currentDatabase(scope)
|
||||||
)
|
)
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) CurrentDatabase(scope *Scope) (name string) {
|
func (s mssql) currentDatabase(scope *Scope) (name string) {
|
||||||
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
|
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ func (mysql) Quote(key string) string {
|
||||||
return fmt.Sprintf("`%s`", key)
|
return fmt.Sprintf("`%s`", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
|
@ -60,7 +60,7 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
|
func (s mysql) currentDatabase(scope *Scope) (name string) {
|
||||||
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,11 +15,11 @@ type postgres struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (postgres) BinVar(i int) string {
|
func (postgres) BindVar(i int) string {
|
||||||
return fmt.Sprintf("$%v", i)
|
return fmt.Sprintf("$%v", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
|
@ -80,12 +80,12 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
|
func (s postgres) currentDatabase(scope *Scope) (name string) {
|
||||||
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
|
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) ReturningStr(tableName, key string) string {
|
func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ type sqlite3 struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "bool"
|
return "bool"
|
||||||
|
@ -65,7 +65,7 @@ func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bo
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
|
func (sqlite3) currentDatabase(scope *Scope) (name string) {
|
||||||
var (
|
var (
|
||||||
ifaces = make([]interface{}, 3)
|
ifaces = make([]interface{}, 3)
|
||||||
pointers = make([]*string, 3)
|
pointers = make([]*string, 3)
|
||||||
|
|
8
main.go
8
main.go
|
@ -453,14 +453,6 @@ func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) CurrentDatabase() string {
|
|
||||||
var (
|
|
||||||
scope = s.clone().NewScope(s.Value)
|
|
||||||
name = s.dialect.CurrentDatabase(scope)
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddForeignKey Add foreign key to the given scope
|
// AddForeignKey Add foreign key to the given scope
|
||||||
// Example: db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
// Example: db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||||
|
|
|
@ -555,7 +555,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||||
autoIncrease = false
|
autoIncrease = false
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease)
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
|
|
@ -621,14 +621,3 @@ func TestSelectWithArrayInput(t *testing.T) {
|
||||||
t.Errorf("Should have selected both age and name")
|
t.Errorf("Should have selected both age and name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCurrentDatabase(t *testing.T) {
|
|
||||||
databaseName := DB.CurrentDatabase()
|
|
||||||
if err := DB.Error; err != nil {
|
|
||||||
t.Errorf("Problem getting current db name: %s", err)
|
|
||||||
}
|
|
||||||
if databaseName == "" {
|
|
||||||
t.Errorf("Current db name returned empty; this should never happen!")
|
|
||||||
}
|
|
||||||
t.Logf("Got current db name: %v", databaseName)
|
|
||||||
}
|
|
||||||
|
|
2
scope.go
2
scope.go
|
@ -229,7 +229,7 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.SqlVars = append(scope.SqlVars, value)
|
scope.SqlVars = append(scope.SqlVars, value)
|
||||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
return scope.Dialect().BindVar(len(scope.SqlVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
type tabler interface {
|
type tabler interface {
|
||||||
|
|
|
@ -518,7 +518,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
primaryKeySqlType := field.TagSettings["TYPE"]
|
||||||
if primaryKeySqlType == "" {
|
if primaryKeySqlType == "" {
|
||||||
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
|
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
|
||||||
}
|
}
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||||
|
@ -530,7 +530,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
primaryKeySqlType := field.TagSettings["TYPE"]
|
||||||
if primaryKeySqlType == "" {
|
if primaryKeySqlType == "" {
|
||||||
primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false)
|
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
|
||||||
}
|
}
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||||
|
|
Loading…
Reference in New Issue