From 749ca37eb0bdb149dbdc8fa7a47c39cf708f51ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 19:23:32 +0800 Subject: [PATCH] Add sql builder test --- callbacks/query.go | 188 ++++++++++++++++++++------------------ callbacks/row.go | 6 +- tests/sql_builder_test.go | 82 +++++++++++++++++ 3 files changed, 184 insertions(+), 92 deletions(-) create mode 100644 tests/sql_builder_test.go diff --git a/callbacks/query.go b/callbacks/query.go index 6edfee0b..9f96fd1a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -19,93 +19,7 @@ func Query(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - clauseSelect := clause.Select{} - - if len(db.Statement.Selects) > 0 { - for _, name := range db.Statement.Selects { - if f := db.Statement.Schema.LookUpField(name); f != nil { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: f.DBName, - }) - } else { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: name, - Raw: true, - }) - } - } - } - - // inline joins - if len(db.Statement.Joins) != 0 { - joins := []clause.Join{} - - if len(db.Statement.Selects) == 0 { - for _, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: db.Statement.Table, - Name: dbName, - }) - } - } - - for name, conds := range db.Statement.Joins { - if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { - tableAliasName := relation.Name - - for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) - } - - var exprs []clause.Expression - for _, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs = append(exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - }) - } else { - if ref.PrimaryValue == "" { - exprs = append(exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - }) - } else { - exprs = append(exprs, clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - } - - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, - }) - } - } - - db.Statement.AddClause(clause.From{Joins: joins}) - } else { - db.Statement.AddClauseIfNotExists(clause.From{}) - } - - if len(clauseSelect.Columns) > 0 { - db.Statement.AddClause(clauseSelect) - } else { - db.Statement.AddClauseIfNotExists(clauseSelect) - } - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + BuildQuerySQL(db) } rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -118,6 +32,106 @@ func Query(db *gorm.DB) { gorm.Scan(rows, db, false) } +func BuildQuerySQL(db *gorm.DB) { + clauseSelect := clause.Select{} + + if len(db.Statement.Selects) > 0 { + for _, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: name, + Raw: true, + }) + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: f.DBName, + }) + } else { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: name, + Raw: true, + }) + } + } + } + + // inline joins + if len(db.Statement.Joins) != 0 { + joins := []clause.Join{} + + if len(db.Statement.Selects) == 0 { + for _, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: db.Statement.Table, + Name: dbName, + }) + } + } + + for name, conds := range db.Statement.Joins { + if db.Statement.Schema == nil { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: name, Vars: conds}, + }) + } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + tableAliasName := relation.Name + + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } + + var exprs []clause.Expression + for _, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + }) + } else { + if ref.PrimaryValue == "" { + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + }) + } else { + exprs = append(exprs, clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: name, Vars: conds}, + }) + } + } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + if len(clauseSelect.Columns) > 0 { + db.Statement.AddClause(clauseSelect) + } else { + db.Statement.AddClauseIfNotExists(clauseSelect) + } + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") +} + func Preload(db *gorm.DB) { if len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{} diff --git a/callbacks/row.go b/callbacks/row.go index b84cf694..004a89d5 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -2,15 +2,11 @@ package callbacks import ( "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/clause" ) func RowQuery(db *gorm.DB) { if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) - - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + BuildQuerySQL(db) } if _, ok := db.Get("rows"); ok { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go new file mode 100644 index 00000000..4cd40c7a --- /dev/null +++ b/tests/sql_builder_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestRow(t *testing.T) { + user1 := User{Name: "RowUser1", Age: 1} + user2 := User{Name: "RowUser2", Age: 10} + user3 := User{Name: "RowUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() + + var age int64 + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 10 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } +} + +func TestRows(t *testing.T) { + user1 := User{Name: "RowsUser1", Age: 1} + user2 := User{Name: "RowsUser2", Age: 10} + user3 := User{Name: "RowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + count := 0 + for rows.Next() { + var name string + var age int64 + rows.Scan(&name, &age) + count++ + } + + if count != 2 { + t.Errorf("Should found two records") + } +} + +func TestRaw(t *testing.T) { + user1 := User{Name: "ExecRawSqlUser1", Age: 1} + user2 := User{Name: "ExecRawSqlUser2", Age: 10} + user3 := User{Name: "ExecRawSqlUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Email string + } + + var results []result + DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&results) + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { + t.Errorf("Raw with scan") + } + + rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() + count := 0 + for rows.Next() { + count++ + } + if count != 1 { + t.Errorf("Raw with Rows should find one record with name 3") + } + + DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { + t.Error("Raw sql to update records") + } +}