mirror of https://github.com/go-gorm/gorm.git
Fix tests with mysql, postgres
This commit is contained in:
parent
af080e6773
commit
f7f633590f
|
@ -4,7 +4,12 @@ import (
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterDefaultCallbacks(db *gorm.DB) {
|
type Config struct {
|
||||||
|
LastInsertIDReversed bool
|
||||||
|
WithReturning bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||||
enableTransaction := func(db *gorm.DB) bool {
|
enableTransaction := func(db *gorm.DB) bool {
|
||||||
return !db.SkipDefaultTransaction
|
return !db.SkipDefaultTransaction
|
||||||
}
|
}
|
||||||
|
@ -13,7 +18,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
|
||||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
|
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||||
createCallback.Register("gorm:create", Create)
|
createCallback.Register("gorm:create", Create(config))
|
||||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
|
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||||
createCallback.Register("gorm:after_create", AfterCreate)
|
createCallback.Register("gorm:after_create", AfterCreate)
|
||||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeCreate(db *gorm.DB) {
|
func BeforeCreate(db *gorm.DB) {
|
||||||
|
@ -43,7 +44,11 @@ func BeforeCreate(db *gorm.DB) {
|
||||||
func SaveBeforeAssociations(db *gorm.DB) {
|
func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(db *gorm.DB) {
|
func Create(config *Config) func(db *gorm.DB) {
|
||||||
|
if config.WithReturning {
|
||||||
|
return CreateWithReturning
|
||||||
|
} else {
|
||||||
|
return func(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
Table: clause.Table{Name: db.Statement.Table},
|
||||||
})
|
})
|
||||||
|
@ -57,13 +62,22 @@ func Create(db *gorm.DB) {
|
||||||
if insertID, err := result.LastInsertId(); err == nil {
|
if insertID, err := result.LastInsertId(); err == nil {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
if config.LastInsertIDReversed {
|
||||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||||
insertID--
|
insertID--
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||||
|
insertID++
|
||||||
|
}
|
||||||
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
@ -71,6 +85,74 @@ func Create(db *gorm.DB) {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateWithReturning(db *gorm.DB) {
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||||
|
Table: clause.Table{Name: db.Statement.Table},
|
||||||
|
})
|
||||||
|
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||||
|
|
||||||
|
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
||||||
|
|
||||||
|
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||||
|
db.Statement.WriteString(" RETURNING ")
|
||||||
|
|
||||||
|
var (
|
||||||
|
idx int
|
||||||
|
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||||
|
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||||
|
)
|
||||||
|
|
||||||
|
for dbName, field := range sch.FieldsWithDefaultDBValue {
|
||||||
|
if idx != 0 {
|
||||||
|
db.Statement.WriteByte(',')
|
||||||
|
}
|
||||||
|
|
||||||
|
fields[idx] = field
|
||||||
|
db.Statement.WriteQuoted(dbName)
|
||||||
|
idx++
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for rows.Next() {
|
||||||
|
for idx, field := range fields {
|
||||||
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||||
|
}
|
||||||
|
if err := rows.Scan(values...); err != nil {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
db.RowsAffected++
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
for idx, field := range fields {
|
||||||
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Next() {
|
||||||
|
err = rows.Scan(values...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func SaveAfterAssociations(db *gorm.DB) {
|
func SaveAfterAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector {
|
||||||
|
|
||||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||||
// register callbacks
|
// register callbacks
|
||||||
callbacks.RegisterDefaultCallbacks(db)
|
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
|
||||||
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
|
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ func Open(dsn string) gorm.Dialector {
|
||||||
|
|
||||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||||
// register callbacks
|
// register callbacks
|
||||||
callbacks.RegisterDefaultCallbacks(db)
|
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
|
||||||
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
|
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,9 @@ func Open(dsn string) gorm.Dialector {
|
||||||
|
|
||||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||||
// register callbacks
|
// register callbacks
|
||||||
callbacks.RegisterDefaultCallbacks(db)
|
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||||
|
WithReturning: true,
|
||||||
|
})
|
||||||
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
|
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
|
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
|
||||||
if os.Getenv("GORM_DSN") != "" {
|
if os.Getenv("GORM_DSN") != "" {
|
||||||
dsn = os.Getenv("GORM_DSN")
|
dsn = os.Getenv("GORM_DSN")
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,9 @@ func Open(dsn string) gorm.Dialector {
|
||||||
|
|
||||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||||
// register callbacks
|
// register callbacks
|
||||||
callbacks.RegisterDefaultCallbacks(db)
|
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||||
|
LastInsertIDReversed: true,
|
||||||
|
})
|
||||||
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
|
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,15 +32,15 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||||
|
|
||||||
// check fields
|
// check fields
|
||||||
fields := []schema.Field{
|
fields := []schema.Field{
|
||||||
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}},
|
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64},
|
||||||
{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
|
{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
|
||||||
{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
|
{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
|
||||||
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
|
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
|
||||||
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
||||||
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
|
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64},
|
||||||
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
||||||
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
|
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64},
|
||||||
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint},
|
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64},
|
||||||
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||||
JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{
|
JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{
|
||||||
{
|
{
|
||||||
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
|
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
|
||||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true,
|
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String,
|
Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String,
|
||||||
|
@ -97,11 +97,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||||
JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{
|
JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{
|
||||||
{
|
{
|
||||||
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
|
Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint,
|
||||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true,
|
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint,
|
Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint,
|
||||||
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true,
|
Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64,
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}},
|
References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}},
|
||||||
|
@ -124,7 +124,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||||
|
|
||||||
// check fields
|
// check fields
|
||||||
fields := []schema.Field{
|
fields := []schema.Field{
|
||||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true},
|
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64},
|
||||||
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
||||||
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
||||||
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
|
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
|
||||||
|
|
|
@ -91,6 +91,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||||
writer.WriteString(" AS ")
|
writer.WriteString(" AS ")
|
||||||
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
|
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
|
||||||
}
|
}
|
||||||
|
case string:
|
||||||
|
stmt.DB.Dialector.QuoteTo(writer, v)
|
||||||
default:
|
default:
|
||||||
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ services:
|
||||||
ports:
|
ports:
|
||||||
- 9920:5432
|
- 9920:5432
|
||||||
environment:
|
environment:
|
||||||
|
- TZ=Asia/Shanghai
|
||||||
- POSTGRES_DB=gorm
|
- POSTGRES_DB=gorm
|
||||||
- POSTGRES_USER=gorm
|
- POSTGRES_USER=gorm
|
||||||
- POSTGRES_PASSWORD=gorm
|
- POSTGRES_PASSWORD=gorm
|
||||||
|
|
|
@ -37,7 +37,7 @@ func TestCreate(t *testing.T, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.Create(&user).Error; err != nil {
|
if err := db.Create(&user).Error; err != nil {
|
||||||
t.Errorf("errors happened when create: %v", err)
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.ID == 0 {
|
if user.ID == 0 {
|
||||||
|
@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
if err := db.Create(&users).Error; err != nil {
|
if err := db.Create(&users).Error; err != nil {
|
||||||
t.Fatal("errors happened when create users: %v", err)
|
t.Fatalf("errors happened when create users: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("First", func(t *testing.T) {
|
t.Run("First", func(t *testing.T) {
|
||||||
|
@ -195,11 +195,11 @@ func TestUpdate(t *testing.T, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.Create(&users).Error; err != nil {
|
if err := db.Create(&users).Error; err != nil {
|
||||||
t.Errorf("errors happened when create: %v", err)
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
} else if user.ID == 0 {
|
} else if user.ID == 0 {
|
||||||
t.Errorf("user's primary value should not zero, %v", user.ID)
|
t.Fatalf("user's primary value should not zero, %v", user.ID)
|
||||||
} else if user.UpdatedAt.IsZero() {
|
} else if user.UpdatedAt.IsZero() {
|
||||||
t.Errorf("user's updated at should not zero, %v", user.UpdatedAt)
|
t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt)
|
||||||
}
|
}
|
||||||
lastUpdatedAt = user.UpdatedAt
|
lastUpdatedAt = user.UpdatedAt
|
||||||
|
|
||||||
|
@ -297,7 +297,7 @@ func TestDelete(t *testing.T, db *gorm.DB) {
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if user.ID == 0 {
|
if user.ID == 0 {
|
||||||
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
|
t.Fatalf("user's primary key should has value after create, got : %v", user.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue