2020-01-29 14:22:44 +03:00
|
|
|
package gorm
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
2020-01-29 22:03:06 +03:00
|
|
|
"database/sql/driver"
|
2020-01-29 14:22:44 +03:00
|
|
|
"fmt"
|
2020-02-23 14:41:29 +03:00
|
|
|
"reflect"
|
2021-10-08 12:51:27 +03:00
|
|
|
"regexp"
|
2020-07-05 07:23:45 +03:00
|
|
|
"sort"
|
2020-01-29 22:03:06 +03:00
|
|
|
"strconv"
|
2020-01-29 14:22:44 +03:00
|
|
|
"strings"
|
|
|
|
"sync"
|
|
|
|
|
2020-06-02 04:16:07 +03:00
|
|
|
"gorm.io/gorm/clause"
|
2020-09-09 11:26:11 +03:00
|
|
|
"gorm.io/gorm/logger"
|
2020-06-02 04:16:07 +03:00
|
|
|
"gorm.io/gorm/schema"
|
2020-06-30 11:53:54 +03:00
|
|
|
"gorm.io/gorm/utils"
|
2020-01-29 14:22:44 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
// Statement statement
|
|
|
|
type Statement struct {
|
2020-03-09 15:37:01 +03:00
|
|
|
*DB
|
2020-07-17 16:19:11 +03:00
|
|
|
TableExpr *clause.Expr
|
2020-03-03 09:18:12 +03:00
|
|
|
Table string
|
|
|
|
Model interface{}
|
2020-05-29 02:35:45 +03:00
|
|
|
Unscoped bool
|
2020-03-03 09:18:12 +03:00
|
|
|
Dest interface{}
|
|
|
|
ReflectValue reflect.Value
|
|
|
|
Clauses map[string]clause.Clause
|
2021-04-28 12:19:30 +03:00
|
|
|
BuildClauses []string
|
2020-06-05 14:19:08 +03:00
|
|
|
Distinct bool
|
2020-03-03 09:18:12 +03:00
|
|
|
Selects []string // selected columns
|
|
|
|
Omits []string // omit columns
|
2020-08-23 05:40:37 +03:00
|
|
|
Joins []join
|
2020-04-15 04:14:24 +03:00
|
|
|
Preloads map[string][]interface{}
|
2020-03-03 09:18:12 +03:00
|
|
|
Settings sync.Map
|
2020-03-09 08:10:48 +03:00
|
|
|
ConnPool ConnPool
|
2020-03-03 09:18:12 +03:00
|
|
|
Schema *schema.Schema
|
2020-03-09 08:10:48 +03:00
|
|
|
Context context.Context
|
2020-03-03 09:18:12 +03:00
|
|
|
RaiseErrorOnNotFound bool
|
2020-11-17 12:49:43 +03:00
|
|
|
SkipHooks bool
|
2020-03-09 10:32:55 +03:00
|
|
|
SQL strings.Builder
|
|
|
|
Vars []interface{}
|
2020-06-30 17:47:21 +03:00
|
|
|
CurDestIndex int
|
2020-05-28 08:12:56 +03:00
|
|
|
attrs []interface{}
|
|
|
|
assigns []interface{}
|
2021-02-25 13:49:01 +03:00
|
|
|
scopes []func(*DB) *DB
|
2020-01-29 14:22:44 +03:00
|
|
|
}
|
|
|
|
|
2020-08-23 05:40:37 +03:00
|
|
|
type join struct {
|
|
|
|
Name string
|
|
|
|
Conds []interface{}
|
2021-09-07 16:21:44 +03:00
|
|
|
On *clause.Where
|
2020-08-23 05:40:37 +03:00
|
|
|
}
|
|
|
|
|
2020-03-12 03:39:42 +03:00
|
|
|
// StatementModifier statement modifier interface
|
|
|
|
type StatementModifier interface {
|
|
|
|
ModifyStatement(*Statement)
|
2020-01-30 10:14:48 +03:00
|
|
|
}
|
|
|
|
|
2021-05-31 05:08:06 +03:00
|
|
|
// WriteString write string
|
2020-03-09 12:07:00 +03:00
|
|
|
func (stmt *Statement) WriteString(str string) (int, error) {
|
|
|
|
return stmt.SQL.WriteString(str)
|
2020-01-29 14:22:44 +03:00
|
|
|
}
|
|
|
|
|
2021-05-31 05:08:06 +03:00
|
|
|
// WriteByte write byte
|
2020-03-09 12:07:00 +03:00
|
|
|
func (stmt *Statement) WriteByte(c byte) error {
|
2020-01-30 10:14:48 +03:00
|
|
|
return stmt.SQL.WriteByte(c)
|
|
|
|
}
|
|
|
|
|
2020-03-08 18:30:16 +03:00
|
|
|
// WriteQuoted write quoted value
|
2020-07-16 06:27:04 +03:00
|
|
|
func (stmt *Statement) WriteQuoted(value interface{}) {
|
2020-03-08 18:30:16 +03:00
|
|
|
stmt.QuoteTo(&stmt.SQL, value)
|
2020-01-29 14:22:44 +03:00
|
|
|
}
|
|
|
|
|
2020-03-08 18:30:16 +03:00
|
|
|
// QuoteTo write quoted value to writer
|
2020-06-08 08:45:41 +03:00
|
|
|
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
2021-09-29 09:02:35 +03:00
|
|
|
write := func(raw bool, str string) {
|
|
|
|
if raw {
|
|
|
|
writer.WriteString(str)
|
|
|
|
} else {
|
|
|
|
stmt.DB.Dialector.QuoteTo(writer, str)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-02-02 09:40:44 +03:00
|
|
|
switch v := field.(type) {
|
|
|
|
case clause.Table:
|
2020-02-07 18:45:35 +03:00
|
|
|
if v.Name == clause.CurrentTable {
|
2020-07-17 16:19:11 +03:00
|
|
|
if stmt.TableExpr != nil {
|
|
|
|
stmt.TableExpr.Build(stmt)
|
2020-07-10 16:11:28 +03:00
|
|
|
} else {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, stmt.Table)
|
2020-07-10 16:11:28 +03:00
|
|
|
}
|
2020-02-04 04:51:19 +03:00
|
|
|
} else {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, v.Name)
|
2020-02-04 04:51:19 +03:00
|
|
|
}
|
2020-02-04 03:56:15 +03:00
|
|
|
|
2020-02-02 09:40:44 +03:00
|
|
|
if v.Alias != "" {
|
2020-08-29 17:09:07 +03:00
|
|
|
writer.WriteByte(' ')
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, v.Alias)
|
2020-02-02 09:40:44 +03:00
|
|
|
}
|
|
|
|
case clause.Column:
|
|
|
|
if v.Table != "" {
|
2020-02-04 03:56:15 +03:00
|
|
|
if v.Table == clause.CurrentTable {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, stmt.Table)
|
2020-02-04 03:56:15 +03:00
|
|
|
} else {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, v.Table)
|
2020-02-04 03:56:15 +03:00
|
|
|
}
|
2020-03-08 18:30:16 +03:00
|
|
|
writer.WriteByte('.')
|
2020-02-02 09:40:44 +03:00
|
|
|
}
|
|
|
|
|
2020-02-04 03:56:15 +03:00
|
|
|
if v.Name == clause.PrimaryKey {
|
2020-07-24 03:32:50 +03:00
|
|
|
if stmt.Schema == nil {
|
2020-07-26 05:03:58 +03:00
|
|
|
stmt.DB.AddError(ErrModelValueRequired)
|
2020-07-24 03:32:50 +03:00
|
|
|
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
2020-06-02 06:30:21 +03:00
|
|
|
} else if len(stmt.Schema.DBNames) > 0 {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, stmt.Schema.DBNames[0])
|
2020-02-04 03:56:15 +03:00
|
|
|
}
|
|
|
|
} else {
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, v.Name)
|
2020-02-04 03:56:15 +03:00
|
|
|
}
|
2020-02-05 06:14:58 +03:00
|
|
|
|
2020-02-02 09:40:44 +03:00
|
|
|
if v.Alias != "" {
|
2020-03-08 18:30:16 +03:00
|
|
|
writer.WriteString(" AS ")
|
2021-09-29 09:02:35 +03:00
|
|
|
write(v.Raw, v.Alias)
|
2020-02-02 09:40:44 +03:00
|
|
|
}
|
2020-07-09 04:03:48 +03:00
|
|
|
case []clause.Column:
|
|
|
|
writer.WriteByte('(')
|
|
|
|
for idx, d := range v {
|
|
|
|
if idx > 0 {
|
|
|
|
writer.WriteString(",")
|
|
|
|
}
|
|
|
|
stmt.QuoteTo(writer, d)
|
|
|
|
}
|
|
|
|
writer.WriteByte(')')
|
2021-08-19 09:33:18 +03:00
|
|
|
case clause.Expr:
|
|
|
|
v.Build(stmt)
|
2020-03-12 08:05:22 +03:00
|
|
|
case string:
|
|
|
|
stmt.DB.Dialector.QuoteTo(writer, v)
|
2020-05-14 07:19:12 +03:00
|
|
|
case []string:
|
|
|
|
writer.WriteByte('(')
|
|
|
|
for idx, d := range v {
|
2020-06-08 08:45:41 +03:00
|
|
|
if idx > 0 {
|
2020-05-14 07:19:12 +03:00
|
|
|
writer.WriteString(",")
|
|
|
|
}
|
|
|
|
stmt.DB.Dialector.QuoteTo(writer, d)
|
|
|
|
}
|
|
|
|
writer.WriteByte(')')
|
2020-02-02 09:40:44 +03:00
|
|
|
default:
|
2020-03-08 18:30:16 +03:00
|
|
|
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
2020-02-02 09:40:44 +03:00
|
|
|
}
|
2020-03-08 18:30:16 +03:00
|
|
|
}
|
2020-02-02 09:40:44 +03:00
|
|
|
|
2020-03-08 18:30:16 +03:00
|
|
|
// Quote returns quoted value
|
2020-06-08 08:45:41 +03:00
|
|
|
func (stmt *Statement) Quote(field interface{}) string {
|
2020-03-08 18:30:16 +03:00
|
|
|
var builder strings.Builder
|
|
|
|
stmt.QuoteTo(&builder, field)
|
|
|
|
return builder.String()
|
2020-01-30 10:14:48 +03:00
|
|
|
}
|
|
|
|
|
2021-05-31 05:08:06 +03:00
|
|
|
// AddVar add var
|
2020-03-09 12:07:00 +03:00
|
|
|
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
2020-01-29 22:03:06 +03:00
|
|
|
for idx, v := range vars {
|
|
|
|
if idx > 0 {
|
2020-03-09 12:07:00 +03:00
|
|
|
writer.WriteByte(',')
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
|
2020-02-07 18:45:35 +03:00
|
|
|
switch v := v.(type) {
|
|
|
|
case sql.NamedArg:
|
2020-07-10 07:28:24 +03:00
|
|
|
stmt.Vars = append(stmt.Vars, v.Value)
|
2020-02-22 15:57:29 +03:00
|
|
|
case clause.Column, clause.Table:
|
2020-03-09 12:07:00 +03:00
|
|
|
stmt.QuoteTo(writer, v)
|
2020-08-27 10:03:57 +03:00
|
|
|
case Valuer:
|
2021-11-29 11:19:06 +03:00
|
|
|
reflectValue := reflect.ValueOf(v)
|
|
|
|
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
|
|
|
|
stmt.AddVar(writer, nil)
|
|
|
|
} else {
|
|
|
|
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
|
|
|
}
|
2020-02-22 15:57:29 +03:00
|
|
|
case clause.Expr:
|
2020-12-15 10:35:11 +03:00
|
|
|
v.Build(stmt)
|
2021-03-19 08:21:43 +03:00
|
|
|
case *clause.Expr:
|
|
|
|
v.Build(stmt)
|
2022-01-28 13:48:32 +03:00
|
|
|
case clause.Expression:
|
|
|
|
v.Build(stmt)
|
2020-05-30 16:05:27 +03:00
|
|
|
case driver.Valuer:
|
|
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
2020-06-18 04:32:31 +03:00
|
|
|
case []byte:
|
|
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
2020-02-07 18:45:35 +03:00
|
|
|
case []interface{}:
|
|
|
|
if len(v) > 0 {
|
2020-03-09 12:07:00 +03:00
|
|
|
writer.WriteByte('(')
|
|
|
|
stmt.AddVar(writer, v...)
|
|
|
|
writer.WriteByte(')')
|
2020-01-29 22:03:06 +03:00
|
|
|
} else {
|
2020-03-09 12:07:00 +03:00
|
|
|
writer.WriteString("(NULL)")
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
2020-06-01 16:26:23 +03:00
|
|
|
case *DB:
|
2020-11-17 10:41:17 +03:00
|
|
|
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
2021-02-09 13:56:13 +03:00
|
|
|
if v.Statement.SQL.Len() > 0 {
|
|
|
|
var (
|
|
|
|
vars = subdb.Statement.Vars
|
|
|
|
sql = v.Statement.SQL.String()
|
|
|
|
)
|
|
|
|
|
|
|
|
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
|
|
|
for _, vv := range vars {
|
|
|
|
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
|
|
|
bindvar := strings.Builder{}
|
|
|
|
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
|
|
|
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
subdb.Statement.SQL.Reset()
|
|
|
|
subdb.Statement.Vars = stmt.Vars
|
|
|
|
if strings.Contains(sql, "@") {
|
|
|
|
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
|
|
|
} else {
|
|
|
|
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
|
|
|
subdb.callbacks.Query().Execute(subdb)
|
|
|
|
}
|
|
|
|
|
2020-06-01 17:31:50 +03:00
|
|
|
writer.WriteString(subdb.Statement.SQL.String())
|
|
|
|
stmt.Vars = subdb.Statement.Vars
|
2020-02-07 18:45:35 +03:00
|
|
|
default:
|
2020-05-23 11:08:50 +03:00
|
|
|
switch rv := reflect.ValueOf(v); rv.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
if rv.Len() == 0 {
|
|
|
|
writer.WriteString("(NULL)")
|
2022-01-12 08:11:40 +03:00
|
|
|
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
|
|
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
2020-05-23 11:08:50 +03:00
|
|
|
} else {
|
|
|
|
writer.WriteByte('(')
|
|
|
|
for i := 0; i < rv.Len(); i++ {
|
|
|
|
if i > 0 {
|
|
|
|
writer.WriteByte(',')
|
|
|
|
}
|
|
|
|
stmt.AddVar(writer, rv.Index(i).Interface())
|
|
|
|
}
|
|
|
|
writer.WriteByte(')')
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
|
|
|
}
|
2020-01-29 14:22:44 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddClause add clause
|
2020-02-03 05:40:03 +03:00
|
|
|
func (stmt *Statement) AddClause(v clause.Interface) {
|
2020-03-12 03:39:42 +03:00
|
|
|
if optimizer, ok := v.(StatementModifier); ok {
|
|
|
|
optimizer.ModifyStatement(stmt)
|
2020-06-06 17:52:08 +03:00
|
|
|
} else {
|
2020-06-14 06:46:17 +03:00
|
|
|
name := v.Name()
|
2020-07-16 06:27:04 +03:00
|
|
|
c := stmt.Clauses[name]
|
2020-06-14 06:46:17 +03:00
|
|
|
c.Name = name
|
2020-06-06 17:52:08 +03:00
|
|
|
v.MergeClause(&c)
|
2020-06-14 06:46:17 +03:00
|
|
|
stmt.Clauses[name] = c
|
2020-01-30 10:14:48 +03:00
|
|
|
}
|
2020-01-29 14:22:44 +03:00
|
|
|
}
|
2020-01-29 22:03:06 +03:00
|
|
|
|
2020-02-03 05:40:03 +03:00
|
|
|
// AddClauseIfNotExists add clause if not exists
|
|
|
|
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
2020-06-06 17:52:08 +03:00
|
|
|
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
|
2020-02-07 18:45:35 +03:00
|
|
|
stmt.AddClause(v)
|
2020-02-03 05:40:03 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-08 06:38:51 +03:00
|
|
|
// BuildCondition build condition
|
2020-11-10 13:38:24 +03:00
|
|
|
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
|
2020-07-10 07:28:24 +03:00
|
|
|
if s, ok := query.(string); ok {
|
2020-06-01 05:02:20 +03:00
|
|
|
// if it is a number, then treats it as primary key
|
2020-07-10 07:28:24 +03:00
|
|
|
if _, err := strconv.Atoi(s); err != nil {
|
|
|
|
if s == "" && len(args) == 0 {
|
2020-11-10 13:38:24 +03:00
|
|
|
return nil
|
2021-10-08 06:16:58 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
|
2020-06-05 16:23:20 +03:00
|
|
|
// looks like a where condition
|
2020-07-10 07:28:24 +03:00
|
|
|
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
2021-10-08 06:16:58 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
if len(args) > 0 && strings.Contains(s, "@") {
|
2020-07-10 07:28:24 +03:00
|
|
|
// looks like a named query
|
|
|
|
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
|
2021-10-08 06:16:58 +03:00
|
|
|
}
|
|
|
|
|
2021-11-23 12:11:52 +03:00
|
|
|
if strings.Contains(strings.TrimSpace(s), " ") {
|
|
|
|
// looks like a where condition
|
|
|
|
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
|
|
|
}
|
|
|
|
|
2021-10-08 06:16:58 +03:00
|
|
|
if len(args) == 1 {
|
2020-07-10 07:28:24 +03:00
|
|
|
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
|
2020-06-01 05:02:20 +03:00
|
|
|
}
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-11-10 13:38:24 +03:00
|
|
|
conds := make([]clause.Expression, 0, 4)
|
2020-01-29 22:03:06 +03:00
|
|
|
args = append([]interface{}{query}, args...)
|
2020-12-30 12:42:27 +03:00
|
|
|
for idx, arg := range args {
|
2020-01-29 22:03:06 +03:00
|
|
|
if valuer, ok := arg.(driver.Valuer); ok {
|
|
|
|
arg, _ = valuer.Value()
|
|
|
|
}
|
|
|
|
|
|
|
|
switch v := arg.(type) {
|
2020-01-30 10:14:48 +03:00
|
|
|
case clause.Expression:
|
2020-05-28 08:12:56 +03:00
|
|
|
conds = append(conds, v)
|
2020-06-19 20:55:30 +03:00
|
|
|
case *DB:
|
|
|
|
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
|
|
|
if where, ok := cs.Expression.(clause.Where); ok {
|
2021-01-20 13:24:05 +03:00
|
|
|
if len(where.Exprs) == 1 {
|
|
|
|
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
2021-03-08 05:46:43 +03:00
|
|
|
where.Exprs[0] = clause.AndConditions(orConds)
|
2021-01-20 13:24:05 +03:00
|
|
|
}
|
|
|
|
}
|
2020-06-19 20:55:30 +03:00
|
|
|
conds = append(conds, clause.And(where.Exprs...))
|
|
|
|
} else if cs.Expression != nil {
|
|
|
|
conds = append(conds, cs.Expression)
|
|
|
|
}
|
|
|
|
}
|
2020-01-29 22:03:06 +03:00
|
|
|
case map[interface{}]interface{}:
|
|
|
|
for i, j := range v {
|
2020-05-28 08:12:56 +03:00
|
|
|
conds = append(conds, clause.Eq{Column: i, Value: j})
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
case map[string]string:
|
2022-01-06 10:02:53 +03:00
|
|
|
keys := make([]string, 0, len(v))
|
2020-07-05 07:23:45 +03:00
|
|
|
for i := range v {
|
|
|
|
keys = append(keys, i)
|
|
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
|
|
|
|
for _, key := range keys {
|
|
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
case map[string]interface{}:
|
2022-01-06 10:02:53 +03:00
|
|
|
keys := make([]string, 0, len(v))
|
2020-07-05 07:23:45 +03:00
|
|
|
for i := range v {
|
|
|
|
keys = append(keys, i)
|
|
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
|
|
|
|
for _, key := range keys {
|
|
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
2020-07-05 06:53:10 +03:00
|
|
|
switch reflectValue.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
2020-09-24 10:00:13 +03:00
|
|
|
if _, ok := v[key].(driver.Valuer); ok {
|
|
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
|
|
|
} else if _, ok := v[key].(Valuer); ok {
|
|
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
|
|
|
} else {
|
2021-04-19 16:03:39 +03:00
|
|
|
// optimize reflect value length
|
2021-04-14 08:00:54 +03:00
|
|
|
valueLen := reflectValue.Len()
|
|
|
|
values := make([]interface{}, valueLen)
|
|
|
|
for i := 0; i < valueLen; i++ {
|
2020-09-24 10:00:13 +03:00
|
|
|
values[i] = reflectValue.Index(i).Interface()
|
|
|
|
}
|
2020-07-05 06:53:10 +03:00
|
|
|
|
2020-09-24 10:00:13 +03:00
|
|
|
conds = append(conds, clause.IN{Column: key, Values: values})
|
|
|
|
}
|
2020-07-05 06:53:10 +03:00
|
|
|
default:
|
2020-07-05 07:23:45 +03:00
|
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
2020-07-05 06:53:10 +03:00
|
|
|
}
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
default:
|
2021-01-05 16:01:16 +03:00
|
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
2021-03-10 14:46:59 +03:00
|
|
|
for reflectValue.Kind() == reflect.Ptr {
|
|
|
|
reflectValue = reflectValue.Elem()
|
|
|
|
}
|
|
|
|
|
2021-01-05 16:01:16 +03:00
|
|
|
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
|
|
|
selectedColumns := map[string]bool{}
|
|
|
|
if idx == 0 {
|
|
|
|
for _, v := range args[1:] {
|
|
|
|
if vs, ok := v.(string); ok {
|
|
|
|
selectedColumns[vs] = true
|
2020-05-28 08:12:56 +03:00
|
|
|
}
|
|
|
|
}
|
2021-01-05 16:01:16 +03:00
|
|
|
}
|
|
|
|
restricted := len(selectedColumns) != 0
|
2021-01-05 13:01:51 +03:00
|
|
|
|
2021-01-05 16:01:16 +03:00
|
|
|
switch reflectValue.Kind() {
|
|
|
|
case reflect.Struct:
|
|
|
|
for _, field := range s.Fields {
|
|
|
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
|
|
|
if selected || (!restricted && field.Readable) {
|
|
|
|
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
|
|
|
|
if field.DBName != "" {
|
|
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
|
|
|
} else if field.DataType != "" {
|
|
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
for i := 0; i < reflectValue.Len(); i++ {
|
2020-05-28 11:10:10 +03:00
|
|
|
for _, field := range s.Fields {
|
2020-12-30 12:42:27 +03:00
|
|
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
|
|
|
if selected || (!restricted && field.Readable) {
|
2021-01-05 16:01:16 +03:00
|
|
|
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
|
2020-08-13 13:09:04 +03:00
|
|
|
if field.DBName != "" {
|
2020-09-03 13:42:32 +03:00
|
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
2020-08-13 13:09:04 +03:00
|
|
|
} else if field.DataType != "" {
|
2020-09-03 13:42:32 +03:00
|
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
2020-06-27 03:04:12 +03:00
|
|
|
}
|
2020-05-28 11:10:10 +03:00
|
|
|
}
|
2020-05-28 08:12:56 +03:00
|
|
|
}
|
|
|
|
}
|
2021-01-05 16:01:16 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if restricted {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
} else if !reflectValue.IsValid() {
|
|
|
|
stmt.AddError(ErrInvalidData)
|
|
|
|
} else if len(conds) == 0 {
|
|
|
|
if len(args) == 1 {
|
|
|
|
switch reflectValue.Kind() {
|
2020-07-05 06:53:10 +03:00
|
|
|
case reflect.Slice, reflect.Array:
|
2021-04-19 16:03:39 +03:00
|
|
|
// optimize reflect value length
|
2021-04-14 08:00:54 +03:00
|
|
|
valueLen := reflectValue.Len()
|
|
|
|
values := make([]interface{}, valueLen)
|
|
|
|
for i := 0; i < valueLen; i++ {
|
2021-01-05 16:01:16 +03:00
|
|
|
values[i] = reflectValue.Index(i).Interface()
|
2020-07-05 06:53:10 +03:00
|
|
|
}
|
|
|
|
|
2021-01-05 16:01:16 +03:00
|
|
|
if len(values) > 0 {
|
|
|
|
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
2020-07-05 06:53:10 +03:00
|
|
|
}
|
2021-01-05 16:01:16 +03:00
|
|
|
return conds
|
2020-07-05 06:53:10 +03:00
|
|
|
}
|
2021-01-05 13:01:51 +03:00
|
|
|
}
|
2021-01-05 16:01:16 +03:00
|
|
|
|
|
|
|
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
2020-05-28 08:12:56 +03:00
|
|
|
}
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-11-10 13:38:24 +03:00
|
|
|
return conds
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
|
|
|
|
2020-01-30 10:14:48 +03:00
|
|
|
// Build build sql with clauses names
|
2020-02-03 05:40:03 +03:00
|
|
|
func (stmt *Statement) Build(clauses ...string) {
|
2020-02-02 14:32:27 +03:00
|
|
|
var firstClauseWritten bool
|
2020-01-30 10:14:48 +03:00
|
|
|
|
|
|
|
for _, name := range clauses {
|
|
|
|
if c, ok := stmt.Clauses[name]; ok {
|
2020-02-02 14:32:27 +03:00
|
|
|
if firstClauseWritten {
|
2020-01-30 10:14:48 +03:00
|
|
|
stmt.WriteByte(' ')
|
|
|
|
}
|
|
|
|
|
2020-02-02 14:32:27 +03:00
|
|
|
firstClauseWritten = true
|
2020-02-03 05:40:03 +03:00
|
|
|
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
|
2020-05-29 17:34:35 +03:00
|
|
|
b(c, stmt)
|
2020-02-03 05:40:03 +03:00
|
|
|
} else {
|
|
|
|
c.Build(stmt)
|
|
|
|
}
|
2020-01-30 10:14:48 +03:00
|
|
|
}
|
|
|
|
}
|
2020-01-29 22:03:06 +03:00
|
|
|
}
|
2020-02-20 18:04:03 +03:00
|
|
|
|
|
|
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
2021-10-25 06:26:44 +03:00
|
|
|
return stmt.ParseWithSpecialTableName(value, "")
|
|
|
|
}
|
|
|
|
|
|
|
|
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
|
|
|
|
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
|
2020-07-19 16:30:24 +03:00
|
|
|
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
|
|
|
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
|
|
|
stmt.Table = tables[1]
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2020-02-24 03:51:35 +03:00
|
|
|
stmt.Table = stmt.Schema.Table
|
2020-02-20 18:04:03 +03:00
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
2020-03-09 10:32:55 +03:00
|
|
|
|
2020-05-24 12:24:23 +03:00
|
|
|
func (stmt *Statement) clone() *Statement {
|
|
|
|
newStmt := &Statement{
|
2020-08-13 11:28:21 +03:00
|
|
|
TableExpr: stmt.TableExpr,
|
2020-05-24 12:24:23 +03:00
|
|
|
Table: stmt.Table,
|
|
|
|
Model: stmt.Model,
|
2020-10-26 05:17:25 +03:00
|
|
|
Unscoped: stmt.Unscoped,
|
2020-05-24 12:24:23 +03:00
|
|
|
Dest: stmt.Dest,
|
|
|
|
ReflectValue: stmt.ReflectValue,
|
|
|
|
Clauses: map[string]clause.Clause{},
|
2020-06-05 14:19:08 +03:00
|
|
|
Distinct: stmt.Distinct,
|
2020-05-24 12:24:23 +03:00
|
|
|
Selects: stmt.Selects,
|
|
|
|
Omits: stmt.Omits,
|
|
|
|
Preloads: map[string][]interface{}{},
|
|
|
|
ConnPool: stmt.ConnPool,
|
|
|
|
Schema: stmt.Schema,
|
|
|
|
Context: stmt.Context,
|
|
|
|
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
2020-11-17 12:49:43 +03:00
|
|
|
SkipHooks: stmt.SkipHooks,
|
2020-05-24 12:24:23 +03:00
|
|
|
}
|
|
|
|
|
2021-02-09 12:05:50 +03:00
|
|
|
if stmt.SQL.Len() > 0 {
|
|
|
|
newStmt.SQL.WriteString(stmt.SQL.String())
|
|
|
|
newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
|
|
|
|
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
|
|
|
|
}
|
|
|
|
|
2020-05-24 12:24:23 +03:00
|
|
|
for k, c := range stmt.Clauses {
|
|
|
|
newStmt.Clauses[k] = c
|
|
|
|
}
|
|
|
|
|
|
|
|
for k, p := range stmt.Preloads {
|
|
|
|
newStmt.Preloads[k] = p
|
|
|
|
}
|
|
|
|
|
2020-08-23 05:40:37 +03:00
|
|
|
if len(stmt.Joins) > 0 {
|
|
|
|
newStmt.Joins = make([]join, len(stmt.Joins))
|
|
|
|
copy(newStmt.Joins, stmt.Joins)
|
2020-05-24 12:24:23 +03:00
|
|
|
}
|
|
|
|
|
2021-02-25 17:01:59 +03:00
|
|
|
if len(stmt.scopes) > 0 {
|
|
|
|
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
|
|
|
|
copy(newStmt.scopes, stmt.scopes)
|
2021-02-25 13:49:01 +03:00
|
|
|
}
|
|
|
|
|
2020-06-19 07:38:03 +03:00
|
|
|
stmt.Settings.Range(func(k, v interface{}) bool {
|
|
|
|
newStmt.Settings.Store(k, v)
|
|
|
|
return true
|
|
|
|
})
|
|
|
|
|
2020-05-24 12:24:23 +03:00
|
|
|
return newStmt
|
|
|
|
}
|
2020-06-30 11:53:54 +03:00
|
|
|
|
|
|
|
// SetColumn set column's value
|
2020-12-15 05:39:20 +03:00
|
|
|
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
|
|
|
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
|
|
|
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
2020-06-30 11:53:54 +03:00
|
|
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
|
|
|
v[name] = value
|
2020-12-15 05:39:20 +03:00
|
|
|
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
|
|
|
|
for _, m := range v {
|
|
|
|
m[name] = value
|
|
|
|
}
|
2020-06-30 11:53:54 +03:00
|
|
|
} else if stmt.Schema != nil {
|
|
|
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
2020-10-27 13:14:36 +03:00
|
|
|
destValue := reflect.ValueOf(stmt.Dest)
|
|
|
|
for destValue.Kind() == reflect.Ptr {
|
|
|
|
destValue = destValue.Elem()
|
|
|
|
}
|
|
|
|
|
|
|
|
if stmt.ReflectValue != destValue {
|
|
|
|
if !destValue.CanAddr() {
|
|
|
|
destValueCanAddr := reflect.New(destValue.Type())
|
|
|
|
destValueCanAddr.Elem().Set(destValue)
|
|
|
|
stmt.Dest = destValueCanAddr.Interface()
|
|
|
|
destValue = destValueCanAddr.Elem()
|
|
|
|
}
|
|
|
|
|
|
|
|
switch destValue.Kind() {
|
|
|
|
case reflect.Struct:
|
|
|
|
field.Set(destValue, value)
|
|
|
|
default:
|
|
|
|
stmt.AddError(ErrInvalidData)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-30 17:47:21 +03:00
|
|
|
switch stmt.ReflectValue.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
2020-12-15 05:39:20 +03:00
|
|
|
if len(fromCallbacks) > 0 {
|
|
|
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
|
|
|
field.Set(stmt.ReflectValue.Index(i), value)
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
|
|
|
|
}
|
2020-06-30 17:47:21 +03:00
|
|
|
case reflect.Struct:
|
2021-05-23 06:21:56 +03:00
|
|
|
if !stmt.ReflectValue.CanAddr() {
|
|
|
|
stmt.AddError(ErrInvalidValue)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2020-06-30 17:47:21 +03:00
|
|
|
field.Set(stmt.ReflectValue, value)
|
|
|
|
}
|
2020-06-30 11:53:54 +03:00
|
|
|
} else {
|
|
|
|
stmt.AddError(ErrInvalidField)
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
stmt.AddError(ErrInvalidField)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Changed check model changed or not when updating
|
|
|
|
func (stmt *Statement) Changed(fields ...string) bool {
|
2020-10-27 13:14:36 +03:00
|
|
|
modelValue := stmt.ReflectValue
|
2020-06-30 17:47:21 +03:00
|
|
|
switch modelValue.Kind() {
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
|
|
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
|
|
|
}
|
|
|
|
|
2020-06-30 11:53:54 +03:00
|
|
|
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
|
|
|
changed := func(field *schema.Field) bool {
|
2020-06-30 17:47:21 +03:00
|
|
|
fieldValue, _ := field.ValueOf(modelValue)
|
2020-06-30 11:53:54 +03:00
|
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
|
|
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
|
|
|
if fv, ok := v[field.Name]; ok {
|
|
|
|
return !utils.AssertEqual(fv, fieldValue)
|
|
|
|
} else if fv, ok := v[field.DBName]; ok {
|
|
|
|
return !utils.AssertEqual(fv, fieldValue)
|
|
|
|
}
|
|
|
|
} else {
|
2020-10-27 13:14:36 +03:00
|
|
|
destValue := reflect.ValueOf(stmt.Dest)
|
|
|
|
for destValue.Kind() == reflect.Ptr {
|
|
|
|
destValue = destValue.Elem()
|
|
|
|
}
|
|
|
|
|
|
|
|
changedValue, zero := field.ValueOf(destValue)
|
|
|
|
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
2020-06-30 11:53:54 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(fields) == 0 {
|
|
|
|
for _, field := range stmt.Schema.FieldsByDBName {
|
|
|
|
if changed(field) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for _, name := range fields {
|
|
|
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
|
|
|
if changed(field) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
2021-10-09 05:42:41 +03:00
|
|
|
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`)
|
2021-10-08 12:51:27 +03:00
|
|
|
|
2020-06-30 11:53:54 +03:00
|
|
|
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
|
|
|
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
|
|
|
results := map[string]bool{}
|
|
|
|
notRestricted := false
|
|
|
|
|
|
|
|
// select columns
|
|
|
|
for _, column := range stmt.Selects {
|
2021-02-18 05:53:29 +03:00
|
|
|
if stmt.Schema == nil {
|
|
|
|
results[column] = true
|
|
|
|
} else if column == "*" {
|
2020-06-30 11:53:54 +03:00
|
|
|
notRestricted = true
|
|
|
|
for _, dbName := range stmt.Schema.DBNames {
|
|
|
|
results[dbName] = true
|
|
|
|
}
|
2021-02-18 05:53:29 +03:00
|
|
|
} else if column == clause.Associations {
|
2020-06-30 11:53:54 +03:00
|
|
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
|
|
results[rel.Name] = true
|
|
|
|
}
|
|
|
|
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
|
|
|
results[field.DBName] = true
|
2021-10-08 12:51:27 +03:00
|
|
|
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
|
|
|
|
results[matches[1]] = true
|
2020-06-30 11:53:54 +03:00
|
|
|
} else {
|
|
|
|
results[column] = true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// omit columns
|
|
|
|
for _, omit := range stmt.Omits {
|
2021-02-18 05:53:29 +03:00
|
|
|
if stmt.Schema == nil {
|
|
|
|
results[omit] = false
|
2021-11-08 13:49:49 +03:00
|
|
|
} else if omit == "*" {
|
|
|
|
for _, dbName := range stmt.Schema.DBNames {
|
|
|
|
results[dbName] = false
|
|
|
|
}
|
2021-02-18 05:53:29 +03:00
|
|
|
} else if omit == clause.Associations {
|
|
|
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
|
|
results[rel.Name] = false
|
2020-06-30 11:53:54 +03:00
|
|
|
}
|
|
|
|
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
|
|
|
results[field.DBName] = false
|
2021-10-08 12:51:27 +03:00
|
|
|
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
|
|
|
|
results[matches[1]] = false
|
2020-06-30 11:53:54 +03:00
|
|
|
} else {
|
|
|
|
results[omit] = false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if stmt.Schema != nil {
|
2020-12-06 06:06:52 +03:00
|
|
|
for _, field := range stmt.Schema.FieldsByName {
|
2020-06-30 11:53:54 +03:00
|
|
|
name := field.DBName
|
|
|
|
if name == "" {
|
|
|
|
name = field.Name
|
|
|
|
}
|
|
|
|
|
|
|
|
if requireCreate && !field.Creatable {
|
|
|
|
results[name] = false
|
|
|
|
} else if requireUpdate && !field.Updatable {
|
|
|
|
results[name] = false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return results, !notRestricted && len(stmt.Selects) > 0
|
|
|
|
}
|