Test SubQuery

This commit is contained in:
Jinzhu 2020-06-01 21:26:23 +08:00
parent 1559fe24e5
commit 4e147e1256
10 changed files with 212 additions and 99 deletions

View File

@ -80,7 +80,7 @@ func (p *processor) Execute(db *DB) {
}
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
db.AddError(err)
}
}

View File

@ -63,6 +63,7 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
if !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
@ -97,6 +98,7 @@ func Create(config *Config) func(db *gorm.DB) {
}
}
}
}
}
func CreateWithReturning(db *gorm.DB) {
@ -135,6 +137,7 @@ func CreateWithReturning(db *gorm.DB) {
idx++
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
@ -174,6 +177,7 @@ func CreateWithReturning(db *gorm.DB) {
}
}
}
}
}
func AfterCreate(db *gorm.DB) {

View File

@ -72,6 +72,7 @@ func Delete(db *gorm.DB) {
db.Statement.Build("DELETE", "FROM", "WHERE")
}
if !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
@ -80,6 +81,7 @@ func Delete(db *gorm.DB) {
db.AddError(err)
}
}
}
}
func AfterDelete(db *gorm.DB) {

View File

@ -23,6 +23,7 @@ func Query(db *gorm.DB) {
BuildQuerySQL(db)
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
@ -32,6 +33,7 @@ func Query(db *gorm.DB) {
gorm.Scan(rows, db, false)
}
}
}
func BuildQuerySQL(db *gorm.DB) {

View File

@ -85,6 +85,7 @@ func Update(db *gorm.DB) {
return
}
if !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
@ -93,6 +94,7 @@ func Update(db *gorm.DB) {
db.AddError(err)
}
}
}
}
func AfterUpdate(db *gorm.DB) {

View File

@ -61,6 +61,7 @@ func Create(db *gorm.DB) {
}
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
@ -97,6 +98,7 @@ func Create(db *gorm.DB) {
} else {
db.AddError(err)
}
}
}
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {

View File

@ -22,6 +22,8 @@ type Config struct {
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -45,6 +47,7 @@ type DB struct {
// Session session config when create session with Session() method
type Session struct {
DryRun bool
WithConditions bool
Context context.Context
Logger logger.Interface
@ -120,6 +123,10 @@ func (db *DB) Session(config *Session) *DB {
tx.clone = 3
}
if config.DryRun {
tx.Config.DryRun = true
}
if config.Logger != nil {
tx.Config.Logger = config.Logger
}

View File

@ -22,8 +22,10 @@ func isPrintable(s []byte) bool {
var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int)
var vars = make([]interface{}, len(avars))
copy(vars, avars)
convertParams = func(v interface{}, idx int) {
switch v := v.(type) {

View File

@ -157,6 +157,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else {
writer.WriteString("(NULL)")
}
case *DB:
result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement
writer.WriteString(result.SQL.String())
stmt.Vars = append(stmt.Vars, result.Vars...)
default:
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
@ -226,7 +230,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
case clause.Expression:
conds = append(conds, v)
case *DB:
if v.Statement == nil {
if v.Statement != nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conds = append(conds, cs.Expression)
}
@ -367,7 +371,9 @@ func (stmt *Statement) reinit() {
// })
// stmt.Schema = nil
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
stmt.NamedVars = nil
}
}

View File

@ -349,3 +349,89 @@ func TestSearchWithMap(t *testing.T) {
t.Errorf("Search all records with inline multiple value map")
}
}
func TestSubQuery(t *testing.T) {
users := []User{
{Name: "subquery_1", Age: 10},
{Name: "subquery_2", Age: 20},
{Name: "subquery_3", Age: 30},
{Name: "subquery_4", Age: 40},
}
DB.Create(&users)
if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil {
t.Fatalf("got error: %v", err)
}
if len(users) != 4 {
t.Errorf("Four users should be found, instead found %d", len(users))
}
DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB.
Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users)
if len(users) != 2 {
t.Errorf("Two users should be found, instead found %d", len(users))
}
}
func TestSubQueryWithRaw(t *testing.T) {
users := []User{
{Name: "subquery_raw_1", Age: 10},
{Name: "subquery_raw_2", Age: 20},
{Name: "subquery_raw_3", Age: 30},
{Name: "subquery_raw_4", Age: 40},
}
DB.Create(&users)
var count int64
err := DB.Raw("select count(*) from (?) tmp",
DB.Table("users").
Select("name").
Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
Group("name"),
).Count(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
if count != 1 {
t.Errorf("Row count must be 1, instead got %d", count)
}
err = DB.Raw("select count(*) from (?) tmp",
DB.Table("users").
Select("name").
Where("name LIKE ?", "subquery_raw%").
Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
Group("name"),
).Count(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
if count != 2 {
t.Errorf("Row count must be 2, instead got %d", count)
}
}
func TestSubQueryWithHaving(t *testing.T) {
users := []User{
{Name: "subquery_having_1", Age: 10},
{Name: "subquery_having_2", Age: 20},
{Name: "subquery_having_3", Age: 30},
{Name: "subquery_having_4", Age: 40},
}
DB.Create(&users)
var results []User
DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB.
Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results)
if len(results) != 2 {
t.Errorf("Two user group should be found, instead found %d", len(results))
}
}