From dffc2713f010c4253b61adae61810e27044ab157 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2020 10:02:20 +0800 Subject: [PATCH] Add mores tests for query --- chainable_api.go | 12 ++- statement.go | 21 ++-- tests/query_test.go | 197 ++++++++++++++++++++++++++++++++++- tests/scanner_valuer_test.go | 41 ++++++++ tests/sql_builder_test.go | 42 ++++++++ 5 files changed, 299 insertions(+), 14 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index afcdccd2..6fa605c6 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -111,21 +111,27 @@ func (db *DB) Omit(columns ...string) (tx *DB) { // Where add conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) + if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: conds}) + } return } // Not add NOT conditions func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) + if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + } return } // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) + if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) + } return } diff --git a/statement.go b/statement.go index 444d5c37..aa7d193c 100644 --- a/statement.go +++ b/statement.go @@ -204,12 +204,15 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { // BuildCondtion build condition func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { - if i, err := strconv.Atoi(sql); err == nil { - query = i - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { - return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} - } else if len(args) == 1 { - return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} + // if it is a number, then treats it as primary key + if _, err := strconv.Atoi(sql); err != nil { + if sql == "" && len(args) == 0 { + return + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} + } else if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} + } } } @@ -267,14 +270,12 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con } } } + } else if len(conds) == 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } } } - if len(conds) == 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) - } - return } diff --git a/tests/query_test.go b/tests/query_test.go index a4fe1243..6efadc8e 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,12 +1,14 @@ package tests_test import ( + "fmt" "reflect" "sort" "strconv" "testing" "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" ) @@ -115,8 +117,14 @@ func TestPluck(t *testing.T) { t.Errorf("got error when pluck name: %v", err) } + var names2 []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { + t.Errorf("got error when pluck name: %v", err) + } + AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2))) + var ids []int - if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil { + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { t.Errorf("got error when pluck id: %v", err) } @@ -133,6 +141,21 @@ func TestPluck(t *testing.T) { } } +func TestSelect(t *testing.T) { + user := User{Name: "SelectUser1"} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Select("name").Find(&result) + if result.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if user.Name != result.Name { + t.Errorf("Should have user Name when selected it") + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, @@ -151,3 +174,175 @@ func TestPluckWithSelect(t *testing.T) { AssertEqual(t, userAges, []int{26, 27}) } + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "select_with_variables"}) + + rows, _ := DB.Table("users").Where("name = ?", "select_with_variables").Select("? as fake", gorm.Expr("name")).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } else { + columns, _ := rows.Columns() + AssertEqual(t, columns, []string{"fake"}) + } + + rows.Close() +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "select_with_array", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = ?", "select_with_array").First(&user) + + if user.Name != "select_with_array" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} + +func TestCustomizedTypePrimaryKey(t *testing.T) { + type ID uint + type CustomizedTypePrimaryKey struct { + ID ID + Name string + } + + DB.Migrator().DropTable(&CustomizedTypePrimaryKey{}) + if err := DB.AutoMigrate(&CustomizedTypePrimaryKey{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + p1 := CustomizedTypePrimaryKey{Name: "p1"} + p2 := CustomizedTypePrimaryKey{Name: "p2"} + p3 := CustomizedTypePrimaryKey{Name: "p3"} + DB.Create(&p1) + DB.Create(&p2) + DB.Create(&p3) + + var p CustomizedTypePrimaryKey + + if err := DB.First(&p, p2.ID).Error; err != nil { + t.Errorf("No error should returns, but got %v", err) + } + + AssertEqual(t, p, p2) + + if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { + t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) + } + + AssertEqual(t, p, p2) +} + +func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.Migrator().DropTable(&AddressByZipCode{}) + if err := DB.AutoMigrate(&AddressByZipCode{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + address := AddressByZipCode{ZipCode: "00501", Address: "Holtsville"} + DB.Create(&address) + + var result AddressByZipCode + DB.First(&result, "00501") + + AssertEqual(t, result, address) +} + +func TestSearchWithEmptyChain(t *testing.T) { + user := User{Name: "search_with_empty_chain", Age: 1} + DB.Create(&user) + + var result User + if DB.Where("").Where("").First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty strings") + } + + if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty struct") + } + + if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty map") + } +} + +func TestLimit(t *testing.T) { + users := []User{ + {Name: "LimitUser1", Age: 1}, + {Name: "LimitUser2", Age: 10}, + {Name: "LimitUser3", Age: 20}, + {Name: "LimitUser4", Age: 10}, + {Name: "LimitUser5", Age: 20}, + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) + + if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { + t.Errorf("Limit should works") + } +} + +func TestOffset(t *testing.T) { + for i := 0; i < 20; i++ { + DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) + } + var users1, users2, users3, users4 []User + + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work") + } +} + +func TestSearchWithMap(t *testing.T) { + users := []User{ + *GetUser("map_search_user1", Config{}), + *GetUser("map_search_user2", Config{}), + *GetUser("map_search_user3", Config{}), + *GetUser("map_search_user4", Config{Company: true}), + } + + DB.Create(&users) + + var user User + DB.First(&user, map[string]interface{}{"name": users[0].Name}) + CheckUser(t, user, users[0]) + + DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) + CheckUser(t, user, users[1]) + + var results []User + DB.Where(map[string]interface{}{"name": users[2].Name}).Find(&results) + if len(results) != 1 { + t.Fatalf("Search all records with inline map") + } + + CheckUser(t, results[0], users[2]) + + var results2 []User + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": nil}) + if len(results2) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&results2, map[string]interface{}{"name": users[0].Name, "company_id": nil}) + if len(results2) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": users[3].CompanyID}) + if len(results2) != 1 { + t.Errorf("Search all records with inline multiple value map") + } +} diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 04c91ab2..9f91b5d8 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -50,6 +50,47 @@ func TestScannerValuer(t *testing.T) { AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs") } +func TestScannerValuerWithFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + } + + var result ScannerValuerStruct + tx := DB.Where(data).FirstOrCreate(&result) + + if tx.RowsAffected != 1 { + t.Errorf("RowsAffected should be 1 after create some record") + } + + if tx.Error != nil { + t.Errorf("Should not raise any error, but got %v", tx.Error) + } + + AssertObjEqual(t, result, data, "Name", "Gender", "Age") + + if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if result.Age.Int64 != 18 { + t.Errorf("should update age to 18") + } + + var result2 ScannerValuerStruct + if err := DB.First(&result2, result.ID).Error; err != nil { + t.Errorf("got error %v when query with %v", err, result.ID) + } + + AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age") +} + func TestInvalidValuer(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 4cd40c7a..0aed82a2 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -80,3 +80,45 @@ func TestRaw(t *testing.T) { t.Error("Raw sql to update records") } } + +func TestRowsWithGroup(t *testing.T) { + users := []User{ + {Name: "having_user_1", Age: 1}, + {Name: "having_user_2", Age: 10}, + {Name: "having_user_1", Age: 20}, + {Name: "having_user_1", Age: 30}, + } + + DB.Create(&users) + + rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN ?", []string{users[0].Name, users[1].Name}).Rows() + if err != nil { + t.Fatalf("got error %v", err) + } + + defer rows.Close() + for rows.Next() { + var name string + var total int64 + rows.Scan(&name, &total) + + if name == users[0].Name && total != 3 { + t.Errorf("Should have one user having name %v", users[0].Name) + } else if name == users[1].Name && total != 1 { + t.Errorf("Should have two users having name %v", users[1].Name) + } + } +} + +func TestQueryRaw(t *testing.T) { + users := []*User{ + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + } + DB.Create(&users) + + var user User + DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) + CheckUser(t, user, *users[1]) +}