diff --git a/callback_create.go b/callback_create.go index b8725363..0ba3feac 100644 --- a/callback_create.go +++ b/callback_create.go @@ -71,29 +71,32 @@ func createCallback(scope *Scope) { } } - returningKey := "*" - primaryField := scope.PrimaryField() + var ( + returningColumn = "*" + quotedTableName = scope.QuotedTableName() + primaryField = scope.PrimaryField() + ) + if primaryField != nil { - returningKey = scope.Quote(primaryField.DBName) + returningColumn = scope.Quote(primaryField.DBName) } + lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn) + if len(columns) == 0 { - scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", - scope.QuotedTableName(), - scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), - )) + scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", quotedTableName, lastInsertIdReturningSuffix)) } else { scope.Raw(fmt.Sprintf( "INSERT INTO %v (%v) VALUES (%v) %v", scope.QuotedTableName(), strings.Join(columns, ","), strings.Join(placeholders, ","), - scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), + lastInsertIdReturningSuffix, )) } // execute create sql - if scope.Dialect().SupportLastInsertId() || primaryField == nil { + if lastInsertIdReturningSuffix == "" || primaryField == nil { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() diff --git a/dialect.go b/dialect.go index 1fa61925..72b6b2aa 100644 --- a/dialect.go +++ b/dialect.go @@ -5,21 +5,30 @@ import ( "reflect" ) +// Dialect interface contains behaviors that differ across SQL database type Dialect interface { - BinVar(i int) string + // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 + BindVar(i int) string + // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string - SqlTag(value reflect.Value, size int, autoIncrease bool) string + // DataTypeOf return data's sql type + DataTypeOf(value reflect.Value, size int, autoIncrease bool) string + // HasIndex check has index or not HasIndex(scope *Scope, tableName string, indexName string) bool + // RemoveIndex remove index RemoveIndex(scope *Scope, indexName string) + // HasTable check has table or not HasTable(scope *Scope, tableName string) bool + // HasColumn check has column or not HasColumn(scope *Scope, tableName string, columnName string) bool - CurrentDatabase(scope *Scope) string - ReturningStr(tableName, key string) string + // LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case LimitAndOffsetSQL(limit, offset int) string + // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string - SupportLastInsertId() bool + // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` + LastInsertIdReturningSuffix(tableName, columnName string) string } func NewDialect(driver string) Dialect { diff --git a/dialect_common.go b/dialect_common.go index ade7c068..efc8d642 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -8,7 +8,7 @@ import ( type commonDialect struct{} -func (commonDialect) BinVar(i int) string { +func (commonDialect) BindVar(i int) string { return "$$" // ? } @@ -16,7 +16,7 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "BOOLEAN" @@ -55,7 +55,7 @@ func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) st func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { var ( count int - databaseName = c.CurrentDatabase(scope) + databaseName = c.currentDatabase(scope) ) c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) return count > 0 @@ -68,7 +68,7 @@ func (commonDialect) RemoveIndex(scope *Scope, indexName string) { func (c commonDialect) HasTable(scope *Scope, tableName string) bool { var ( count int - databaseName = c.CurrentDatabase(scope) + databaseName = c.currentDatabase(scope) ) c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName) return count > 0 @@ -77,7 +77,7 @@ func (c commonDialect) HasTable(scope *Scope, tableName string) bool { func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var ( count int - databaseName = c.CurrentDatabase(scope) + databaseName = c.currentDatabase(scope) ) c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) return count > 0 @@ -95,15 +95,11 @@ func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) } -func (commonDialect) CurrentDatabase(scope *Scope) (name string) { +func (commonDialect) currentDatabase(scope *Scope) (name string) { scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) return } -func (commonDialect) ReturningStr(tableName, key string) string { - return "" -} - func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { if limit >= 0 { sql += fmt.Sprintf(" LIMIT %d", limit) @@ -118,6 +114,6 @@ func (commonDialect) SelectFromDummyTable() string { return "" } -func (commonDialect) SupportLastInsertId() bool { - return true +func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string { + return "" } diff --git a/dialect_mssql.go b/dialect_mssql.go index 82fba7d1..c3e21c97 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -10,7 +10,7 @@ type mssql struct { commonDialect } -func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bit" @@ -55,7 +55,7 @@ func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { func (s mssql) HasTable(scope *Scope, tableName string) bool { var ( count int - databaseName = s.CurrentDatabase(scope) + databaseName = s.currentDatabase(scope) ) s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) return count > 0 @@ -64,13 +64,13 @@ func (s mssql) HasTable(scope *Scope, tableName string) bool { func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var ( count int - databaseName = s.CurrentDatabase(scope) + databaseName = s.currentDatabase(scope) ) s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) return count > 0 } -func (s mssql) CurrentDatabase(scope *Scope) (name string) { +func (s mssql) currentDatabase(scope *Scope) (name string) { s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") return } diff --git a/dialect_mysql.go b/dialect_mysql.go index b6f9a22b..e334c7a4 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -14,7 +14,7 @@ func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -60,7 +60,7 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s mysql) CurrentDatabase(scope *Scope) (name string) { +func (s mysql) currentDatabase(scope *Scope) (name string) { s.RawScanString(scope, &name, "SELECT DATABASE()") return } diff --git a/dialect_postgres.go b/dialect_postgres.go index 0b16816c..c4742aec 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -15,11 +15,11 @@ type postgres struct { commonDialect } -func (postgres) BinVar(i int) string { +func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } -func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -80,12 +80,12 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b return count > 0 } -func (s postgres) CurrentDatabase(scope *Scope) (name string) { +func (s postgres) currentDatabase(scope *Scope) (name string) { s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") return } -func (s postgres) ReturningStr(tableName, key string) string { +func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 82546dbb..e1e35bf7 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -10,7 +10,7 @@ type sqlite3 struct { commonDialect } -func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bool" @@ -65,7 +65,7 @@ func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bo return count > 0 } -func (sqlite3) CurrentDatabase(scope *Scope) (name string) { +func (sqlite3) currentDatabase(scope *Scope) (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) diff --git a/main.go b/main.go index 461329fa..10c1b9be 100644 --- a/main.go +++ b/main.go @@ -453,14 +453,6 @@ func (s *DB) RemoveIndex(indexName string) *DB { return scope.db } -func (s *DB) CurrentDatabase() string { - var ( - scope = s.clone().NewScope(s.Value) - name = s.dialect.CurrentDatabase(scope) - ) - return name -} - // AddForeignKey Add foreign key to the given scope // Example: db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { diff --git a/model_struct.go b/model_struct.go index b47f8534..c81dcd88 100644 --- a/model_struct.go +++ b/model_struct.go @@ -555,7 +555,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string { autoIncrease = false } - sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) + sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease) } if strings.TrimSpace(additionalType) == "" { diff --git a/query_test.go b/query_test.go index a7d5bc0e..b762dee5 100644 --- a/query_test.go +++ b/query_test.go @@ -621,14 +621,3 @@ func TestSelectWithArrayInput(t *testing.T) { t.Errorf("Should have selected both age and name") } } - -func TestCurrentDatabase(t *testing.T) { - databaseName := DB.CurrentDatabase() - if err := DB.Error; err != nil { - t.Errorf("Problem getting current db name: %s", err) - } - if databaseName == "" { - t.Errorf("Current db name returned empty; this should never happen!") - } - t.Logf("Got current db name: %v", databaseName) -} diff --git a/scope.go b/scope.go index 8ee4bdd5..6d9303ec 100644 --- a/scope.go +++ b/scope.go @@ -229,7 +229,7 @@ func (scope *Scope) AddToVars(value interface{}) string { } scope.SqlVars = append(scope.SqlVars, value) - return scope.Dialect().BinVar(len(scope.SqlVars)) + return scope.Dialect().BindVar(len(scope.SqlVars)) } type tabler interface { diff --git a/scope_private.go b/scope_private.go index dc1676e8..d5d384af 100644 --- a/scope_private.go +++ b/scope_private.go @@ -518,7 +518,7 @@ func (scope *Scope) createJoinTable(field *StructField) { value := reflect.Indirect(reflect.New(field.Struct.Type)) primaryKeySqlType := field.TagSettings["TYPE"] if primaryKeySqlType == "" { - primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false) + primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false) } sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) @@ -530,7 +530,7 @@ func (scope *Scope) createJoinTable(field *StructField) { value := reflect.Indirect(reflect.New(field.Struct.Type)) primaryKeySqlType := field.TagSettings["TYPE"] if primaryKeySqlType == "" { - primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false) + primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false) } sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))