Refactor tests files

This commit is contained in:
Jinzhu 2020-06-02 10:34:50 +08:00
parent 5790ba9ef4
commit 8bb05a5a69
58 changed files with 184 additions and 1425 deletions

View File

@ -7,7 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
func BenchmarkSelect(b *testing.B) {

View File

@ -9,7 +9,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
var db, _ = gorm.Open(tests.DummyDialector{}, nil)

View File

@ -8,7 +8,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
func TestExpr(t *testing.T) {

View File

@ -1,225 +0,0 @@
package mssql
import (
"reflect"
"sort"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func Create(db *gorm.DB) {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
setIdentityInsert := false
c := db.Statement.Clauses["ON CONFLICT"]
onConflict, hasConflict := c.Expression.(clause.OnConflict)
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
setIdentityInsert = false
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
_, isZero := field.ValueOf(db.Statement.ReflectValue)
setIdentityInsert = !isZero
case reflect.Slice:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
_, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i))
setIdentityInsert = !isZero
break
}
}
if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) {
setIdentityInsert = true
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" ON;")
} else {
setIdentityInsert = false
}
}
if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 {
MergeCreate(db, onConflict)
} else {
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
db.Statement.Build("INSERT")
db.Statement.WriteByte(' ')
db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok {
if len(values.Columns) > 0 {
db.Statement.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column)
}
db.Statement.WriteByte(')')
outputInserted(db)
db.Statement.WriteString(" VALUES ")
for idx, value := range values.Values {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteByte('(')
db.Statement.AddVar(db.Statement, value...)
db.Statement.WriteByte(')')
}
db.Statement.WriteString(";")
} else {
db.Statement.WriteString("DEFAULT VALUES;")
}
}
}
if setIdentityInsert {
db.Statement.WriteString("SET IDENTITY_INSERT ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" OFF;")
}
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
sortedKeys := []string{}
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
sortedKeys = append(sortedKeys, field.DBName)
}
sort.Strings(sortedKeys)
returnningFields := make([]*schema.Field, len(sortedKeys))
for idx, key := range sortedKeys {
returnningFields[idx] = db.Statement.Schema.LookUpField(key)
}
values := make([]interface{}, len(returnningFields))
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for rows.Next() {
for idx, field := range returnningFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
case reflect.Struct:
for idx, field := range returnningFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
}
}
} else {
db.AddError(err)
}
}
}
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
values := callbacks.ConvertToCreateValues(db.Statement)
db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
db.Statement.WriteString(" USING (VALUES")
for idx, value := range values.Values {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteByte('(')
db.Statement.AddVar(db.Statement, value...)
db.Statement.WriteByte(')')
}
db.Statement.WriteString(") AS source (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column.Name)
}
db.Statement.WriteString(") ON ")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "source", Name: field.DBName},
})
}
where.Build(db.Statement)
if len(onConflict.DoUpdates) > 0 {
db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(column.Name)
}
db.Statement.WriteString(") VALUES (")
for idx, column := range values.Columns {
if idx > 0 {
db.Statement.WriteByte(',')
}
db.Statement.WriteQuoted(clause.Column{
Table: "source",
Name: column.Name,
})
}
db.Statement.WriteString(")")
outputInserted(db)
db.Statement.WriteString(";")
}
func outputInserted(db *gorm.DB) {
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
sortedKeys := []string{}
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
sortedKeys = append(sortedKeys, field.DBName)
}
sort.Strings(sortedKeys)
db.Statement.WriteString(" OUTPUT")
for idx, key := range sortedKeys {
if idx > 0 {
db.Statement.WriteString(",")
}
db.Statement.WriteString(" INSERTED.")
db.Statement.AddVar(db.Statement, clause.Column{Name: key})
}
}
}

View File

