gorm/statement.go

374 lines
9.5 KiB
Go
Raw Normal View History

2020-01-29 14:22:44 +03:00
package gorm
import (
"context"
"database/sql"
2020-01-29 22:03:06 +03:00
"database/sql/driver"
2020-01-29 14:22:44 +03:00
"fmt"
2020-02-23 14:41:29 +03:00
"reflect"
2020-01-29 22:03:06 +03:00
"strconv"
2020-01-29 14:22:44 +03:00
"strings"
"sync"
"github.com/jinzhu/gorm/clause"
2020-02-02 09:40:44 +03:00
"github.com/jinzhu/gorm/schema"
2020-01-29 14:22:44 +03:00
)
// Statement statement
type Statement struct {
2020-03-09 15:37:01 +03:00
*DB
2020-03-03 09:18:12 +03:00
Table string
Model interface{}
2020-05-29 02:35:45 +03:00
Unscoped bool
2020-03-03 09:18:12 +03:00
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
Selects []string // selected columns
Omits []string // omit columns
2020-04-15 04:14:24 +03:00
Joins map[string][]interface{}
Preloads map[string][]interface{}
2020-03-03 09:18:12 +03:00
Settings sync.Map
2020-03-09 08:10:48 +03:00
ConnPool ConnPool
2020-03-03 09:18:12 +03:00
Schema *schema.Schema
2020-03-09 08:10:48 +03:00
Context context.Context
2020-03-03 09:18:12 +03:00
RaiseErrorOnNotFound bool
2020-05-30 12:34:22 +03:00
DisableUpdateTime bool
2020-03-09 10:32:55 +03:00
SQL strings.Builder
Vars []interface{}
NamedVars []sql.NamedArg
2020-05-28 08:12:56 +03:00
attrs []interface{}
assigns []interface{}
2020-01-29 14:22:44 +03:00
}
2020-03-12 03:39:42 +03:00
// StatementModifier statement modifier interface
type StatementModifier interface {
ModifyStatement(*Statement)
2020-01-30 10:14:48 +03:00
}
2020-01-29 14:22:44 +03:00
// Write write string
2020-03-09 12:07:00 +03:00
func (stmt *Statement) WriteString(str string) (int, error) {
return stmt.SQL.WriteString(str)
2020-01-29 14:22:44 +03:00
}
2020-01-30 10:14:48 +03:00
// Write write string
2020-03-09 12:07:00 +03:00
func (stmt *Statement) WriteByte(c byte) error {
2020-01-30 10:14:48 +03:00
return stmt.SQL.WriteByte(c)
}
2020-03-08 18:30:16 +03:00
// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) error {
stmt.QuoteTo(&stmt.SQL, value)
return nil
2020-01-29 14:22:44 +03:00
}
2020-03-08 18:30:16 +03:00
// QuoteTo write quoted value to writer
2020-03-09 12:07:00 +03:00
func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
2020-02-02 09:40:44 +03:00
switch v := field.(type) {
case clause.Table:
2020-02-07 18:45:35 +03:00
if v.Name == clause.CurrentTable {
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
2020-05-24 06:32:59 +03:00
} else if v.Raw {
writer.WriteString(v.Name)
2020-02-04 04:51:19 +03:00
} else {
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, v.Name)
2020-02-04 04:51:19 +03:00
}
2020-02-04 03:56:15 +03:00
2020-02-02 09:40:44 +03:00
if v.Alias != "" {
2020-03-08 18:30:16 +03:00
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
2020-02-02 09:40:44 +03:00
}
case clause.Column:
if v.Table != "" {
2020-02-04 03:56:15 +03:00
if v.Table == clause.CurrentTable {
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
2020-02-04 03:56:15 +03:00
} else {
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, v.Table)
2020-02-04 03:56:15 +03:00
}
2020-03-08 18:30:16 +03:00
writer.WriteByte('.')
2020-02-02 09:40:44 +03:00
}
2020-02-04 03:56:15 +03:00
if v.Name == clause.PrimaryKey {
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
2020-02-04 03:56:15 +03:00
}
2020-05-24 06:32:59 +03:00
} else if v.Raw {
writer.WriteString(v.Name)
2020-02-04 03:56:15 +03:00
} else {
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, v.Name)
2020-02-04 03:56:15 +03:00
}
2020-02-05 06:14:58 +03:00
2020-02-02 09:40:44 +03:00
if v.Alias != "" {
2020-03-08 18:30:16 +03:00
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
2020-02-02 09:40:44 +03:00
}
2020-03-12 08:05:22 +03:00
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
2020-05-14 07:19:12 +03:00
case []string:
writer.WriteByte('(')
for idx, d := range v {
if idx != 0 {
writer.WriteString(",")
}
stmt.DB.Dialector.QuoteTo(writer, d)
}
writer.WriteByte(')')
2020-02-02 09:40:44 +03:00
default:
2020-03-08 18:30:16 +03:00
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
2020-02-02 09:40:44 +03:00
}
2020-03-08 18:30:16 +03:00
}
2020-02-02 09:40:44 +03:00
2020-03-08 18:30:16 +03:00
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
2020-01-30 10:14:48 +03:00
}
2020-01-29 14:22:44 +03:00
// Write write string
2020-03-09 12:07:00 +03:00
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
2020-01-29 22:03:06 +03:00
for idx, v := range vars {
if idx > 0 {
2020-03-09 12:07:00 +03:00
writer.WriteByte(',')
2020-01-29 22:03:06 +03:00
}
2020-02-07 18:45:35 +03:00
switch v := v.(type) {
case sql.NamedArg:
if len(v.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, v)
2020-03-09 12:07:00 +03:00
writer.WriteByte('@')
writer.WriteString(v.Name)
2020-02-07 18:45:35 +03:00
} else {
stmt.Vars = append(stmt.Vars, v.Value)
2020-03-09 12:59:54 +03:00
stmt.DB.Dialector.BindVarTo(writer, stmt, v.Value)
2020-02-07 18:45:35 +03:00
}
2020-02-22 15:57:29 +03:00
case clause.Column, clause.Table:
2020-03-09 12:07:00 +03:00
stmt.QuoteTo(writer, v)
2020-02-22 15:57:29 +03:00
case clause.Expr:
2020-03-09 12:07:00 +03:00
writer.WriteString(v.SQL)
2020-02-22 15:57:29 +03:00
stmt.Vars = append(stmt.Vars, v.Vars...)
2020-05-30 16:05:27 +03:00
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
2020-02-07 18:45:35 +03:00
case []interface{}:
if len(v) > 0 {
2020-03-09 12:07:00 +03:00
writer.WriteByte('(')
stmt.AddVar(writer, v...)
writer.WriteByte(')')
2020-01-29 22:03:06 +03:00
} else {
2020-03-09 12:07:00 +03:00
writer.WriteString("(NULL)")
2020-01-29 22:03:06 +03:00
}
2020-02-07 18:45:35 +03:00
default:
2020-05-23 11:08:50 +03:00
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
writer.WriteString("(NULL)")
} else {
writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ {
if i > 0 {
writer.WriteByte(',')
}
stmt.AddVar(writer, rv.Index(i).Interface())
}
writer.WriteByte(')')
}
default:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
}
2020-01-29 14:22:44 +03:00
}
}
}
// AddClause add clause
2020-02-03 05:40:03 +03:00
func (stmt *Statement) AddClause(v clause.Interface) {
2020-03-12 03:39:42 +03:00
if optimizer, ok := v.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
2020-01-30 10:14:48 +03:00
}
2020-02-07 18:45:35 +03:00
c, ok := stmt.Clauses[v.Name()]
if !ok {
2020-01-30 10:14:48 +03:00
c.Name = v.Name()
}
2020-02-07 18:45:35 +03:00
v.MergeClause(&c)
2020-01-30 10:14:48 +03:00
stmt.Clauses[v.Name()] = c
2020-01-29 14:22:44 +03:00
}
2020-01-29 22:03:06 +03:00
2020-02-03 05:40:03 +03:00
// AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
2020-05-31 15:21:52 +03:00
if c, ok := stmt.Clauses[v.Name()]; !ok && c.Expression == nil {
2020-02-07 18:45:35 +03:00
stmt.AddClause(v)
2020-02-03 05:40:03 +03:00
}
}
2020-01-30 10:14:48 +03:00
// BuildCondtion build condition
2020-05-28 08:12:56 +03:00
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conds []clause.Expression) {
2020-01-29 22:03:06 +03:00
if sql, ok := query.(string); ok {
2020-06-01 05:02:20 +03:00
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(sql); err != nil {
if sql == "" && len(args) == 0 {
return
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
} else if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}}
}
2020-01-29 22:03:06 +03:00
}
}
args = append([]interface{}{query}, args...)
for _, arg := range args {
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
2020-01-30 10:14:48 +03:00
case clause.Expression:
2020-05-28 08:12:56 +03:00
conds = append(conds, v)
2020-01-29 22:03:06 +03:00
case *DB:
if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
2020-05-28 08:12:56 +03:00
conds = append(conds, cs.Expression)
2020-01-29 22:03:06 +03:00
}
}
case map[interface{}]interface{}:
for i, j := range v {
2020-05-28 08:12:56 +03:00
conds = append(conds, clause.Eq{Column: i, Value: j})
2020-01-29 22:03:06 +03:00
}
case map[string]string:
for i, j := range v {
2020-05-28 08:12:56 +03:00
conds = append(conds, clause.Eq{Column: i, Value: j})
2020-01-29 22:03:06 +03:00
}
case map[string]interface{}:
for i, j := range v {
2020-05-28 08:12:56 +03:00
conds = append(conds, clause.Eq{Column: i, Value: j})
2020-01-29 22:03:06 +03:00
}
default:
2020-05-28 08:12:56 +03:00
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
switch reflectValue.Kind() {
case reflect.Struct:
2020-05-28 11:10:10 +03:00
for _, field := range s.Fields {
2020-05-28 08:12:56 +03:00
if v, isZero := field.ValueOf(reflectValue); !isZero {
2020-05-28 11:10:10 +03:00
if field.DBName == "" {
2020-05-31 15:21:52 +03:00
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
2020-05-28 11:10:10 +03:00
} else {
2020-05-31 15:21:52 +03:00
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
2020-05-28 11:10:10 +03:00
}
2020-05-28 08:12:56 +03:00
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
2020-05-28 11:10:10 +03:00
for _, field := range s.Fields {
2020-05-28 08:12:56 +03:00
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
2020-05-28 11:10:10 +03:00
if field.DBName == "" {
2020-05-31 15:21:52 +03:00
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v})
2020-05-28 11:10:10 +03:00
} else {
2020-05-31 15:21:52 +03:00
conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v})
2020-05-28 11:10:10 +03:00
}
2020-05-28 08:12:56 +03:00
}
}
}
}
2020-06-01 05:02:20 +03:00
} else if len(conds) == 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
2020-05-28 08:12:56 +03:00
}
2020-01-29 22:03:06 +03:00
}
}
2020-05-28 08:12:56 +03:00
return
2020-01-29 22:03:06 +03:00
}
2020-01-30 10:14:48 +03:00
// Build build sql with clauses names
2020-02-03 05:40:03 +03:00
func (stmt *Statement) Build(clauses ...string) {
2020-02-02 14:32:27 +03:00
var firstClauseWritten bool
2020-01-30 10:14:48 +03:00
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
2020-02-02 14:32:27 +03:00
if firstClauseWritten {
2020-01-30 10:14:48 +03:00
stmt.WriteByte(' ')
}
2020-02-02 14:32:27 +03:00
firstClauseWritten = true
2020-02-03 05:40:03 +03:00
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
2020-05-29 17:34:35 +03:00
b(c, stmt)
2020-02-03 05:40:03 +03:00
} else {
c.Build(stmt)
}
2020-01-30 10:14:48 +03:00
}
}
2020-02-02 14:32:27 +03:00
// TODO handle named vars
2020-01-29 22:03:06 +03:00
}
2020-02-20 18:04:03 +03:00
func (stmt *Statement) Parse(value interface{}) (err error) {
2020-02-24 03:51:35 +03:00
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
stmt.Table = stmt.Schema.Table
2020-02-20 18:04:03 +03:00
}
return err
}
2020-03-09 10:32:55 +03:00
2020-05-24 12:24:23 +03:00
func (stmt *Statement) clone() *Statement {
newStmt := &Statement{
DB: stmt.DB,
Table: stmt.Table,
Model: stmt.Model,
Dest: stmt.Dest,
ReflectValue: stmt.ReflectValue,
Clauses: map[string]clause.Clause{},
Selects: stmt.Selects,
Omits: stmt.Omits,
Joins: map[string][]interface{}{},
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
}
for k, c := range stmt.Clauses {
newStmt.Clauses[k] = c
}
for k, p := range stmt.Preloads {
newStmt.Preloads[k] = p
}
for k, j := range stmt.Joins {
newStmt.Joins[k] = j
}
return newStmt
}
2020-03-09 10:32:55 +03:00
func (stmt *Statement) reinit() {
2020-05-24 06:32:59 +03:00
// stmt.Table = ""
// stmt.Model = nil
// stmt.Selects = nil
// stmt.Omits = nil
// stmt.ConnPool = stmt.DB.Config.ConnPool
// stmt.Context = context.Background()
// stmt.RaiseErrorOnNotFound = false
// for k := range stmt.Clauses {
// delete(stmt.Clauses, k)
// }
// for k := range stmt.Joins {
// delete(stmt.Joins, k)
// }
// for k := range stmt.Preloads {
// delete(stmt.Preloads, k)
// }
// stmt.Settings.Range(func(k, _ interface{}) bool {
// stmt.Settings.Delete(k)
// return true
// })
2020-03-09 10:32:55 +03:00
2020-05-28 08:12:56 +03:00
// stmt.Schema = nil
2020-03-09 10:32:55 +03:00
stmt.SQL.Reset()
stmt.Vars = nil
stmt.NamedVars = nil
}