Add sqlite migration tests

This commit is contained in:
Jinzhu 2020-02-22 20:57:29 +08:00
parent 215f5e7765
commit 6d58b62fd4
16 changed files with 117 additions and 41 deletions

View File

@ -8,10 +8,13 @@ import (
)
func Query(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err)
fmt.Println(result)

View File

@ -1,11 +1,14 @@
package callbacks
import "github.com/jinzhu/gorm"
import (
"github.com/jinzhu/gorm"
)
func RawExec(db *gorm.DB) {
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.RowsAffected, _ = result.RowsAffected()
if err != nil {
db.AddError(err)
} else {
db.RowsAffected, _ = result.RowsAffected()
}
}

View File

@ -6,10 +6,12 @@ import (
)
func RowQuery(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
if _, ok := db.Get("rows"); ok {
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)

View File

@ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) {
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
stmt := tx.Statement
stmt.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(stmt)
tx.Statement.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
return
}

View File

@ -1,6 +1,8 @@
package clause
import "strings"
import (
"strings"
)
// Expression expression interface
type Expression interface {
@ -22,7 +24,7 @@ type Expr struct {
func (expr Expr) Build(builder Builder) {
sql := expr.SQL
for _, v := range expr.Vars {
sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1)
sql = strings.Replace(sql, "?", builder.AddVar(v), 1)
}
builder.Write(sql)
}

35
clause/expression_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"sync"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/tests"
)
func TestExpr(t *testing.T) {
results := []struct {
SQL string
Result string
Vars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
})
}
}

View File

@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
}
return m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)",
stmt.Table, `%"`+name+`" %`, `%`+name+` %`,
"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
).Row().Scan(&count)
})
return count > 0

View File

@ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
@ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte {
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "NUMERIC"
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement {
// https://www.sqlite.org/autoinc.html
return "INTEGER PRIMARY KEY AUTOINCREMENT"
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
return "INTEGER"
return "integer"
}
case schema.Float:
return "REAL"
return "real"
case schema.String, schema.Time:
return "TEXT"
return "text"
case schema.Bytes:
return "BLOB"
return "blob"
}
return ""

View File

@ -2,6 +2,7 @@ package gorm
import (
"database/sql"
"strings"
"github.com/jinzhu/gorm/clause"
)
@ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) {
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
tx.callbacks.Raw().Execute(tx)
return
}

1
go.mod
View File

@ -5,4 +5,5 @@ go 1.13
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
)

View File

@ -18,7 +18,8 @@ type Migrator struct {
// Config schema config
type Config struct {
DB *gorm.DB
CreateIndexAfterCreateTable bool
DB *gorm.DB
gorm.Dialector
}
@ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
// create join table
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
if rel.JoinTable != nil {
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
}
}
}
return nil
@ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
for _, idx := range stmt.Schema.ParseIndexes() {
createTableSQL += "INDEX ? ?,"
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
if m.CreateIndexAfterCreateTable {
m.DB.Migrator().CreateIndex(value, idx.Name)
} else {
createTableSQL += "INDEX ? ?,"
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
@ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
// create join table
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
if rel.JoinTable != nil {
joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
}
}
}
@ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
@ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
err := fmt.Errorf("failed to create constraint with name %v", name)
if field := stmt.Schema.LookUpField(name); field != nil {
for _, cc := range checkConstraints {
if err = m.CreateIndex(value, cc.Name); err != nil {
if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
return err
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
if err = m.CreateIndex(value, constraint.Name); err != nil {
if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
return err
}
}

View File

@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string {
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table)
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name))
}
// CheckerName generate checker name

View File

@ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
}
}
if constraint.ReferenceSchema == nil {
if rel.JoinTable != nil || constraint.ReferenceSchema == nil {
return nil
}

View File

@ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string {
stmt.Vars = append(stmt.Vars, v.Value)
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
}
case clause.Column:
case clause.Column, clause.Table:
placeholders.WriteString(stmt.Quote(v))
case clause.Expr:
placeholders.WriteString(v.SQL)
stmt.Vars = append(stmt.Vars, v.Vars...)
case []interface{}:
if len(v) > 0 {
placeholders.WriteByte('(')

View File

@ -2,6 +2,7 @@ package tests
import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
)
type DummyDialector struct {
@ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error {
return nil
}
func (DummyDialector) Migrator() gorm.Migrator {
func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator {
return nil
}
@ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
func (DummyDialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}
func (DummyDialector) DataTypeOf(*schema.Field) string {
return ""
}

View File

@ -9,11 +9,21 @@ import (
func TestMigrate(t *testing.T, db *gorm.DB) {
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
db.AutoMigrate(allModels...)
for _, m := range allModels {
if db.Migrator().HasTable(m) {
if err := db.Migrator().DropTable(m); err != nil {
t.Errorf("Failed to drop table, got error %v", err)
}
}
}
if err := db.AutoMigrate(allModels...); err != nil {
t.Errorf("Failed to auto migrate, but got error %v", err)
}
for _, m := range allModels {
if !db.Migrator().HasTable(m) {
t.Errorf("Failed to create table for %+v", m)
t.Errorf("Failed to create table for %#v", m)
}
}
}