@ -1,142 +0,0 @@
package mssql
import (
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
)
type Migrator struct {
migrator.Migrator
}
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?",
stmt.Table, m.CurrentDatabase(),
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable string
if v, ok := oldName.(string); ok {
oldTable = v
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(oldName); err == nil {
oldTable = stmt.Table
} else {
return err
}
}
if v, ok := newName.(string); ok {
newTable = v
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(newName); err == nil {
newTable = stmt.Table
} else {
return err
}
}
return m.DB.Exec(
"sp_rename @objname = ?, @newname = ?;",
clause.Table{Name: oldTable}, clause.Table{Name: newTable},
).Error
}
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?",
currentDatabase, stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
return m.DB.Exec(
"sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';",
fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName},
).Error
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Raw(
"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)",
name, stmt.Table,
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';",
fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName},
).Error
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw(
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`,
name, stmt.Table, m.CurrentDatabase(),
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
return
}

View File

@ -1,127 +0,0 @@
package mssql
import (
"database/sql"
"fmt"
"regexp"
"strconv"
_ "github.com/denisenkom/go-mssqldb"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Dialector struct {
DSN string
}
func (dialector Dialector) Name() string {
return "mssql"
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
db.Callback().Create().Replace("gorm:create", Create)
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"LIMIT": func(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Offset > 0 {
builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
builder.WriteString("ROWS")
}
if limit.Limit > 0 {
if limit.Offset == 0 {
builder.WriteString(" OFFSET 0 ROWS")
}
builder.WriteString(" FETCH NEXT ")
builder.WriteString(strconv.Itoa(limit.Limit))
builder.WriteString(" ROWS ONLY")
}
}
},
}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteString("@p")
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('"')
writer.WriteString(str)
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "bit"
case schema.Int, schema.Uint:
var sqlType string
switch {
case field.Size < 16:
sqlType = "smallint"
case field.Size < 31:
sqlType = "int"
default:
sqlType = "bigint"
}
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
return sqlType + " IDENTITY(1,1)"
}
return sqlType
case schema.Float:
return "float"
case schema.String:
size := field.Size
if field.PrimaryKey && size == 0 {
size = 256
}
if size > 0 && size <= 4000 {
return fmt.Sprintf("nvarchar(%d)", size)
}
return "nvarchar(MAX)"
case schema.Time:
return "datetimeoffset"
case schema.Bytes:
return "varbinary(MAX)"
}
return ""
}

View File

@ -1,58 +0,0 @@
package mysql
import (
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
)
type Migrator struct {
migrator.Migrator
}
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec(
"ALTER TABLE ? MODIFY COLUMN ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
return nil
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if chk.Name == name {
return m.DB.Exec(
"ALTER TABLE ? DROP CHECK ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
).Error
}
}
return m.DB.Exec(
"ALTER TABLE ? DROP FOREIGN KEY ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
).Error
})
}

View File

@ -1,169 +0,0 @@
package mysql
import (
"database/sql"
"fmt"
"math"
_ "github.com/go-sql-driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Dialector struct {
DSN string
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "mysql"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"ON CONFLICT": func(c clause.Clause, builder clause.Builder) {
if onConflict, ok := c.Expression.(clause.OnConflict); ok {
builder.WriteString("ON DUPLICATE KEY UPDATE ")
if len(onConflict.DoUpdates) == 0 {
if s := builder.(*gorm.Statement).Schema; s != nil {
var column clause.Column
onConflict.DoNothing = false
if s.PrioritizedPrimaryField != nil {
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
} else {
for _, field := range s.FieldsByDBName {
column = clause.Column{Name: field.DBName}
break
}
}
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
}
}
onConflict.DoUpdates.Build(builder)
} else {
c.Build(builder)
}
},
"VALUES": func(c clause.Clause, builder clause.Builder) {
if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
builder.WriteString("VALUES()")
return
}
c.Build(builder)
},
}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
writer.WriteString(str)
writer.WriteByte('`')
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
sqlType := "int"
switch {
case field.Size <= 8:
sqlType = "tinyint"
case field.Size <= 16:
sqlType = "smallint"
case field.Size <= 32:
sqlType = "int"
default:
sqlType = "bigint"
}
if field.DataType == schema.Uint {
sqlType += " unsigned"
}
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
sqlType += " AUTO_INCREMENT"
}
return sqlType
case schema.Float:
if field.Size <= 32 {
return "float"
}
return "double"
case schema.String:
size := field.Size
if size == 0 {
if field.PrimaryKey || field.HasDefaultValue {
size = 256
}
}
if size >= 65536 && size <= int(math.Pow(2, 24)) {
return "mediumtext"
} else if size > int(math.Pow(2, 24)) || size <= 0 {
return "longtext"
}
return fmt.Sprintf("varchar(%d)", size)
case schema.Time:
precision := ""
if field.Precision == 0 {
field.Precision = 3
}
if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision)
}
if field.NotNull || field.PrimaryKey {
return "datetime" + precision
}
return "datetime" + precision + " NULL"
case schema.Bytes:
if field.Size > 0 && field.Size < 65536 {
return fmt.Sprintf("varbinary(%d)", field.Size)
}
if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
return "mediumblob"
}
return "longblob"
}
return ""
}

