forked from mirror/gorm
Refactor tests files
This commit is contained in:
parent
5790ba9ef4
commit
8bb05a5a69
|
@ -7,7 +7,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
var db, _ = gorm.Open(tests.DummyDialector{}, nil)
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestExpr(t *testing.T) {
|
||||
|
|
|
@ -1,225 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
setIdentityInsert := false
|
||||
c := db.Statement.Clauses["ON CONFLICT"]
|
||||
onConflict, hasConflict := c.Expression.(clause.OnConflict)
|
||||
|
||||
if field := db.Statement.Schema.PrioritizedPrimaryField; field != nil {
|
||||
setIdentityInsert = false
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
_, isZero := field.ValueOf(db.Statement.ReflectValue)
|
||||
setIdentityInsert = !isZero
|
||||
case reflect.Slice:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
_, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i))
|
||||
setIdentityInsert = !isZero
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if setIdentityInsert && (field.DataType == schema.Int || field.DataType == schema.Uint) {
|
||||
setIdentityInsert = true
|
||||
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
db.Statement.WriteString(" ON;")
|
||||
} else {
|
||||
setIdentityInsert = false
|
||||
}
|
||||
}
|
||||
|
||||
if hasConflict && len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||
MergeCreate(db, onConflict)
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
|
||||
db.Statement.Build("INSERT")
|
||||
db.Statement.WriteByte(' ')
|
||||
|
||||
db.Statement.AddClause(callbacks.ConvertToCreateValues(db.Statement))
|
||||
if values, ok := db.Statement.Clauses["VALUES"].Expression.(clause.Values); ok {
|
||||
if len(values.Columns) > 0 {
|
||||
db.Statement.WriteByte('(')
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
db.Statement.WriteQuoted(column)
|
||||
}
|
||||
db.Statement.WriteByte(')')
|
||||
|
||||
outputInserted(db)
|
||||
|
||||
db.Statement.WriteString(" VALUES ")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
db.Statement.WriteByte('(')
|
||||
db.Statement.AddVar(db.Statement, value...)
|
||||
db.Statement.WriteByte(')')
|
||||
}
|
||||
|
||||
db.Statement.WriteString(";")
|
||||
} else {
|
||||
db.Statement.WriteString("DEFAULT VALUES;")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if setIdentityInsert {
|
||||
db.Statement.WriteString("SET IDENTITY_INSERT ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
db.Statement.WriteString(" OFF;")
|
||||
}
|
||||
}
|
||||
|
||||
if !db.DryRun {
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
|
||||
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||
sortedKeys := []string{}
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
sortedKeys = append(sortedKeys, field.DBName)
|
||||
}
|
||||
sort.Strings(sortedKeys)
|
||||
|
||||
returnningFields := make([]*schema.Field, len(sortedKeys))
|
||||
for idx, key := range sortedKeys {
|
||||
returnningFields[idx] = db.Statement.Schema.LookUpField(key)
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(returnningFields))
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for rows.Next() {
|
||||
for idx, field := range returnningFields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, field := range returnningFields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||
values := callbacks.ConvertToCreateValues(db.Statement)
|
||||
|
||||
db.Statement.WriteString("MERGE INTO ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
db.Statement.WriteString(" USING (VALUES")
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
db.Statement.WriteByte('(')
|
||||
db.Statement.AddVar(db.Statement, value...)
|
||||
db.Statement.WriteByte(')')
|
||||
}
|
||||
|
||||
db.Statement.WriteString(") AS source (")
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
db.Statement.WriteString(") ON ")
|
||||
|
||||
var where clause.Where
|
||||
for _, field := range db.Statement.Schema.PrimaryFields {
|
||||
where.Exprs = append(where.Exprs, clause.Eq{
|
||||
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
|
||||
Value: clause.Column{Table: "source", Name: field.DBName},
|
||||
})
|
||||
}
|
||||
where.Build(db.Statement)
|
||||
|
||||
if len(onConflict.DoUpdates) > 0 {
|
||||
db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(db.Statement)
|
||||
}
|
||||
|
||||
db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
|
||||
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
|
||||
db.Statement.WriteString(") VALUES (")
|
||||
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
db.Statement.WriteQuoted(clause.Column{
|
||||
Table: "source",
|
||||
Name: column.Name,
|
||||
})
|
||||
}
|
||||
|
||||
db.Statement.WriteString(")")
|
||||
outputInserted(db)
|
||||
db.Statement.WriteString(";")
|
||||
}
|
||||
|
||||
func outputInserted(db *gorm.DB) {
|
||||
if len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||
sortedKeys := []string{}
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
sortedKeys = append(sortedKeys, field.DBName)
|
||||
}
|
||||
sort.Strings(sortedKeys)
|
||||
|
||||
db.Statement.WriteString(" OUTPUT")
|
||||
for idx, key := range sortedKeys {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteString(",")
|
||||
}
|
||||
db.Statement.WriteString(" INSERTED.")
|
||||
db.Statement.AddVar(db.Statement, clause.Column{Name: key})
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,142 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?",
|
||||
stmt.Table, m.CurrentDatabase(),
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
||||
var oldTable, newTable string
|
||||
if v, ok := oldName.(string); ok {
|
||||
oldTable = v
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if err := stmt.Parse(oldName); err == nil {
|
||||
oldTable = stmt.Table
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := newName.(string); ok {
|
||||
newTable = v
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if err := stmt.Parse(newName); err == nil {
|
||||
newTable = stmt.Table
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"sp_rename @objname = ?, @newname = ?;",
|
||||
clause.Table{Name: oldTable}, clause.Table{Name: newTable},
|
||||
).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
name := field
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?",
|
||||
currentDatabase, stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';",
|
||||
fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)",
|
||||
name, stmt.Table,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
|
||||
return m.DB.Exec(
|
||||
"sp_rename @objname = ?, @newname = ?, @objtype = 'INDEX';",
|
||||
fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw(
|
||||
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`,
|
||||
name, stmt.Table, m.CurrentDatabase(),
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
|
||||
return
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "mssql"
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
|
||||
db.Callback().Create().Replace("gorm:create", Create)
|
||||
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"LIMIT": func(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
if limit.Offset > 0 {
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
builder.WriteString("ROWS")
|
||||
}
|
||||
|
||||
if limit.Limit > 0 {
|
||||
if limit.Offset == 0 {
|
||||
builder.WriteString(" OFFSET 0 ROWS")
|
||||
}
|
||||
builder.WriteString(" FETCH NEXT ")
|
||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
||||
builder.WriteString(" ROWS ONLY")
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteString("@p")
|
||||
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('"')
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('"')
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "bit"
|
||||
case schema.Int, schema.Uint:
|
||||
var sqlType string
|
||||
switch {
|
||||
case field.Size < 16:
|
||||
sqlType = "smallint"
|
||||
case field.Size < 31:
|
||||
sqlType = "int"
|
||||
default:
|
||||
sqlType = "bigint"
|
||||
}
|
||||
|
||||
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
|
||||
return sqlType + " IDENTITY(1,1)"
|
||||
}
|
||||
return sqlType
|
||||
case schema.Float:
|
||||
return "float"
|
||||
case schema.String:
|
||||
size := field.Size
|
||||
if field.PrimaryKey && size == 0 {
|
||||
size = 256
|
||||
}
|
||||
if size > 0 && size <= 4000 {
|
||||
return fmt.Sprintf("nvarchar(%d)", size)
|
||||
}
|
||||
return "nvarchar(MAX)"
|
||||
case schema.Time:
|
||||
return "datetimeoffset"
|
||||
case schema.Bytes:
|
||||
return "varbinary(MAX)"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? MODIFY COLUMN ? ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if chk.Name == name {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP CHECK ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
|
||||
).Error
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP FOREIGN KEY ?",
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: name},
|
||||
).Error
|
||||
})
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
|
||||
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"ON CONFLICT": func(c clause.Clause, builder clause.Builder) {
|
||||
if onConflict, ok := c.Expression.(clause.OnConflict); ok {
|
||||
builder.WriteString("ON DUPLICATE KEY UPDATE ")
|
||||
if len(onConflict.DoUpdates) == 0 {
|
||||
if s := builder.(*gorm.Statement).Schema; s != nil {
|
||||
var column clause.Column
|
||||
onConflict.DoNothing = false
|
||||
|
||||
if s.PrioritizedPrimaryField != nil {
|
||||
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
|
||||
} else {
|
||||
for _, field := range s.FieldsByDBName {
|
||||
column = clause.Column{Name: field.DBName}
|
||||
break
|
||||
}
|
||||
}
|
||||
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict.DoUpdates.Build(builder)
|
||||
} else {
|
||||
c.Build(builder)
|
||||
}
|
||||
},
|
||||
"VALUES": func(c clause.Clause, builder clause.Builder) {
|
||||
if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
|
||||
builder.WriteString("VALUES()")
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('`')
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "boolean"
|
||||
case schema.Int, schema.Uint:
|
||||
sqlType := "int"
|
||||
switch {
|
||||
case field.Size <= 8:
|
||||
sqlType = "tinyint"
|
||||
case field.Size <= 16:
|
||||
sqlType = "smallint"
|
||||
case field.Size <= 32:
|
||||
sqlType = "int"
|
||||
default:
|
||||
sqlType = "bigint"
|
||||
}
|
||||
|
||||
if field.DataType == schema.Uint {
|
||||
sqlType += " unsigned"
|
||||
}
|
||||
|
||||
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
|
||||
sqlType += " AUTO_INCREMENT"
|
||||
}
|
||||
return sqlType
|
||||
case schema.Float:
|
||||
if field.Size <= 32 {
|
||||
return "float"
|
||||
}
|
||||
return "double"
|
||||
case schema.String:
|
||||
size := field.Size
|
||||
if size == 0 {
|
||||
if field.PrimaryKey || field.HasDefaultValue {
|
||||
size = 256
|
||||
}
|
||||
}
|
||||
|
||||
if size >= 65536 && size <= int(math.Pow(2, 24)) {
|
||||
return "mediumtext"
|
||||
} else if size > int(math.Pow(2, 24)) || size <= 0 {
|
||||
return "longtext"
|
||||
}
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
case schema.Time:
|
||||
precision := ""
|
||||
if field.Precision == 0 {
|
||||
field.Precision = 3
|
||||
}
|
||||
|
||||
if field.Precision > 0 {
|
||||
precision = fmt.Sprintf("(%d)", field.Precision)
|
||||
}
|
||||
|
||||
if field.NotNull || field.PrimaryKey {
|
||||
return "datetime" + precision
|
||||
}
|
||||
return "datetime" + precision + " NULL"
|
||||
case schema.Bytes:
|
||||
if field.Size > 0 && field.Size < 65536 {
|
||||
return fmt.Sprintf("varbinary(%d)", field.Size)
|
||||
}
|
||||
|
||||
if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
|
||||
return "mediumblob"
|
||||
}
|
||||
|
||||
return "longblob"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
|
@ -1,139 +0,0 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Exec(
|
||||
"ALTER INDEX ? RENAME TO ?",
|
||||
clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
name := field
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
|
||||
stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
|
@ -1,102 +0,0 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
WithReturning: true,
|
||||
})
|
||||
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('$')
|
||||
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('"')
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('"')
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "boolean"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
|
||||
switch {
|
||||
case field.Size < 16:
|
||||
return "smallserial"
|
||||
case field.Size < 31:
|
||||
return "serial"
|
||||
default:
|
||||
return "bigserial"
|
||||
}
|
||||
} else {
|
||||
switch {
|
||||
case field.Size < 16:
|
||||
return "smallint"
|
||||
case field.Size < 31:
|
||||
return "integer"
|
||||
default:
|
||||
return "bigint"
|
||||
}
|
||||
}
|
||||
case schema.Float:
|
||||
return "decimal"
|
||||
case schema.String:
|
||||
if field.Size > 0 {
|
||||
return fmt.Sprintf("varchar(%d)", field.Size)
|
||||
}
|
||||
return "text"
|
||||
case schema.Time:
|
||||
return "timestamptz"
|
||||
case schema.Bytes:
|
||||
return "bytea"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
|
@ -1,211 +0,0 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
var (
|
||||
createSQL string
|
||||
newTableName = stmt.Table + "__temp"
|
||||
)
|
||||
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
|
||||
|
||||
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
|
||||
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
createSQL = reg.ReplaceAllString(createSQL, "?")
|
||||
|
||||
var columns []string
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
for _, columnType := range columnTypes {
|
||||
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
|
||||
}
|
||||
|
||||
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
|
||||
return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to alter field with name %v", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
var (
|
||||
createSQL string
|
||||
newTableName = stmt.Table + "__temp"
|
||||
)
|
||||
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL)
|
||||
|
||||
if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil {
|
||||
tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
createSQL = reg.ReplaceAllString(createSQL, "")
|
||||
|
||||
var columns []string
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() != name {
|
||||
columns = append(columns, fmt.Sprintf("`%v`", columnType.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table)
|
||||
|
||||
return m.DB.Exec(createSQL).Error
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(interface{}, string) error {
|
||||
return gorm.ErrNotImplemented
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(interface{}, string) error {
|
||||
return gorm.ErrNotImplemented
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('`')
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement {
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "real"
|
||||
case schema.String:
|
||||
return "text"
|
||||
case schema.Time:
|
||||
return "datetime"
|
||||
case schema.Bytes:
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
6
go.mod
6
go.mod
|
@ -3,12 +3,6 @@ module gorm.io/gorm
|
|||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
|
||||
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.1
|
||||
github.com/lib/pq v1.1.1
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible
|
||||
gorm.io/gorm v1.9.12
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestFieldValuerAndSetter(t *testing.T) {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/tests"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestParseSchema(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestBelongsToAssociation(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestHasManyAssociation(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestHasOneAssociation(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestMany2ManyAssociation(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) {
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
|
|
|
@ -3,8 +3,6 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
func TestCustomizeColumn(t *testing.T) {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
func TestEmbeddedStruct(t *testing.T) {
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
module gorm.io/gorm/tests
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/jinzhu/now v1.1.1
|
||||
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
|
||||
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
|
||||
gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8
|
||||
gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30
|
||||
gorm.io/gorm v1.9.12
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestGroupBy(t *testing.T) {
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
package tests
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
@ -73,101 +69,6 @@ func GetUser(name string, config Config) *User {
|
|||
return &user
|
||||
}
|
||||
|
||||
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||
for _, name := range names {
|
||||
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
||||
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, got, expect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func AssertEqual(t *testing.T, got, expect interface{}) {
|
||||
if !reflect.DeepEqual(got, expect) {
|
||||
isEqual := func() {
|
||||
if curTime, ok := got.(time.Time); ok {
|
||||
format := "2006-01-02T15:04:05Z07:00"
|
||||
|
||||
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) {
|
||||
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
|
||||
}
|
||||
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
|
||||
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
|
||||
}
|
||||
}
|
||||
|
||||
if fmt.Sprint(got) == fmt.Sprint(expect) {
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() {
|
||||
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||
return
|
||||
}
|
||||
|
||||
if valuer, ok := got.(driver.Valuer); ok {
|
||||
got, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if valuer, ok := expect.(driver.Valuer); ok {
|
||||
expect, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if got != nil {
|
||||
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
|
||||
}
|
||||
|
||||
if expect != nil {
|
||||
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
|
||||
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Kind() == reflect.Slice {
|
||||
if reflect.ValueOf(expect).Kind() == reflect.Slice {
|
||||
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
|
||||
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
|
||||
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
|
||||
})
|
||||
}
|
||||
} else {
|
||||
name := reflect.ValueOf(got).Type().Elem().Name()
|
||||
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Kind() == reflect.Struct {
|
||||
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
|
||||
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
|
||||
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
field := reflect.ValueOf(got).Field(i)
|
||||
t.Run(fieldStruct.Name, func(t *testing.T) {
|
||||
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||
isEqual()
|
||||
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
|
||||
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
|
||||
isEqual()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
||||
if pet.ID != 0 {
|
||||
var newPet Pet
|
|
@ -6,7 +6,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
type Product struct {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestJoins(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
type Blog struct {
|
||||
|
@ -36,8 +34,8 @@ func compareTags(tags []Tag, contents []string) bool {
|
|||
}
|
||||
|
||||
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
|
||||
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags")
|
||||
|
@ -125,8 +123,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
|
||||
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags")
|
||||
|
@ -246,8 +244,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
|
||||
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags")
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type Hamster struct {
|
||||
|
|
|
@ -3,8 +3,6 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
type Animal struct {
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
)
|
||||
|
||||
func toJSONString(v interface{}) []byte {
|
||||
|
@ -691,8 +690,8 @@ func TestNestedPreload12(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "mssql" {
|
||||
t.Skip("skip sqlite, mssql due to it doesn't support multiple primary keys with auto increment")
|
||||
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
type (
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestNestedPreload(t *testing.T) {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestFind(t *testing.T) {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestScan(t *testing.T) {
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestScannerValuer(t *testing.T) {
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestSoftDelete(t *testing.T) {
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestRow(t *testing.T) {
|
||||
|
|
|
@ -18,8 +18,13 @@ for dialect in "${dialects[@]}" ; do
|
|||
if [ "$GORM_VERBOSE" = "" ]
|
||||
then
|
||||
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./...
|
||||
cd tests
|
||||
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 ./...
|
||||
else
|
||||
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./...
|
||||
cd tests
|
||||
DEBUG=false GORM_DIALECT=${dialect} go test $race -count=1 -v ./...
|
||||
fi
|
||||
cd ..
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package tests
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
@ -7,12 +7,13 @@ import (
|
|||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/dialects/mssql"
|
||||
"gorm.io/gorm/dialects/mysql"
|
||||
"gorm.io/gorm/dialects/postgres"
|
||||
"gorm.io/gorm/dialects/sqlite"
|
||||
"gorm.io/gorm/logger"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
@ -40,17 +41,17 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||
dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
|
||||
}
|
||||
db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{})
|
||||
case "mssql":
|
||||
case "sqlserver":
|
||||
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
|
||||
// CREATE DATABASE gorm;
|
||||
// USE gorm;
|
||||
// CREATE USER gorm FROM LOGIN gorm;
|
||||
// sp_changedbowner 'gorm';
|
||||
log.Println("testing mssql...")
|
||||
log.Println("testing sqlserver...")
|
||||
if dbDSN == "" {
|
||||
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||
}
|
||||
db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{})
|
||||
db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{})
|
||||
default:
|
||||
log.Println("testing sqlite3...")
|
||||
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
|
||||
|
@ -90,8 +91,3 @@ func RunMigrations() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Now() *time.Time {
|
||||
now := time.Now()
|
||||
return &now
|
||||
}
|
|
@ -6,7 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestUpdateBelongsTo(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestUpdateHasManyAssociations(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestUpdateHasOne(t *testing.T) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package tests_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestUpdateMany2ManyAssociations(t *testing.T) {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestUpsert(t *testing.T) {
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||
for _, name := range names {
|
||||
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
||||
expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, got, expect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func AssertEqual(t *testing.T, got, expect interface{}) {
|
||||
if !reflect.DeepEqual(got, expect) {
|
||||
isEqual := func() {
|
||||
if curTime, ok := got.(time.Time); ok {
|
||||
format := "2006-01-02T15:04:05Z07:00"
|
||||
|
||||
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) {
|
||||
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
|
||||
}
|
||||
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
|
||||
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
|
||||
}
|
||||
}
|
||||
|
||||
if fmt.Sprint(got) == fmt.Sprint(expect) {
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() {
|
||||
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||
return
|
||||
}
|
||||
|
||||
if valuer, ok := got.(driver.Valuer); ok {
|
||||
got, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if valuer, ok := expect.(driver.Valuer); ok {
|
||||
expect, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if got != nil {
|
||||
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
|
||||
}
|
||||
|
||||
if expect != nil {
|
||||
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
|
||||
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Kind() == reflect.Slice {
|
||||
if reflect.ValueOf(expect).Kind() == reflect.Slice {
|
||||
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
|
||||
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
|
||||
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
|
||||
})
|
||||
}
|
||||
} else {
|
||||
name := reflect.ValueOf(got).Type().Elem().Name()
|
||||
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Kind() == reflect.Struct {
|
||||
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
|
||||
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
|
||||
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
field := reflect.ValueOf(got).Field(i)
|
||||
t.Run(fieldStruct.Name, func(t *testing.T) {
|
||||
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||
isEqual()
|
||||
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
|
||||
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
|
||||
isEqual()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Now() *time.Time {
|
||||
now := time.Now()
|
||||
return &now
|
||||
}
|
Loading…
Reference in New Issue