mirror of https://github.com/go-gorm/gorm.git
Test SubQuery
This commit is contained in:
parent
1559fe24e5
commit
4e147e1256
|
@ -80,7 +80,7 @@ func (p *processor) Execute(db *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.Model != nil {
|
if stmt.Model != nil {
|
||||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
|
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,36 +63,38 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if !db.DryRun {
|
||||||
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||||
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
||||||
if insertID, err := result.LastInsertId(); err == nil {
|
if insertID, err := result.LastInsertId(); err == nil {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if config.LastInsertIDReversed {
|
if config.LastInsertIDReversed {
|
||||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||||
insertID--
|
insertID--
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||||
insertID++
|
insertID++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
} else {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -135,42 +137,44 @@ func CreateWithReturning(db *gorm.DB) {
|
||||||
idx++
|
idx++
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if !db.DryRun {
|
||||||
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
for idx, field := range fields {
|
||||||
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||||
|
}
|
||||||
|
if err := rows.Scan(values...); err != nil {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
db.RowsAffected++
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
}
|
}
|
||||||
if err := rows.Scan(values...); err != nil {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
|
||||||
db.RowsAffected++
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
for idx, field := range fields {
|
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
err = rows.Scan(values...)
|
err = rows.Scan(values...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,12 +72,14 @@ func Delete(db *gorm.DB) {
|
||||||
db.Statement.Build("DELETE", "FROM", "WHERE")
|
db.Statement.Build("DELETE", "FROM", "WHERE")
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if !db.DryRun {
|
||||||
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,14 +23,16 @@ func Query(db *gorm.DB) {
|
||||||
BuildQuerySQL(db)
|
BuildQuerySQL(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if !db.DryRun {
|
||||||
if err != nil {
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
db.AddError(err)
|
if err != nil {
|
||||||
return
|
db.AddError(err)
|
||||||
}
|
return
|
||||||
defer rows.Close()
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
gorm.Scan(rows, db, false)
|
gorm.Scan(rows, db, false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,12 +85,14 @@ func Update(db *gorm.DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if !db.DryRun {
|
||||||
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
} else {
|
} else {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,41 +61,43 @@ func Create(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
if !db.DryRun {
|
||||||
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||||
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||||
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
db.RowsAffected++
|
||||||
|
db.AddError(rows.Scan(values...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||||
|
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
if rows.Next() {
|
||||||
db.AddError(rows.Scan(values...))
|
db.RowsAffected++
|
||||||
}
|
db.AddError(rows.Scan(values...))
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
|
||||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
|
||||||
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
|
||||||
|
|
||||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Next() {
|
|
||||||
db.RowsAffected++
|
|
||||||
db.AddError(rows.Scan(values...))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
db.AddError(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
7
gorm.go
7
gorm.go
|
@ -22,6 +22,8 @@ type Config struct {
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
// NowFunc the function to be used when creating a new timestamp
|
// NowFunc the function to be used when creating a new timestamp
|
||||||
NowFunc func() time.Time
|
NowFunc func() time.Time
|
||||||
|
// DryRun generate sql without execute
|
||||||
|
DryRun bool
|
||||||
|
|
||||||
// ClauseBuilders clause builder
|
// ClauseBuilders clause builder
|
||||||
ClauseBuilders map[string]clause.ClauseBuilder
|
ClauseBuilders map[string]clause.ClauseBuilder
|
||||||
|
@ -45,6 +47,7 @@ type DB struct {
|
||||||
|
|
||||||
// Session session config when create session with Session() method
|
// Session session config when create session with Session() method
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
DryRun bool
|
||||||
WithConditions bool
|
WithConditions bool
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
|
@ -120,6 +123,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
tx.clone = 3
|
tx.clone = 3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.DryRun {
|
||||||
|
tx.Config.DryRun = true
|
||||||
|
}
|
||||||
|
|
||||||
if config.Logger != nil {
|
if config.Logger != nil {
|
||||||
tx.Config.Logger = config.Logger
|
tx.Config.Logger = config.Logger
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,10 @@ func isPrintable(s []byte) bool {
|
||||||
|
|
||||||
var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, vars ...interface{}) string {
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||||
var convertParams func(interface{}, int)
|
var convertParams func(interface{}, int)
|
||||||
|
var vars = make([]interface{}, len(avars))
|
||||||
|
copy(vars, avars)
|
||||||
|
|
||||||
convertParams = func(v interface{}, idx int) {
|
convertParams = func(v interface{}, idx int) {
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
|
|
14
statement.go
14
statement.go
|
@ -157,6 +157,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||||
} else {
|
} else {
|
||||||
writer.WriteString("(NULL)")
|
writer.WriteString("(NULL)")
|
||||||
}
|
}
|
||||||
|
case *DB:
|
||||||
|
result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement
|
||||||
|
writer.WriteString(result.SQL.String())
|
||||||
|
stmt.Vars = append(stmt.Vars, result.Vars...)
|
||||||
default:
|
default:
|
||||||
switch rv := reflect.ValueOf(v); rv.Kind() {
|
switch rv := reflect.ValueOf(v); rv.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
|
@ -226,7 +230,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
conds = append(conds, v)
|
conds = append(conds, v)
|
||||||
case *DB:
|
case *DB:
|
||||||
if v.Statement == nil {
|
if v.Statement != nil {
|
||||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||||
conds = append(conds, cs.Expression)
|
conds = append(conds, cs.Expression)
|
||||||
}
|
}
|
||||||
|
@ -367,7 +371,9 @@ func (stmt *Statement) reinit() {
|
||||||
// })
|
// })
|
||||||
|
|
||||||
// stmt.Schema = nil
|
// stmt.Schema = nil
|
||||||
stmt.SQL.Reset()
|
if !stmt.DB.DryRun {
|
||||||
stmt.Vars = nil
|
stmt.SQL.Reset()
|
||||||
stmt.NamedVars = nil
|
stmt.Vars = nil
|
||||||
|
stmt.NamedVars = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -349,3 +349,89 @@ func TestSearchWithMap(t *testing.T) {
|
||||||
t.Errorf("Search all records with inline multiple value map")
|
t.Errorf("Search all records with inline multiple value map")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSubQuery(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{Name: "subquery_1", Age: 10},
|
||||||
|
{Name: "subquery_2", Age: 20},
|
||||||
|
{Name: "subquery_3", Age: 30},
|
||||||
|
{Name: "subquery_4", Age: 40},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil {
|
||||||
|
t.Fatalf("got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != 4 {
|
||||||
|
t.Errorf("Four users should be found, instead found %d", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB.
|
||||||
|
Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users)
|
||||||
|
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Two users should be found, instead found %d", len(users))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubQueryWithRaw(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{Name: "subquery_raw_1", Age: 10},
|
||||||
|
{Name: "subquery_raw_2", Age: 20},
|
||||||
|
{Name: "subquery_raw_3", Age: 30},
|
||||||
|
{Name: "subquery_raw_4", Age: 40},
|
||||||
|
}
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
err := DB.Raw("select count(*) from (?) tmp",
|
||||||
|
DB.Table("users").
|
||||||
|
Select("name").
|
||||||
|
Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
|
||||||
|
Group("name"),
|
||||||
|
).Count(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected to get no errors, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("Row count must be 1, instead got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Raw("select count(*) from (?) tmp",
|
||||||
|
DB.Table("users").
|
||||||
|
Select("name").
|
||||||
|
Where("name LIKE ?", "subquery_raw%").
|
||||||
|
Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
|
||||||
|
Group("name"),
|
||||||
|
).Count(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected to get no errors, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("Row count must be 2, instead got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubQueryWithHaving(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{Name: "subquery_having_1", Age: 10},
|
||||||
|
{Name: "subquery_having_2", Age: 20},
|
||||||
|
{Name: "subquery_having_3", Age: 30},
|
||||||
|
{Name: "subquery_having_4", Age: 40},
|
||||||
|
}
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var results []User
|
||||||
|
DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB.
|
||||||
|
Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results)
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("Two user group should be found, instead found %d", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue