Fix tests with mysql, postgres

This commit is contained in:
Jinzhu 2020-03-12 13:05:22 +08:00
parent af080e6773
commit f7f633590f
11 changed files with 129 additions and 35 deletions

View File

@ -4,7 +4,12 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
func RegisterDefaultCallbacks(db *gorm.DB) { type Config struct {
LastInsertIDReversed bool
WithReturning bool
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
enableTransaction := func(db *gorm.DB) bool { enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction return !db.SkipDefaultTransaction
} }
@ -13,7 +18,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
createCallback.Register("gorm:create", Create) createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

View File

@ -5,6 +5,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
) )
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
@ -43,32 +44,113 @@ func BeforeCreate(db *gorm.DB) {
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
} }
func Create(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning {
return CreateWithReturning
} else {
return func(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
if db.Statement.Schema != nil {
if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
insertID--
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
insertID++
}
}
case reflect.Struct:
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
} else {
db.AddError(err)
}
}
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func CreateWithReturning(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{ db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table}, Table: clause.Table{Name: db.Statement.Table},
}) })
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
if db.Statement.Schema != nil { db.Statement.WriteString(" RETURNING ")
if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() { var (
case reflect.Slice, reflect.Array: idx int
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
insertID-- )
for dbName, field := range sch.FieldsWithDefaultDBValue {
if idx != 0 {
db.Statement.WriteByte(',')
}
fields[idx] = field
db.Statement.WriteQuoted(dbName)
idx++
}
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for rows.Next() {
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
} }
case reflect.Struct: if err := rows.Scan(values...); err != nil {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) db.AddError(err)
}
db.RowsAffected++
}
case reflect.Struct:
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
err = rows.Scan(values...)
} }
} }
} }
db.RowsAffected, _ = result.RowsAffected()
if err != nil {
db.AddError(err)
}
} else { } else {
db.AddError(err) if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
} }
} }

View File

@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
return return
} }

View File

@ -24,7 +24,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
db.ConnPool, err = sql.Open("mysql", dialector.DSN) db.ConnPool, err = sql.Open("mysql", dialector.DSN)
return return
} }

View File

@ -25,7 +25,9 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
WithReturning: true,
})
db.ConnPool, err = sql.Open("postgres", dialector.DSN) db.ConnPool, err = sql.Open("postgres", dialector.DSN)
return return
} }

View File

@ -16,7 +16,7 @@ var (
) )
func init() { func init() {
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
if os.Getenv("GORM_DSN") != "" { if os.Getenv("GORM_DSN") != "" {
dsn = os.Getenv("GORM_DSN") dsn = os.Getenv("GORM_DSN")
} }

View File

@ -22,7 +22,9 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
LastInsertIDReversed: true,
})
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
return return
} }

View File

@ -32,15 +32,15 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
// check fields // check fields
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64},
{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64},
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64},
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
} }
@ -83,7 +83,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{
{ {
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
}, },
{ {
Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String,
@ -97,11 +97,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{
{ {
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
}, },
{ {
Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint,
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
}, },
}}, }},
References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}},
@ -124,7 +124,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
// check fields // check fields
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true}, {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},

View File

@ -91,6 +91,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteString(" AS ") writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias) stmt.DB.Dialector.QuoteTo(writer, v.Alias)
} }
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
default: default:
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
} }

View File

@ -15,6 +15,7 @@ services:
ports: ports:
- 9920:5432 - 9920:5432
environment: environment:
- TZ=Asia/Shanghai
- POSTGRES_DB=gorm - POSTGRES_DB=gorm
- POSTGRES_USER=gorm - POSTGRES_USER=gorm
- POSTGRES_PASSWORD=gorm - POSTGRES_PASSWORD=gorm

View File

@ -37,7 +37,7 @@ func TestCreate(t *testing.T, db *gorm.DB) {
} }
if err := db.Create(&user).Error; err != nil { if err := db.Create(&user).Error; err != nil {
t.Errorf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
} }
if user.ID == 0 { if user.ID == 0 {
@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) {
}} }}
if err := db.Create(&users).Error; err != nil { if err := db.Create(&users).Error; err != nil {
t.Fatal("errors happened when create users: %v", err) t.Fatalf("errors happened when create users: %v", err)
} }
t.Run("First", func(t *testing.T) { t.Run("First", func(t *testing.T) {
@ -195,11 +195,11 @@ func TestUpdate(t *testing.T, db *gorm.DB) {
} }
if err := db.Create(&users).Error; err != nil { if err := db.Create(&users).Error; err != nil {
t.Errorf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
} else if user.ID == 0 { } else if user.ID == 0 {
t.Errorf("user's primary value should not zero, %v", user.ID) t.Fatalf("user's primary value should not zero, %v", user.ID)
} else if user.UpdatedAt.IsZero() { } else if user.UpdatedAt.IsZero() {
t.Errorf("user's updated at should not zero, %v", user.UpdatedAt) t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt)
} }
lastUpdatedAt = user.UpdatedAt lastUpdatedAt = user.UpdatedAt
@ -297,7 +297,7 @@ func TestDelete(t *testing.T, db *gorm.DB) {
for _, user := range users { for _, user := range users {
if user.ID == 0 { if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID) t.Fatalf("user's primary key should has value after create, got : %v", user.ID)
} }
} }