From f7f633590fefb3a503a4cbda894787d8a11b2540 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2020 13:05:22 +0800 Subject: [PATCH] Fix tests with mysql, postgres --- callbacks/callbacks.go | 9 ++- callbacks/create.go | 110 +++++++++++++++++++++++++---- dialects/mssql/mssql.go | 2 +- dialects/mysql/mysql.go | 2 +- dialects/postgres/postgres.go | 4 +- dialects/postgres/postgres_test.go | 2 +- dialects/sqlite/sqlite.go | 4 +- schema/schema_test.go | 16 ++--- statement.go | 2 + tests/docker-compose.yml | 1 + tests/tests.go | 12 ++-- 11 files changed, 129 insertions(+), 35 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 0a48ada6..1985aec2 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -4,7 +4,12 @@ import ( "github.com/jinzhu/gorm" ) -func RegisterDefaultCallbacks(db *gorm.DB) { +type Config struct { + LastInsertIDReversed bool + WithReturning bool +} + +func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { enableTransaction := func(db *gorm.DB) bool { return !db.SkipDefaultTransaction } @@ -13,7 +18,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) - createCallback.Register("gorm:create", Create) + createCallback.Register("gorm:create", Create(config)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/create.go b/callbacks/create.go index 42dcda27..3f6a81e4 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func BeforeCreate(db *gorm.DB) { @@ -43,32 +44,113 @@ func BeforeCreate(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) { } -func Create(db *gorm.DB) { +func Create(config *Config) func(db *gorm.DB) { + if config.WithReturning { + return CreateWithReturning + } else { + return func(db *gorm.DB) { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + if db.Statement.Schema != nil { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } + } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } else { + db.AddError(err) + } + } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func CreateWithReturning(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{ Table: clause.Table{Name: db.Statement.Table}, }) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - if db.Statement.Schema != nil { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- + if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { + db.Statement.WriteString(" RETURNING ") + + var ( + idx int + fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) + values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) + ) + + for dbName, field := range sch.FieldsWithDefaultDBValue { + if idx != 0 { + db.Statement.WriteByte(',') + } + + fields[idx] = field + db.Statement.WriteQuoted(dbName) + idx++ + } + + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + db.RowsAffected++ + } + case reflect.Struct: + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + err = rows.Scan(values...) } } } - db.RowsAffected, _ = result.RowsAffected() + + if err != nil { + db.AddError(err) + } } else { - db.AddError(err) + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index e5bc7dd2..ad6782c7 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) return } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index af796847..7b8f0491 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -24,7 +24,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) db.ConnPool, err = sql.Open("mysql", dialector.DSN) return } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 7589025d..73a19e9d 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -25,7 +25,9 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + WithReturning: true, + }) db.ConnPool, err = sql.Open("postgres", dialector.DSN) return } diff --git a/dialects/postgres/postgres_test.go b/dialects/postgres/postgres_test.go index a1252d92..2185c19c 100644 --- a/dialects/postgres/postgres_test.go +++ b/dialects/postgres/postgres_test.go @@ -16,7 +16,7 @@ var ( ) func init() { - dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" if os.Getenv("GORM_DSN") != "" { dsn = os.Getenv("GORM_DSN") } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index c4837463..51829b17 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -22,7 +22,9 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks - callbacks.RegisterDefaultCallbacks(db) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + LastInsertIDReversed: true, + }) db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) return } diff --git a/schema/schema_test.go b/schema/schema_test.go index ce225010..7d13e614 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,15 +32,15 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, - {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, + {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, - {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } @@ -83,7 +83,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, }, { Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, @@ -97,11 +97,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, }, { Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, - Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, }, }}, References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, @@ -124,7 +124,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, diff --git a/statement.go b/statement.go index 298a4c56..e45bd8bb 100644 --- a/statement.go +++ b/statement.go @@ -91,6 +91,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteString(" AS ") stmt.DB.Dialector.QuoteTo(writer, v.Alias) } + case string: + stmt.DB.Dialector.QuoteTo(writer, v) default: stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 6bf3fadf..05e0956e 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -15,6 +15,7 @@ services: ports: - 9920:5432 environment: + - TZ=Asia/Shanghai - POSTGRES_DB=gorm - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm diff --git a/tests/tests.go b/tests/tests.go index c26d743e..aa48f699 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -37,7 +37,7 @@ func TestCreate(t *testing.T, db *gorm.DB) { } if err := db.Create(&user).Error; err != nil { - t.Errorf("errors happened when create: %v", err) + t.Fatalf("errors happened when create: %v", err) } if user.ID == 0 { @@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) { }} if err := db.Create(&users).Error; err != nil { - t.Fatal("errors happened when create users: %v", err) + t.Fatalf("errors happened when create users: %v", err) } t.Run("First", func(t *testing.T) { @@ -195,11 +195,11 @@ func TestUpdate(t *testing.T, db *gorm.DB) { } if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) + t.Fatalf("errors happened when create: %v", err) } else if user.ID == 0 { - t.Errorf("user's primary value should not zero, %v", user.ID) + t.Fatalf("user's primary value should not zero, %v", user.ID) } else if user.UpdatedAt.IsZero() { - t.Errorf("user's updated at should not zero, %v", user.UpdatedAt) + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) } lastUpdatedAt = user.UpdatedAt @@ -297,7 +297,7 @@ func TestDelete(t *testing.T, db *gorm.DB) { for _, user := range users { if user.ID == 0 { - t.Errorf("user's primary key should has value after create, got : %v", user.ID) + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) } }