View File

@ -1,139 +0,0 @@
package postgres
import (
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Migrator struct {
migrator.Migrator
}
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Raw(
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER INDEX ? RENAME TO ?",
clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count)
})
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return nil
}
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}

View File

@ -1,102 +0,0 @@
package postgres
import (
"database/sql"
"fmt"
"regexp"
"strconv"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/lib/pq"
)
type Dialector struct {
DSN string
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "postgres"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
WithReturning: true,
})
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
return
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('$')
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('"')
writer.WriteString(str)
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
switch {
case field.Size < 16:
return "smallserial"
case field.Size < 31:
return "serial"
default:
return "bigserial"
}
} else {
switch {
case field.Size < 16:
return "smallint"
case field.Size < 31:
return "integer"
default:
return "bigint"
}
}
case schema.Float:
return "decimal"
case schema.String:
if field.Size > 0 {
return fmt.Sprintf("varchar(%d)", field.Size)
}
return "text"
case schema.Time:
return "timestamptz"
case schema.Bytes:
return "bytea"
}
return ""
}

View File

@ -1,211 +0,0 @@
package sqlite
import (
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Migrator struct {
migrator.Migrator
}
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
return m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
var (
createSQL string
newTableName = stmt.Table + "__temp"
)
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createSQL = reg.ReplaceAllString(createSQL, "?")
var columns []string
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, columnType := range columnTypes {
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
}
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error
} else {
return err
}
} else {
return fmt.Errorf("failed to alter field with name %v", name)
}
})
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
var (
createSQL string
newTableName = stmt.Table + "__temp"
)
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createSQL = reg.ReplaceAllString(createSQL, "")
var columns []string
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, columnType := range columnTypes {
if columnType.Name() != name {
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
}
}
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
return m.DB.Exec(createSQL).Error
} else {
return err
}
})
}
func (m Migrator) CreateConstraint(interface{}, string) error {
return gorm.ErrNotImplemented
}
func (m Migrator) DropConstraint(interface{}, string) error {
return gorm.ErrNotImplemented
}
func (m Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
).Row().Scan(&count)
return nil
})
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}

View File

@ -1,80 +0,0 @@
package sqlite
import (
"database/sql"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/mattn/go-sqlite3"
)
type Dialector struct {
DSN string
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
LastInsertIDReversed: true,
})
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
return
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
writer.WriteString(str)
writer.WriteByte('`')
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement {
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
return "integer"
}
case schema.Float:
return "real"
case schema.String:
return "text"
case schema.Time:
return "datetime"
case schema.Bytes:
return "blob"
}
return ""
}

6
go.mod
View File

@ -3,12 +3,6 @@ module gorm.io/gorm
go 1.14
require (
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
github.com/go-sql-driver/mysql v1.5.0
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.1.1
github.com/mattn/go-sqlite3 v2.0.1+incompatible
gorm.io/gorm v1.9.12
)

View File

