forked from mirror/gorm
Add sqlite migration tests
This commit is contained in:
parent
215f5e7765
commit
6d58b62fd4
|
@ -8,10 +8,13 @@ import (
|
|||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Select{})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Select{})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
}
|
||||
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
fmt.Println(err)
|
||||
fmt.Println(result)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
} else {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,10 +6,12 @@ import (
|
|||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Select{})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Select{})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
}
|
||||
|
||||
if _, ok := db.Get("rows"); ok {
|
||||
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
|
|
@ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) {
|
|||
|
||||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
stmt := tx.Statement
|
||||
stmt.SQL = strings.Builder{}
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(stmt)
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package clause
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Expression expression interface
|
||||
type Expression interface {
|
||||
|
@ -22,7 +24,7 @@ type Expr struct {
|
|||
func (expr Expr) Build(builder Builder) {
|
||||
sql := expr.SQL
|
||||
for _, v := range expr.Vars {
|
||||
sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1)
|
||||
sql = strings.Replace(sql, "?", builder.AddVar(v), 1)
|
||||
}
|
||||
builder.Write(sql)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestExpr(t *testing.T) {
|
||||
results := []struct {
|
||||
SQL string
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{{
|
||||
SQL: "create table ? (? ?, ? ?)",
|
||||
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
|
||||
Result: "create table `users` (`id` int, `name` text)",
|
||||
}}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
|
||||
if stmt.SQL.String() != result.Result {
|
||||
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
|||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)",
|
||||
stmt.Table, `%"`+name+`" %`, `%`+name+` %`,
|
||||
"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
|
|
|
@ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
|
@ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte {
|
|||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "NUMERIC"
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement {
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "INTEGER"
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "REAL"
|
||||
return "real"
|
||||
case schema.String, schema.Time:
|
||||
return "TEXT"
|
||||
return "text"
|
||||
case schema.Bytes:
|
||||
return "BLOB"
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return ""
|
||||
|
|
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
)
|
||||
|
@ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) {
|
|||
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
tx.callbacks.Raw().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
|
1
go.mod
1
go.mod
|
@ -5,4 +5,5 @@ go 1.13
|
|||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
)
|
||||
|
|
|
@ -18,7 +18,8 @@ type Migrator struct {
|
|||
|
||||
// Config schema config
|
||||
type Config struct {
|
||||
DB *gorm.DB
|
||||
CreateIndexAfterCreateTable bool
|
||||
DB *gorm.DB
|
||||
gorm.Dialector
|
||||
}
|
||||
|
||||
|
@ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
}
|
||||
|
||||
// create join table
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
if !m.DB.Migrator().HasTable(joinValue) {
|
||||
defer m.DB.Migrator().CreateTable(joinValue)
|
||||
if rel.JoinTable != nil {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
if !m.DB.Migrator().HasTable(joinValue) {
|
||||
defer m.DB.Migrator().CreateTable(joinValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||
}
|
||||
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
createTableSQL += "INDEX ? ?,"
|
||||
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
if m.CreateIndexAfterCreateTable {
|
||||
m.DB.Migrator().CreateIndex(value, idx.Name)
|
||||
} else {
|
||||
createTableSQL += "INDEX ? ?,"
|
||||
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
|
@ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||
}
|
||||
|
||||
// create join table
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
if !m.DB.Migrator().HasTable(joinValue) {
|
||||
defer m.DB.Migrator().CreateTable(joinValue)
|
||||
if rel.JoinTable != nil {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
if !m.DB.Migrator().HasTable(joinValue) {
|
||||
defer m.DB.Migrator().CreateTable(joinValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
|
|||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
|||
err := fmt.Errorf("failed to create constraint with name %v", name)
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
for _, cc := range checkConstraints {
|
||||
if err = m.CreateIndex(value, cc.Name); err != nil {
|
||||
if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
|
||||
if err = m.CreateIndex(value, constraint.Name); err != nil {
|
||||
if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string {
|
|||
|
||||
// RelationshipFKName generate fk name for relation
|
||||
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
|
||||
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table)
|
||||
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name))
|
||||
}
|
||||
|
||||
// CheckerName generate checker name
|
||||
|
|
|
@ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
|||
}
|
||||
}
|
||||
|
||||
if constraint.ReferenceSchema == nil {
|
||||
if rel.JoinTable != nil || constraint.ReferenceSchema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string {
|
|||
stmt.Vars = append(stmt.Vars, v.Value)
|
||||
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
|
||||
}
|
||||
case clause.Column:
|
||||
case clause.Column, clause.Table:
|
||||
placeholders.WriteString(stmt.Quote(v))
|
||||
case clause.Expr:
|
||||
placeholders.WriteString(v.SQL)
|
||||
stmt.Vars = append(stmt.Vars, v.Vars...)
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
placeholders.WriteByte('(')
|
||||
|
|
|
@ -2,6 +2,7 @@ package tests
|
|||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
type DummyDialector struct {
|
||||
|
@ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (DummyDialector) Migrator() gorm.Migrator {
|
||||
func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
|||
func (DummyDialector) QuoteChars() [2]byte {
|
||||
return [2]byte{'`', '`'} // `name`
|
||||
}
|
||||
|
||||
func (DummyDialector) DataTypeOf(*schema.Field) string {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -9,11 +9,21 @@ import (
|
|||
func TestMigrate(t *testing.T, db *gorm.DB) {
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
|
||||
|
||||
db.AutoMigrate(allModels...)
|
||||
for _, m := range allModels {
|
||||
if db.Migrator().HasTable(m) {
|
||||
if err := db.Migrator().DropTable(m); err != nil {
|
||||
t.Errorf("Failed to drop table, got error %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(allModels...); err != nil {
|
||||
t.Errorf("Failed to auto migrate, but got error %v", err)
|
||||
}
|
||||
|
||||
for _, m := range allModels {
|
||||
if !db.Migrator().HasTable(m) {
|
||||
t.Errorf("Failed to create table for %+v", m)
|
||||
t.Errorf("Failed to create table for %#v", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue