From f0a442adff91e70a5f85cb50b4dc27bd3c189714 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 23 May 2020 23:50:48 +0800 Subject: [PATCH] Refactor tests --- callbacks/helper.go | 4 +- finisher_api.go | 3 + logger/sql.go | 2 +- tests/associations.go | 73 ----- tests/associations_test.go | 24 ++ tests/create.go | 188 ------------- tests/delete.go | 64 ----- tests/delete_test.go | 48 ++++ tests/group_by.go | 62 ----- tests/group_by_test.go | 57 ++++ tests/joins.go | 81 ------ tests/joins_test.go | 55 ++++ tests/{migrate.go => migrate_test.go} | 12 +- tests/query.go | 95 ------- tests/query_test.go | 82 ++++++ tests/update.go | 382 -------------------------- tests/update_test.go | 226 +++++++++++++++ tests/utils.go | 232 +++++++++++++++- 18 files changed, 734 insertions(+), 956 deletions(-) delete mode 100644 tests/associations.go create mode 100644 tests/associations_test.go delete mode 100644 tests/create.go delete mode 100644 tests/delete.go create mode 100644 tests/delete_test.go delete mode 100644 tests/group_by.go create mode 100644 tests/group_by_test.go delete mode 100644 tests/joins.go create mode 100644 tests/joins_test.go rename tests/{migrate.go => migrate_test.go} (67%) delete mode 100644 tests/query.go create mode 100644 tests/query_test.go delete mode 100644 tests/update.go create mode 100644 tests/update_test.go diff --git a/callbacks/helper.go b/callbacks/helper.go index 092c9c37..43e90b8a 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -10,10 +10,12 @@ import ( // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} + notRestricted := false // select columns for _, column := range stmt.Selects { if column == "*" { + notRestricted = true for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } @@ -51,7 +53,7 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo } } - return results, len(stmt.Selects) > 0 + return results, !notRestricted && len(stmt.Selects) > 0 } // ConvertMapToValuesForCreate convert map to values diff --git a/finisher_api.go b/finisher_api.go index 9e29e327..1b2a7e29 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -35,6 +35,9 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.AddClause(where) } + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } tx.callbacks.Update().Execute(tx) return } diff --git a/logger/sql.go b/logger/sql.go index 9c0f54d7..219ae301 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -53,7 +53,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } else { rv := reflect.ValueOf(v) - if !rv.IsValid() { + if !rv.IsValid() || rv.IsNil() { vars[idx] = "NULL" } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) diff --git a/tests/associations.go b/tests/associations.go deleted file mode 100644 index 7e93e81e..00000000 --- a/tests/associations.go +++ /dev/null @@ -1,73 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestAssociations(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - - TestBelongsToAssociations(t, db) -} - -func TestBelongsToAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User, old User) { - if old.Company.Name != "" { - if user.CompanyID == nil { - t.Errorf("Company's foreign key should be saved") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != old.Company.Name { - t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) - } else if user.Company.Name != old.Company.Name { - t.Errorf("Company's name should be same, expects: %v, got %v", old.Company.Name, user.Company.Name) - } - } - } else if user.CompanyID != nil { - t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) - } - - if old.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != user.Manager.Name { - t.Errorf("Manager's name should be same") - } else if user.Manager.Name != old.Manager.Name { - t.Errorf("Manager's name should be same") - } - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - } - - t.Run("BelongsTo", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - Company: Company{Name: "company-belongs-to-association"}, - Manager: &User{Name: "manager-belongs-to-association"}, - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user, user) - - var user2 User - db.Find(&user2, "id = ?", user.ID) - db.Model(&user2).Association("Company").Find(&user2.Company) - user2.Manager = &User{} - db.Model(&user2).Association("Manager").Find(user2.Manager) - check(t, user2, user) - }) -} diff --git a/tests/associations_test.go b/tests/associations_test.go new file mode 100644 index 00000000..dc88ee03 --- /dev/null +++ b/tests/associations_test.go @@ -0,0 +1,24 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestAssociationForBelongsTo(t *testing.T) { + var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Company").Find(&user2.Company) + user2.Manager = &User{} + DB.Model(&user2).Association("Manager").Find(user2.Manager) + CheckUser(t, user2, user) +} diff --git a/tests/create.go b/tests/create.go deleted file mode 100644 index 6e5dd2c5..00000000 --- a/tests/create.go +++ /dev/null @@ -1,188 +0,0 @@ -package tests - -import ( - "strconv" - "testing" - "time" -) - -type Config struct { - Account bool - Pets int - Toys int - Company bool - Manager bool - Team int - Languages int - Friends int -} - -func GetUser(name string, config Config) *User { - var ( - birthday = time.Now() - user = User{ - Name: name, - Age: 18, - Birthday: &birthday, - } - ) - - if config.Account { - user.Account = Account{Number: name + "_account"} - } - - for i := 0; i < config.Pets; i++ { - user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) - } - - for i := 0; i < config.Toys; i++ { - user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) - } - - if config.Company { - user.Company = Company{Name: "company-" + name} - } - - if config.Manager { - user.Manager = GetUser(name+"_manager", Config{}) - } - - for i := 0; i < config.Team; i++ { - user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) - } - - for i := 0; i < config.Languages; i++ { - name := name + "_locale_" + strconv.Itoa(i+1) - language := Language{Code: name, Name: name} - DB.Create(&language) - user.Languages = append(user.Languages, language) - } - - for i := 0; i < config.Friends; i++ { - user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) - } - - return &user -} - -func CheckPet(t *testing.T, pet Pet, expect Pet) { - if pet.ID != 0 { - var newPet Pet - if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { - t.Fatalf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") - } - } - - AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") - - AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") - - if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { - t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) - } -} - -func CheckUser(t *testing.T, user User, expect User) { - if user.ID != 0 { - var newUser User - if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { - t.Fatalf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - } - - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - - t.Run("Account", func(t *testing.T) { - AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") - - if user.Account.Number != "" { - if !user.Account.UserID.Valid { - t.Errorf("Account's foreign key should be saved") - } else { - var account Account - DB.First(&account, "user_id = ?", user.ID) - AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") - } - } - }) - - t.Run("Pets", func(t *testing.T) { - if len(user.Pets) != len(expect.Pets) { - t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) - } - - for idx, pet := range user.Pets { - if pet == nil || expect.Pets[idx] == nil { - t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) - } else { - CheckPet(t, *pet, *expect.Pets[idx]) - } - } - }) - - t.Run("Toys", func(t *testing.T) { - if len(user.Toys) != len(expect.Toys) { - t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) - } - - for idx, toy := range user.Toys { - if toy.OwnerType != "users" { - t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) - } - - AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") - } - }) - - t.Run("Company", func(t *testing.T) { - AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") - }) - - t.Run("Manager", func(t *testing.T) { - if user.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - DB.First(&manager, "id = ?", *user.ManagerID) - AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - }) - - t.Run("Team", func(t *testing.T) { - if len(user.Team) != len(expect.Team) { - t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) - } - - for idx, team := range user.Team { - AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - }) - - t.Run("Languages", func(t *testing.T) { - if len(user.Languages) != len(expect.Languages) { - t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) - } - - for idx, language := range user.Languages { - AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") - } - }) - - t.Run("Friends", func(t *testing.T) { - if len(user.Friends) != len(expect.Friends) { - t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) - } - - for idx, friend := range user.Friends { - AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - } - }) -} diff --git a/tests/delete.go b/tests/delete.go deleted file mode 100644 index 45701ff0..00000000 --- a/tests/delete.go +++ /dev/null @@ -1,64 +0,0 @@ -package tests - -import ( - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestDelete(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Delete", func(t *testing.T) { - var users = []User{{ - Name: "find", - Age: 1, - Birthday: Now(), - }, { - Name: "find", - Age: 2, - Birthday: Now(), - }, { - Name: "find", - Age: 3, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) - } - - for _, user := range users { - if user.ID == 0 { - t.Fatalf("user's primary key should has value after create, got : %v", user.ID) - } - } - - if err := db.Delete(&users[1]).Error; err != nil { - t.Errorf("errors happened when delete: %v", err) - } - - var result User - if err := db.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { - t.Errorf("should returns record not found error, but got %v", err) - } - - for _, user := range []User{users[0], users[2]} { - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("no error should returns when query %v, but got %v", user.ID, err) - } - } - - if err := db.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { - t.Errorf("should returns missing WHERE clause while deleting error") - } - - for _, user := range []User{users[0], users[2]} { - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("no error should returns when query %v, but got %v", user.ID, err) - } - } - }) -} diff --git a/tests/delete_test.go b/tests/delete_test.go new file mode 100644 index 00000000..8be072d3 --- /dev/null +++ b/tests/delete_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "errors" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestDelete(t *testing.T) { + var users = []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if err := DB.Delete(&users[1]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + var result User + if err := DB.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + for _, user := range []User{users[0], users[2]} { + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } +} diff --git a/tests/group_by.go b/tests/group_by.go deleted file mode 100644 index b0bb4155..00000000 --- a/tests/group_by.go +++ /dev/null @@ -1,62 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestGroupBy(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("GroupBy", func(t *testing.T) { - var users = []User{{ - Name: "groupby", - Age: 10, - Birthday: Now(), - }, { - Name: "groupby", - Age: 20, - Birthday: Now(), - }, { - Name: "groupby", - Age: 30, - Birthday: Now(), - }, { - Name: "groupby1", - Age: 110, - Birthday: Now(), - }, { - Name: "groupby1", - Age: 220, - Birthday: Now(), - }, { - Name: "groupby1", - Age: 330, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create: %v", err) - } - - var name string - var total int - if err := db.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if name != "groupby" || total != 60 { - t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) - } - - if err := db.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if name != "groupby1" || total != 660 { - t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) - } - }) -} diff --git a/tests/group_by_test.go b/tests/group_by_test.go new file mode 100644 index 00000000..66a733aa --- /dev/null +++ b/tests/group_by_test.go @@ -0,0 +1,57 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestGroupBy(t *testing.T) { + var users = []User{{ + Name: "groupby", + Age: 10, + Birthday: Now(), + }, { + Name: "groupby", + Age: 20, + Birthday: Now(), + }, { + Name: "groupby", + Age: 30, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 110, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 220, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 330, + Birthday: Now(), + }} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + var name string + var total int + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby1" || total != 660 { + t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) + } +} diff --git a/tests/joins.go b/tests/joins.go deleted file mode 100644 index 86f9f104..00000000 --- a/tests/joins.go +++ /dev/null @@ -1,81 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestJoins(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}, &Account{}, &Company{}) - db.AutoMigrate(&User{}, &Account{}, &Company{}) - - check := func(t *testing.T, oldUser, newUser User) { - if newUser.Company.ID != oldUser.Company.ID { - t.Errorf("Company is not equal when load with joins, loaded company id: %v", newUser.Company.ID) - } - - if newUser.Manager == nil || newUser.Manager.ID != oldUser.Manager.ID { - t.Errorf("Manager is not equal when load with joins: loaded manager: %+v", newUser.Manager) - } - - if newUser.Account.ID != oldUser.Account.ID { - t.Errorf("Account is not equal when load with joins, loaded account id: %v, expect: %v", newUser.Account.ID, oldUser.Account.ID) - } - } - - t.Run("Joins", func(t *testing.T) { - user := User{ - Name: "joins-1", - Company: Company{Name: "company"}, - Manager: &User{Name: "manager"}, - Account: Account{Number: "account-has-one-association"}, - } - - db.Create(&user) - - var user2 User - if err := db.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { - t.Fatalf("Failed to load with joins, got error: %v", err) - } - - check(t, user, user2) - }) - - t.Run("JoinsForSlice", func(t *testing.T) { - users := []User{{ - Name: "slice-joins-1", - Company: Company{Name: "company"}, - Manager: &User{Name: "manager"}, - Account: Account{Number: "account-has-one-association"}, - }, { - Name: "slice-joins-2", - Company: Company{Name: "company2"}, - Manager: &User{Name: "manager2"}, - Account: Account{Number: "account-has-one-association2"}, - }, { - Name: "slice-joins-3", - Company: Company{Name: "company3"}, - Manager: &User{Name: "manager3"}, - Account: Account{Number: "account-has-one-association3"}, - }} - - db.Create(&users) - - var users2 []User - if err := db.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.name LIKE ?", "slice-joins%").Error; err != nil { - t.Fatalf("Failed to load with joins, got error: %v", err) - } else if len(users2) != len(users) { - t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) - } - - for _, u2 := range users2 { - for _, u := range users { - if u.Name == u2.Name { - check(t, u, u2) - continue - } - } - } - }) -} diff --git a/tests/joins_test.go b/tests/joins_test.go new file mode 100644 index 00000000..556130ee --- /dev/null +++ b/tests/joins_test.go @@ -0,0 +1,55 @@ +package tests_test + +import ( + "sort" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestJoins(t *testing.T) { + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) + + DB.Create(&user) + + var user2 User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } + + CheckUser(t, user2, user) +} + +func TestJoinsForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-joins-1", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-2", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-3", Config{Company: true, Manager: true, Account: true}), + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + CheckUser(t, user, users2[idx]) + } +} diff --git a/tests/migrate.go b/tests/migrate_test.go similarity index 67% rename from tests/migrate.go rename to tests/migrate_test.go index fa8a89e8..917fba75 100644 --- a/tests/migrate.go +++ b/tests/migrate_test.go @@ -1,28 +1,28 @@ -package tests +package tests_test import ( "math/rand" "testing" "time" - "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" ) -func TestMigrate(t *testing.T, db *gorm.DB) { +func TestMigrate(t *testing.T) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - if err := db.Migrator().DropTable(allModels...); err != nil { + if err := DB.Migrator().DropTable(allModels...); err != nil { t.Errorf("Failed to drop table, got error %v", err) } - if err := db.AutoMigrate(allModels...); err != nil { + if err := DB.AutoMigrate(allModels...); err != nil { t.Errorf("Failed to auto migrate, but got error %v", err) } for _, m := range allModels { - if !db.Migrator().HasTable(m) { + if !DB.Migrator().HasTable(m) { t.Errorf("Failed to create table for %#v", m) } } diff --git a/tests/query.go b/tests/query.go deleted file mode 100644 index 5eabfb48..00000000 --- a/tests/query.go +++ /dev/null @@ -1,95 +0,0 @@ -package tests - -import ( - "reflect" - "strconv" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestFind(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Find", func(t *testing.T) { - var users = []User{{ - Name: "find", - Age: 1, - Birthday: Now(), - }, { - Name: "find", - Age: 2, - Birthday: Now(), - }, { - Name: "find", - Age: 3, - Birthday: Now(), - }} - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create users: %v", err) - } - - t.Run("First", func(t *testing.T) { - var first User - if err := db.Where("name = ?", "find").First(&first).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - AssertObjEqual(t, first, users[0], "Name", "Age", "Birthday") - } - }) - - t.Run("Last", func(t *testing.T) { - var last User - if err := db.Where("name = ?", "find").Last(&last).Error; err != nil { - t.Errorf("errors happened when query last: %v", err) - } else { - AssertObjEqual(t, last, users[2], "Name", "Age", "Birthday") - } - }) - - var all []User - if err := db.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { - t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) - } else { - for idx, user := range users { - t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { - AssertObjEqual(t, all[idx], user, "Name", "Age", "Birthday") - }) - } - } - - t.Run("FirstMap", func(t *testing.T) { - var first = map[string]interface{}{} - if err := db.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := db.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) - AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) - }) - } - } - }) - - var allMap = []map[string]interface{}{} - if err := db.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for idx, user := range users { - t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := db.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(user)) - AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) - }) - } - }) - } - } - }) -} diff --git a/tests/query_test.go b/tests/query_test.go new file mode 100644 index 00000000..4388066f --- /dev/null +++ b/tests/query_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "reflect" + "strconv" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestFind(t *testing.T) { + var users = []User{ + *GetUser("find", Config{}), + *GetUser("find", Config{}), + *GetUser("find", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := DB.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + CheckUser(t, first, users[0]) + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := DB.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + CheckUser(t, last, users[2]) + } + }) + + var all []User + if err := DB.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, all[idx], user) + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + var allMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } +} diff --git a/tests/update.go b/tests/update.go deleted file mode 100644 index 82a2dc8b..00000000 --- a/tests/update.go +++ /dev/null @@ -1,382 +0,0 @@ -package tests - -import ( - "fmt" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&User{}) - db.AutoMigrate(&User{}) - - t.Run("Update", func(t *testing.T) { - var ( - users = []*User{{ - Name: "update-before", - Age: 1, - Birthday: Now(), - }, { - Name: "update", - Age: 18, - Birthday: Now(), - }, { - Name: "update-after", - Age: 1, - Birthday: Now(), - }} - user = users[1] - lastUpdatedAt time.Time - ) - - checkUpdatedTime := func(name string, n time.Time) { - if n.UnixNano() == lastUpdatedAt.UnixNano() { - t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) - } - lastUpdatedAt = n - } - - checkOtherData := func(name string) { - var beforeUser, afterUser User - if err := db.Where("id = ?", users[0].ID).First(&beforeUser).Error; err != nil { - t.Errorf("errors happened when query before user: %v", err) - } - t.Run(name, func(t *testing.T) { - AssertObjEqual(t, beforeUser, users[0], "Name", "Age", "Birthday") - }) - - if err := db.Where("id = ?", users[2].ID).First(&afterUser).Error; err != nil { - t.Errorf("errors happened when query after user: %v", err) - } - t.Run(name, func(t *testing.T) { - AssertObjEqual(t, afterUser, users[2], "Name", "Age", "Birthday") - }) - } - - if err := db.Create(&users).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } else if user.ID == 0 { - t.Fatalf("user's primary value should not zero, %v", user.ID) - } else if user.UpdatedAt.IsZero() { - t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) - } - lastUpdatedAt = user.UpdatedAt - - if err := db.Model(user).Update("Age", 10).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 10 { - t.Errorf("Age should equals to 10, but got %v", user.Age) - } - checkUpdatedTime("Update", user.UpdatedAt) - checkOtherData("Update") - - var result User - if err := db.Where("id = ?", user.ID).First(&result).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result, user, "Name", "Age", "Birthday") - } - - values := map[string]interface{}{"Active": true, "age": 5} - if err := db.Model(user).Updates(values).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 5 { - t.Errorf("Age should equals to 5, but got %v", user.Age) - } else if user.Active != true { - t.Errorf("Active should be true, but got %v", user.Active) - } - checkUpdatedTime("Updates with map", user.UpdatedAt) - checkOtherData("Updates with map") - - var result2 User - if err := db.Where("id = ?", user.ID).First(&result2).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result2, user, "Name", "Age", "Birthday") - } - - if err := db.Model(user).Updates(User{Age: 2}).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 2 { - t.Errorf("Age should equals to 2, but got %v", user.Age) - } - checkUpdatedTime("Updates with struct", user.UpdatedAt) - checkOtherData("Updates with struct") - - var result3 User - if err := db.Where("id = ?", user.ID).First(&result3).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result3, user, "Name", "Age", "Birthday") - } - - user.Active = false - user.Age = 1 - if err := db.Save(user).Error; err != nil { - t.Errorf("errors happened when update: %v", err) - } else if user.Age != 1 { - t.Errorf("Age should equals to 1, but got %v", user.Age) - } else if user.Active != false { - t.Errorf("Active should equals to false, but got %v", user.Active) - } - checkUpdatedTime("Save", user.UpdatedAt) - checkOtherData("Save") - - var result4 User - if err := db.Where("id = ?", user.ID).First(&result4).Error; err != nil { - t.Errorf("errors happened when query: %v", err) - } else { - AssertObjEqual(t, result4, user, "Name", "Age", "Birthday") - } - - TestUpdateAssociations(t, db) - }) -} - -func TestUpdateAssociations(t *testing.T, db *gorm.DB) { - db.Migrator().DropTable(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - db.Migrator().AutoMigrate(&Account{}, &Company{}, &Pet{}, &Toy{}, &Language{}) - - TestUpdateBelongsToAssociations(t, db) - TestUpdateHasOneAssociations(t, db) - TestUpdateHasManyAssociations(t, db) - TestUpdateMany2ManyAssociations(t, db) -} - -func TestUpdateBelongsToAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - if user.Company.Name != "" { - if user.CompanyID == nil { - t.Errorf("Company's foreign key should be saved") - } else { - var company Company - db.First(&company, "id = ?", *user.CompanyID) - if company.Name != user.Company.Name { - t.Errorf("Company's name should be same") - } - } - } else if user.CompanyID != nil { - t.Errorf("Company should not be created for zero value, got: %+v", user.CompanyID) - } - - if user.Manager != nil { - if user.ManagerID == nil { - t.Errorf("Manager's foreign key should be saved") - } else { - var manager User - db.First(&manager, "id = ?", *user.ManagerID) - if manager.Name != user.Manager.Name { - t.Errorf("Manager's name should be same") - } - } - } else if user.ManagerID != nil { - t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) - } - } - - t.Run("BelongsTo", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Company = Company{Name: "company-belongs-to-association"} - user.Manager = &User{Name: "manager-belongs-to-association"} - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - check(t, user) - }) -} - -func TestUpdateHasOneAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - if user.Account.ID == 0 { - t.Errorf("Account should be saved") - } else if user.Account.UserID.Int64 != int64(user.ID) { - t.Errorf("Account's foreign key should be saved") - } else { - var account Account - db.First(&account, "id = ?", user.Account.ID) - if account.Number != user.Account.Number { - t.Errorf("Account's number should be sme") - } - } - } - - t.Run("HasOne", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Account = Account{Number: "account-has-one-association"} - - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - check(t, user) - }) - - checkPet := func(t *testing.T, pet Pet) { - if pet.Toy.OwnerID != fmt.Sprint(pet.ID) || pet.Toy.OwnerType != "pets" { - t.Errorf("Failed to create polymorphic has one association - toy owner id %v, owner type %v", pet.Toy.OwnerID, pet.Toy.OwnerType) - } else { - var toy Toy - db.First(&toy, "owner_id = ? and owner_type = ?", pet.Toy.OwnerID, pet.Toy.OwnerType) - if toy.Name != pet.Toy.Name { - t.Errorf("Failed to query saved polymorphic has one association") - } - } - } - - t.Run("PolymorphicHasOne", func(t *testing.T) { - var pet = Pet{ - Name: "create", - } - - if err := db.Create(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} - - if err := db.Save(&pet).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - checkPet(t, pet) - }) -} - -func TestUpdateHasManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - for _, pet := range user.Pets { - if pet.ID == 0 { - t.Errorf("Pet's foreign key should be saved") - } - - var result Pet - db.First(&result, "id = ?", pet.ID) - if result.Name != pet.Name { - t.Errorf("Pet's name should be same") - } else if result.UserID != user.ID { - t.Errorf("Pet's foreign key should be saved") - } - } - } - - t.Run("HasMany", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - check(t, user) - }) - - checkToy := func(t *testing.T, user User) { - for idx, toy := range user.Toys { - if toy.ID == 0 { - t.Fatalf("Failed to create toy #%v", idx) - } - - var result Toy - db.First(&result, "id = ?", toy.ID) - if result.Name != toy.Name { - t.Errorf("Failed to query saved toy") - } else if result.OwnerID != fmt.Sprint(user.ID) || result.OwnerType != "users" { - t.Errorf("Failed to save relation") - } - } - } - - t.Run("PolymorphicHasMany", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when update: %v", err) - } - - checkToy(t, user) - }) -} - -func TestUpdateMany2ManyAssociations(t *testing.T, db *gorm.DB) { - check := func(t *testing.T, user User) { - for _, language := range user.Languages { - var result Language - db.First(&result, "code = ?", language.Code) - // TODO - // if result.Name != language.Name { - // t.Errorf("Language's name should be same") - // } - } - - for _, f := range user.Friends { - if f.ID == 0 { - t.Errorf("Friend's foreign key should be saved") - } - - var result User - db.First(&result, "id = ?", f.ID) - if result.Name != f.Name { - t.Errorf("Friend's name should be same") - } - } - } - - t.Run("Many2Many", func(t *testing.T) { - var user = User{ - Name: "create", - Age: 18, - Birthday: Now(), - } - - if err := db.Create(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} - user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} - - if err := db.Save(&user).Error; err != nil { - t.Fatalf("errors happened when create: %v", err) - } - - check(t, user) - }) -} diff --git a/tests/update_test.go b/tests/update_test.go new file mode 100644 index 00000000..10835f97 --- /dev/null +++ b/tests/update_test.go @@ -0,0 +1,226 @@ +package tests_test + +import ( + "testing" + "time" + + . "github.com/jinzhu/gorm/tests" +) + +func TestUpdate(t *testing.T) { + var ( + users = []*User{ + GetUser("update-1", Config{}), + GetUser("update-2", Config{}), + GetUser("update-3", Config{}), + } + user = users[1] + lastUpdatedAt time.Time + ) + + checkUpdatedTime := func(name string, n time.Time) { + if n.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) + } + lastUpdatedAt = n + } + + checkOtherData := func(name string) { + var first, last User + if err := DB.Where("id = ?", users[0].ID).First(&first).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } + CheckUser(t, first, *users[0]) + + if err := DB.Where("id = ?", users[2].ID).First(&last).Error; err != nil { + t.Errorf("errors happened when query after user: %v", err) + } + CheckUser(t, last, *users[2]) + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Fatalf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) + } + lastUpdatedAt = user.UpdatedAt + + if err := DB.Model(user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + checkUpdatedTime("Update", user.UpdatedAt) + checkOtherData("Update") + + var result User + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result, *user) + } + + values := map[string]interface{}{"Active": true, "age": 5} + if err := DB.Model(user).Updates(values).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + checkUpdatedTime("Updates with map", user.UpdatedAt) + checkOtherData("Updates with map") + + var result2 User + if err := DB.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result2, *user) + } + + if err := DB.Model(user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + checkUpdatedTime("Updates with struct", user.UpdatedAt) + checkOtherData("Updates with struct") + + var result3 User + if err := DB.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result3, *user) + } + + user.Active = false + user.Age = 1 + if err := DB.Save(user).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 1 { + t.Errorf("Age should equals to 1, but got %v", user.Age) + } else if user.Active != false { + t.Errorf("Active should equals to false, but got %v", user.Active) + } + checkUpdatedTime("Save", user.UpdatedAt) + checkOtherData("Save") + + var result4 User + if err := DB.Where("id = ?", user.ID).First(&result4).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result4, *user) + } +} + +func TestUpdateBelongsTo(t *testing.T) { + var user = *GetUser("update-belongs-to", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Company = Company{Name: "company-belongs-to-association"} + user.Manager = &User{Name: "manager-belongs-to-association"} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} + +func TestUpdateHasOne(t *testing.T) { + var user = *GetUser("update-has-one", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Account = Account{Number: "account-has-one-association"} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Account").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + t.Run("Polymorphic", func(t *testing.T) { + var pet = Pet{Name: "create"} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} + + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + }) +} + +func TestUpdateHasManyAssociations(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Pets").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + t.Run("Polymorphic", func(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Toys").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + }) +} + +func TestUpdateMany2ManyAssociations(t *testing.T) { + var user = *GetUser("update-many2many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} + for _, lang := range user.Languages { + DB.Create(&lang) + } + user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} diff --git a/tests/utils.go b/tests/utils.go index cb4e4fcc..001d77e9 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -2,10 +2,74 @@ package tests import ( "reflect" + "sort" + "strconv" + "strings" "testing" "time" + + "github.com/jinzhu/gorm/utils" ) +type Config struct { + Account bool + Pets int + Toys int + Company bool + Manager bool + Team int + Languages int + Friends int +} + +func GetUser(name string, config Config) *User { + var ( + birthday = time.Now() + user = User{ + Name: name, + Age: 18, + Birthday: &birthday, + } + ) + + if config.Account { + user.Account = Account{Number: name + "_account"} + } + + for i := 0; i < config.Pets; i++ { + user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) + } + + for i := 0; i < config.Toys; i++ { + user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) + } + + if config.Company { + user.Company = Company{Name: "company-" + name} + } + + if config.Manager { + user.Manager = GetUser(name+"_manager", Config{}) + } + + for i := 0; i < config.Team; i++ { + user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) + } + + for i := 0; i < config.Languages; i++ { + name := name + "_locale_" + strconv.Itoa(i+1) + language := Language{Code: name, Name: name} + DB.Create(&language) + user.Languages = append(user.Languages, language) + } + + for i := 0; i < config.Friends; i++ { + user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) + } + + return &user +} + func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() @@ -21,11 +85,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) { isEqual := func() { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" + if curTime.Format(format) != expect.(time.Time).Format(format) { - t.Errorf("expect: %v, got %v", expect.(time.Time).Format(format), curTime.Format(format)) + t.Errorf("%v: expect: %v, got %v", utils.FileWithLineNum(), expect.(time.Time).Format(format), curTime.Format(format)) } } else if got != expect { - t.Errorf("expect: %#v, got %#v", expect, got) + t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) } } @@ -34,7 +99,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { - t.Errorf("expect: %+v, got %+v", expect, got) + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) return } @@ -55,3 +120,164 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } } + +func CheckPet(t *testing.T, pet Pet, expect Pet) { + if pet.ID != 0 { + var newPet Pet + if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + } + } + + AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + + AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") + + if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) + } +} + +func CheckUser(t *testing.T, user User, expect User) { + if user.ID != 0 { + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } + + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + + t.Run("Account", func(t *testing.T) { + AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + + if user.Account.Number != "" { + if !user.Account.UserID.Valid { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + DB.First(&account, "user_id = ?", user.ID) + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + } + } + }) + + t.Run("Pets", func(t *testing.T) { + if len(user.Pets) != len(expect.Pets) { + t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + } + + sort.Slice(user.Pets, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + + sort.Slice(expect.Pets, func(i, j int) bool { + return expect.Pets[i].ID > expect.Pets[j].ID + }) + + for idx, pet := range user.Pets { + if pet == nil || expect.Pets[idx] == nil { + t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) + } else { + CheckPet(t, *pet, *expect.Pets[idx]) + } + } + }) + + t.Run("Toys", func(t *testing.T) { + if len(user.Toys) != len(expect.Toys) { + t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + } + + sort.Slice(user.Toys, func(i, j int) bool { + return user.Toys[i].ID > user.Toys[j].ID + }) + + sort.Slice(expect.Toys, func(i, j int) bool { + return expect.Toys[i].ID > expect.Toys[j].ID + }) + + for idx, toy := range user.Toys { + if toy.OwnerType != "users" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) + } + + AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") + } + }) + + t.Run("Company", func(t *testing.T) { + AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") + }) + + t.Run("Manager", func(t *testing.T) { + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + DB.First(&manager, "id = ?", *user.ManagerID) + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + }) + + t.Run("Team", func(t *testing.T) { + if len(user.Team) != len(expect.Team) { + t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + } + + sort.Slice(user.Team, func(i, j int) bool { + return user.Team[i].ID > user.Team[j].ID + }) + + sort.Slice(expect.Team, func(i, j int) bool { + return expect.Team[i].ID > expect.Team[j].ID + }) + + for idx, team := range user.Team { + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) + + t.Run("Languages", func(t *testing.T) { + if len(user.Languages) != len(expect.Languages) { + t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + } + + sort.Slice(user.Languages, func(i, j int) bool { + return strings.Compare(user.Languages[i].Code, user.Languages[j].Code) > 0 + }) + + sort.Slice(expect.Languages, func(i, j int) bool { + return strings.Compare(expect.Languages[i].Code, expect.Languages[j].Code) > 0 + }) + for idx, language := range user.Languages { + AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") + } + }) + + t.Run("Friends", func(t *testing.T) { + if len(user.Friends) != len(expect.Friends) { + t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + } + + sort.Slice(user.Friends, func(i, j int) bool { + return user.Friends[i].ID > user.Friends[j].ID + }) + + sort.Slice(expect.Friends, func(i, j int) bool { + return expect.Friends[i].ID > expect.Friends[j].ID + }) + + for idx, friend := range user.Friends { + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) +}