@ -9,7 +9,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
func TestFieldValuerAndSetter(t *testing.T) {

View File

@ -5,7 +5,7 @@ import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
type User struct {

View File

@ -7,7 +7,7 @@ import (
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {

View File

@ -5,7 +5,7 @@ import (
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/tests"
"gorm.io/gorm/utils/tests"
)
func TestParseSchema(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestBelongsToAssociation(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestHasManyAssociation(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestHasOneAssociation(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestMany2ManyAssociation(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {

View File

@ -4,7 +4,7 @@ import (
"fmt"
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestCount(t *testing.T) {

View File

@ -6,7 +6,7 @@ import (
"github.com/jinzhu/now"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestCreate(t *testing.T) {

View File

@ -3,8 +3,6 @@ package tests_test
import (
"testing"
"time"
. "gorm.io/gorm/tests"
)
func TestCustomizeColumn(t *testing.T) {

View File

@ -5,7 +5,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestDelete(t *testing.T) {

View File

@ -4,7 +4,6 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
)
func TestEmbeddedStruct(t *testing.T) {

14
tests/go.mod Normal file
View File

@ -0,0 +1,14 @@
module gorm.io/gorm/tests
go 1.14
require (
github.com/jinzhu/now v1.1.1
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8
gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30
gorm.io/gorm v1.9.12
)
replace gorm.io/gorm => ../

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestGroupBy(t *testing.T) {

View File

@ -1,17 +1,13 @@
package tests
package tests_test
import (
"database/sql/driver"
"fmt"
"go/ast"
"reflect"
"sort"
"strconv"
"strings"
"testing"
"time"
"gorm.io/gorm/utils"
. "gorm.io/gorm/utils/tests"
)
type Config struct {
@ -73,101 +69,6 @@ func GetUser(name string, config Config) *User {
return &user
}
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
for _, name := range names {
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
t.Run(name, func(t *testing.T) {
AssertEqual(t, got, expect)
})
}
}
func AssertEqual(t *testing.T, got, expect interface{}) {
if !reflect.DeepEqual(got, expect) {
isEqual := func() {
if curTime, ok := got.(time.Time); ok {
format := "2006-01-02T15:04:05Z07:00"
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) {
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
}
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
}
}
if fmt.Sprint(got) == fmt.Sprint(expect) {
return
}
if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if valuer, ok := got.(driver.Valuer); ok {
got, _ = valuer.Value()
}
if valuer, ok := expect.(driver.Valuer); ok {
expect, _ = valuer.Value()
}
if got != nil {
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
}
if expect != nil {
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
}
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if reflect.ValueOf(got).Kind() == reflect.Slice {
if reflect.ValueOf(expect).Kind() == reflect.Slice {
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
t.Run(name, func(t *testing.T) {
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
})
}
} else {
name := reflect.ValueOf(got).Type().Elem().Name()
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
}
return
}
}
if reflect.ValueOf(got).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
field := reflect.ValueOf(got).Field(i)
t.Run(fieldStruct.Name, func(t *testing.T) {
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
}
}
return
}
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual()
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
isEqual()
}
}
}
func CheckPet(t *testing.T, pet Pet, expect Pet) {
if pet.ID != 0 {
var newPet Pet

View File

@ -6,7 +6,6 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
)
type Product struct {

View File

@ -5,7 +5,6 @@ import (
"time"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
)
type Person struct {

View File

@ -5,7 +5,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestJoins(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestMain(m *testing.M) {

View File

@ -7,7 +7,7 @@ import (
"time"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestMigrate(t *testing.T) {

View File

@ -4,8 +4,6 @@ import (
"reflect"
"sort"
"testing"
. "gorm.io/gorm/tests"
)
type Blog struct {
@ -36,8 +34,8 @@ func compareTags(tags []Tag, contents []string) bool {
}
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags")
@ -125,8 +123,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
}
func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags")
@ -246,8 +244,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
}
func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags")

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
type Hamster struct {

View File

@ -3,8 +3,6 @@ package tests_test
import (
"testing"
"time"
. "gorm.io/gorm/tests"
)
type Animal struct {

View File

@ -7,7 +7,6 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
)
func toJSONString(v interface{}) []byte {
@ -691,8 +690,8 @@ func TestNestedPreload12(t *testing.T) {
}
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
type (

View File

@ -6,7 +6,7 @@ import (
"testing"
"gorm.io/gorm/clause"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestNestedPreload(t *testing.T) {

View File

@ -9,7 +9,7 @@ import (
"time"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestFind(t *testing.T) {

View File

@ -6,7 +6,7 @@ import (
"strings"
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestScan(t *testing.T) {

View File

@ -11,7 +11,7 @@ import (
"time"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestScannerValuer(t *testing.T) {

View File

@ -4,7 +4,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func NameIn1And2(d *gorm.DB) *gorm.DB {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestSoftDelete(t *testing.T) {

View File

@ -4,7 +4,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestRow(t *testing.T) {

View File

@ -18,8 +18,13 @@ for dialect in "${dialects[@]}" ; do
if [ "$GORM_VERBOSE" = "" ]
then
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./...
cd tests
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./...
else
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./...
cd tests
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./...
fi
cd ..
fi
done

View File

@ -1,4 +1,4 @@
package tests
package tests_test
import (
"log"
@ -7,12 +7,13 @@ import (
"path/filepath"
"time"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/dialects/mssql"
"gorm.io/gorm/dialects/mysql"
"gorm.io/gorm/dialects/postgres"
"gorm.io/gorm/dialects/sqlite"
"gorm.io/gorm/logger"
. "gorm.io/gorm/utils/tests"
)
var DB *gorm.DB
@ -40,17 +41,17 @@ func OpenTestConnection() (db *gorm.DB, err error) {
dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
}
db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{})
case "mssql":
case "sqlserver":
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
// CREATE DATABASE gorm;
// USE gorm;
// CREATE USER gorm FROM LOGIN gorm;
// sp_changedbowner 'gorm';
log.Println("testing mssql...")
log.Println("testing sqlserver...")
if dbDSN == "" {
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
}
db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{})
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{})
default:
log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
@ -90,8 +91,3 @@ func RunMigrations() {
}
}
}
func Now() *time.Time {
now := time.Now()
return &now
}

View File

@ -6,7 +6,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestTransaction(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestUpdateBelongsTo(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestUpdateHasManyAssociations(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestUpdateHasOne(t *testing.T) {

View File

@ -3,7 +3,7 @@ package tests_test
import (
"testing"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestUpdateMany2ManyAssociations(t *testing.T) {

View File

@ -8,7 +8,7 @@ import (
"time"
"gorm.io/gorm"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestUpdate(t *testing.T) {

View File

@ -5,7 +5,7 @@ import (
"time"
"gorm.io/gorm/clause"
. "gorm.io/gorm/tests"
. "gorm.io/gorm/utils/tests"
)
func TestUpsert(t *testing.T) {

112
utils/tests/utils.go Normal file
View File

@ -0,0 +1,112 @@
package tests
import (
"database/sql/driver"
"fmt"
"go/ast"
"reflect"
"testing"
"time"
"gorm.io/gorm/utils"
)
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
for _, name := range names {
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
t.Run(name, func(t *testing.T) {
AssertEqual(t, got, expect)
})
}
}
func AssertEqual(t *testing.T, got, expect interface{}) {
if !reflect.DeepEqual(got, expect) {
isEqual := func() {
if curTime, ok := got.(time.Time); ok {
format := "2006-01-02T15:04:05Z07:00"
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) {
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
}
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
}
}
if fmt.Sprint(got) == fmt.Sprint(expect) {
return
}
if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if valuer, ok := got.(driver.Valuer); ok {
got, _ = valuer.Value()
}
if valuer, ok := expect.(driver.Valuer); ok {
expect, _ = valuer.Value()
}
if got != nil {
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
}
if expect != nil {
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
}
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if reflect.ValueOf(got).Kind() == reflect.Slice {
if reflect.ValueOf(expect).Kind() == reflect.Slice {
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
t.Run(name, func(t *testing.T) {
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
})
}
} else {
name := reflect.ValueOf(got).Type().Elem().Name()
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
}
return
}
}
if reflect.ValueOf(got).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
field := reflect.ValueOf(got).Field(i)
t.Run(fieldStruct.Name, func(t *testing.T) {
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
}
}
return
}
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual()
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
isEqual()
}
}
}
func Now() *time.Time {
now := time.Now()
return &now
}