gorm/statement.go

286 lines
6.7 KiB
Go
Raw Normal View History

2020-01-29 14:22:44 +03:00
package gorm
import (
"context"
"database/sql"
2020-01-29 22:03:06 +03:00
"database/sql/driver"
2020-01-29 14:22:44 +03:00
"fmt"
2020-02-23 14:41:29 +03:00
"reflect"
2020-01-29 22:03:06 +03:00
"strconv"
2020-01-29 14:22:44 +03:00
"strings"
"sync"
"github.com/jinzhu/gorm/clause"
2020-02-02 09:40:44 +03:00
"github.com/jinzhu/gorm/schema"
2020-01-29 14:22:44 +03:00
)
2020-01-30 10:14:48 +03:00
// Instance db instance
type Instance struct {
Error error
RowsAffected int64
Context context.Context
Statement *Statement
}
2020-02-03 05:40:03 +03:00
func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
2020-02-02 14:32:27 +03:00
if len(clauses) > 0 {
instance.Statement.Build(clauses...)
}
2020-02-07 18:45:35 +03:00
return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
2020-02-02 14:32:27 +03:00
}
2020-01-30 10:14:48 +03:00
// AddError add error to instance
2020-02-03 05:40:03 +03:00
func (inst *Instance) AddError(err error) {
2020-01-30 10:14:48 +03:00
if inst.Error == nil {
inst.Error = err
2020-02-23 14:41:29 +03:00
} else if err != nil {
2020-01-30 10:14:48 +03:00
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
}
}
2020-01-29 14:22:44 +03:00
// Statement statement
type Statement struct {
2020-02-23 14:41:29 +03:00
Table string
Model interface{}
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
Selects []string // selected columns
Omits []string // omit columns
Settings sync.Map
DB *DB
Schema *schema.Schema
2020-01-29 14:22:44 +03:00
2020-01-30 10:14:48 +03:00
// SQL Builder
SQL strings.Builder
2020-01-29 14:22:44 +03:00
Vars []interface{}
NamedVars []sql.NamedArg
}
2020-01-30 10:14:48 +03:00
// StatementOptimizer statement optimizer interface
type StatementOptimizer interface {
2020-02-03 05:40:03 +03:00
OptimizeStatement(*Statement)
2020-01-30 10:14:48 +03:00
}
2020-01-29 14:22:44 +03:00
// Write write string
2020-02-03 05:40:03 +03:00
func (stmt *Statement) Write(sql ...string) (err error) {
2020-01-29 14:22:44 +03:00
for _, s := range sql {
_, err = stmt.SQL.WriteString(s)
}
return
}
2020-01-30 10:14:48 +03:00
// Write write string
2020-02-03 05:40:03 +03:00
func (stmt *Statement) WriteByte(c byte) (err error) {
2020-01-30 10:14:48 +03:00
return stmt.SQL.WriteByte(c)
}
2020-01-29 14:22:44 +03:00
// WriteQuoted write quoted field
2020-02-03 05:40:03 +03:00
func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
2020-01-29 14:22:44 +03:00
_, err = stmt.SQL.WriteString(stmt.Quote(field))
return
}
2020-01-30 10:14:48 +03:00
// Quote returns quoted value
2020-02-02 09:40:44 +03:00
func (stmt Statement) Quote(field interface{}) string {
var str strings.Builder
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[0])
2020-02-02 09:40:44 +03:00
switch v := field.(type) {
case clause.Table:
2020-02-07 18:45:35 +03:00
if v.Name == clause.CurrentTable {
2020-02-04 04:51:19 +03:00
str.WriteString(stmt.Table)
} else {
2020-02-07 18:45:35 +03:00
str.WriteString(v.Name)
2020-02-04 04:51:19 +03:00
}
2020-02-04 03:56:15 +03:00
2020-02-02 09:40:44 +03:00
if v.Alias != "" {
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[1])
2020-02-02 09:40:44 +03:00
str.WriteString(" AS ")
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[0])
2020-02-02 09:40:44 +03:00
str.WriteString(v.Alias)
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[1])
2020-02-02 09:40:44 +03:00
}
case clause.Column:
if v.Table != "" {
2020-02-04 03:56:15 +03:00
if v.Table == clause.CurrentTable {
str.WriteString(stmt.Table)
} else {
str.WriteString(v.Table)
}
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[1])
2020-02-02 09:40:44 +03:00
str.WriteByte('.')
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[0])
2020-02-02 09:40:44 +03:00
}
2020-02-04 03:56:15 +03:00
if v.Name == clause.PrimaryKey {
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
}
} else {
str.WriteString(v.Name)
}
2020-02-05 06:14:58 +03:00
2020-02-02 09:40:44 +03:00
if v.Alias != "" {
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[1])
2020-02-02 09:40:44 +03:00
str.WriteString(" AS ")
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[0])
2020-02-02 09:40:44 +03:00
str.WriteString(v.Alias)
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[1])
2020-02-02 09:40:44 +03:00
}
default:
2020-02-07 18:45:35 +03:00
str.WriteString(fmt.Sprint(field))
2020-02-02 09:40:44 +03:00
}
2020-02-05 06:14:58 +03:00
str.WriteByte(stmt.DB.quoteChars[1])
2020-02-02 09:40:44 +03:00
return str.String()
2020-01-30 10:14:48 +03:00
}
2020-01-29 14:22:44 +03:00
// Write write string
2020-02-03 05:40:03 +03:00
func (stmt *Statement) AddVar(vars ...interface{}) string {
2020-01-29 22:03:06 +03:00
var placeholders strings.Builder
for idx, v := range vars {
if idx > 0 {
placeholders.WriteByte(',')
}
2020-02-07 18:45:35 +03:00
switch v := v.(type) {
case sql.NamedArg:
if len(v.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, v)
placeholders.WriteByte('@')
placeholders.WriteString(v.Name)
} else {
stmt.Vars = append(stmt.Vars, v.Value)
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
}
2020-02-22 15:57:29 +03:00
case clause.Column, clause.Table:
2020-02-07 18:45:35 +03:00
placeholders.WriteString(stmt.Quote(v))
2020-02-22 15:57:29 +03:00
case clause.Expr:
placeholders.WriteString(v.SQL)
stmt.Vars = append(stmt.Vars, v.Vars...)
2020-02-07 18:45:35 +03:00
case []interface{}:
if len(v) > 0 {
2020-02-13 19:09:44 +03:00
placeholders.WriteByte('(')
2020-02-07 18:45:35 +03:00
placeholders.WriteString(stmt.AddVar(v...))
2020-02-13 19:09:44 +03:00
placeholders.WriteByte(')')
2020-01-29 22:03:06 +03:00
} else {
2020-02-13 19:09:44 +03:00
placeholders.WriteString("(NULL)")
2020-01-29 22:03:06 +03:00
}
2020-02-07 18:45:35 +03:00
default:
stmt.Vars = append(stmt.Vars, v)
2020-01-29 22:03:06 +03:00
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
2020-01-29 14:22:44 +03:00
}
}
2020-01-29 22:03:06 +03:00
return placeholders.String()
2020-01-29 14:22:44 +03:00
}
// AddClause add clause
2020-02-03 05:40:03 +03:00
func (stmt *Statement) AddClause(v clause.Interface) {
2020-01-30 10:14:48 +03:00
if optimizer, ok := v.(StatementOptimizer); ok {
optimizer.OptimizeStatement(stmt)
}
2020-02-07 18:45:35 +03:00
c, ok := stmt.Clauses[v.Name()]
if !ok {
2020-01-30 10:14:48 +03:00
c.Name = v.Name()
}
2020-02-07 18:45:35 +03:00
v.MergeClause(&c)
2020-01-30 10:14:48 +03:00
stmt.Clauses[v.Name()] = c
2020-01-29 14:22:44 +03:00
}
2020-01-29 22:03:06 +03:00
2020-02-03 05:40:03 +03:00
// AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
2020-02-07 18:45:35 +03:00
if _, ok := stmt.Clauses[v.Name()]; !ok {
stmt.AddClause(v)
2020-02-03 05:40:03 +03:00
}
}
2020-01-30 10:14:48 +03:00
// BuildCondtion build condition
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
2020-01-29 22:03:06 +03:00
if sql, ok := query.(string); ok {
2020-02-23 14:41:29 +03:00
if i, err := strconv.Atoi(sql); err == nil {
2020-01-29 22:03:06 +03:00
query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
2020-02-07 18:45:35 +03:00
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
2020-01-29 22:03:06 +03:00
}
}
args = append([]interface{}{query}, args...)
for _, arg := range args {
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
2020-01-30 10:14:48 +03:00
case clause.Expression:
2020-01-29 22:03:06 +03:00
conditions = append(conditions, v)
case *DB:
if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
2020-01-30 10:14:48 +03:00
conditions = append(conditions, cs.Expression)
2020-01-29 22:03:06 +03:00
}
}
case map[interface{}]interface{}:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
case map[string]string:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
case map[string]interface{}:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
default:
// TODO check is struct
// struct, slice -> ids
}
}
if len(conditions) == 0 {
2020-02-07 18:45:35 +03:00
conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args})
2020-01-29 22:03:06 +03:00
}
2020-01-30 10:14:48 +03:00
2020-01-29 22:03:06 +03:00
return conditions
}
2020-01-30 10:14:48 +03:00
// Build build sql with clauses names
2020-02-03 05:40:03 +03:00
func (stmt *Statement) Build(clauses ...string) {
2020-02-02 14:32:27 +03:00
var firstClauseWritten bool
2020-01-30 10:14:48 +03:00
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
2020-02-02 14:32:27 +03:00
if firstClauseWritten {
2020-01-30 10:14:48 +03:00
stmt.WriteByte(' ')
}
2020-02-02 14:32:27 +03:00
firstClauseWritten = true
2020-02-03 05:40:03 +03:00
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b.Build(c, stmt)
} else {
c.Build(stmt)
}
2020-01-30 10:14:48 +03:00
}
}
2020-02-02 14:32:27 +03:00
// TODO handle named vars
2020-01-29 22:03:06 +03:00
}
2020-02-20 18:04:03 +03:00
func (stmt *Statement) Parse(value interface{}) (err error) {
2020-02-23 14:41:29 +03:00
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
if stmt.Table == "" {
stmt.Table = stmt.Schema.Table
}
2020-02-20 18:04:03 +03:00
}
return err
}