mirror of https://github.com/go-gorm/gorm.git
Fix tests with mysql, postgres
This commit is contained in:
parent
af080e6773
commit
f7f633590f
|
@ -4,7 +4,12 @@ import (
|
|||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB) {
|
||||
type Config struct {
|
||||
LastInsertIDReversed bool
|
||||
WithReturning bool
|
||||
}
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
enableTransaction := func(db *gorm.DB) bool {
|
||||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
@ -13,7 +18,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
|
|||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||
createCallback.Register("gorm:create", Create)
|
||||
createCallback.Register("gorm:create", Create(config))
|
||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
|
@ -43,32 +44,113 @@ func BeforeCreate(db *gorm.DB) {
|
|||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
if config.WithReturning {
|
||||
return CreateWithReturning
|
||||
} else {
|
||||
return func(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID--
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID++
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CreateWithReturning(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID--
|
||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||
db.Statement.WriteString(" RETURNING ")
|
||||
|
||||
var (
|
||||
idx int
|
||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||
)
|
||||
|
||||
for dbName, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx != 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
fields[idx] = field
|
||||
db.Statement.WriteQuoted(dbName)
|
||||
idx++
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for rows.Next() {
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
}
|
||||
case reflect.Struct:
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.RowsAffected++
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
err = rows.Scan(values...)
|
||||
}
|
||||
}
|
||||
}
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector {
|
|||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
|
||||
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func Open(dsn string) gorm.Dialector {
|
|||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
|
||||
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -25,7 +25,9 @@ func Open(dsn string) gorm.Dialector {
|
|||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
WithReturning: true,
|
||||
})
|
||||
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
||||
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
|
||||
if os.Getenv("GORM_DSN") != "" {
|
||||
dsn = os.Getenv("GORM_DSN")
|
||||
}
|
||||
|
|
|
@ -22,7 +22,9 @@ func Open(dsn string) gorm.Dialector {
|
|||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -32,15 +32,15 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||
|
||||
// check fields
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64},
|
||||
{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
|
||||
{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
|
||||
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
|
||||
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
||||
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
|
||||
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64},
|
||||
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
||||
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
|
||||
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint},
|
||||
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64},
|
||||
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64},
|
||||
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||
JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{
|
||||
{
|
||||
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
|
||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true,
|
||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
|
||||
},
|
||||
{
|
||||
Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String,
|
||||
|
@ -97,11 +97,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||
JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{
|
||||
{
|
||||
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
|
||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true,
|
||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
|
||||
},
|
||||
{
|
||||
Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint,
|
||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true,
|
||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
|
||||
},
|
||||
}},
|
||||
References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}},
|
||||
|
@ -124,7 +124,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
|||
|
||||
// check fields
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true},
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64},
|
||||
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
||||
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
||||
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
|
||||
|
|
|
@ -91,6 +91,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
writer.WriteString(" AS ")
|
||||
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
|
||||
}
|
||||
case string:
|
||||
stmt.DB.Dialector.QuoteTo(writer, v)
|
||||
default:
|
||||
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ services:
|
|||
ports:
|
||||
- 9920:5432
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- POSTGRES_DB=gorm
|
||||
- POSTGRES_USER=gorm
|
||||
- POSTGRES_PASSWORD=gorm
|
||||
|
|
|
@ -37,7 +37,7 @@ func TestCreate(t *testing.T, db *gorm.DB) {
|
|||
}
|
||||
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
t.Errorf("errors happened when create: %v", err)
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
|
@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) {
|
|||
}}
|
||||
|
||||
if err := db.Create(&users).Error; err != nil {
|
||||
t.Fatal("errors happened when create users: %v", err)
|
||||
t.Fatalf("errors happened when create users: %v", err)
|
||||
}
|
||||
|
||||
t.Run("First", func(t *testing.T) {
|
||||
|
@ -195,11 +195,11 @@ func TestUpdate(t *testing.T, db *gorm.DB) {
|
|||
}
|
||||
|
||||
if err := db.Create(&users).Error; err != nil {
|
||||
t.Errorf("errors happened when create: %v", err)
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
} else if user.ID == 0 {
|
||||
t.Errorf("user's primary value should not zero, %v", user.ID)
|
||||
t.Fatalf("user's primary value should not zero, %v", user.ID)
|
||||
} else if user.UpdatedAt.IsZero() {
|
||||
t.Errorf("user's updated at should not zero, %v", user.UpdatedAt)
|
||||
t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt)
|
||||
}
|
||||
lastUpdatedAt = user.UpdatedAt
|
||||
|
||||
|
@ -297,7 +297,7 @@ func TestDelete(t *testing.T, db *gorm.DB) {
|
|||
|
||||
for _, user := range users {
|
||||
if user.ID == 0 {
|
||||
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
|
||||
t.Fatalf("user's primary key should has value after create, got : %v", user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue