gorm/statement.go

143 lines
3.3 KiB
Go
Raw Normal View History

2020-01-29 14:22:44 +03:00
package gorm
import (
"bytes"
"context"
"database/sql"
2020-01-29 22:03:06 +03:00
"database/sql/driver"
2020-01-29 14:22:44 +03:00
"fmt"
2020-01-29 22:03:06 +03:00
"strconv"
2020-01-29 14:22:44 +03:00
"strings"
"sync"
"github.com/jinzhu/gorm/clause"
)
// Statement statement
type Statement struct {
2020-01-29 22:03:06 +03:00
Model interface{}
2020-01-29 14:22:44 +03:00
Dest interface{}
2020-01-29 22:03:06 +03:00
Table string
Clauses map[string][]clause.Condition
2020-01-29 14:22:44 +03:00
Settings sync.Map
Context context.Context
DB *DB
StatementBuilder
}
// StatementBuilder statement builder
type StatementBuilder struct {
SQL bytes.Buffer
Vars []interface{}
NamedVars []sql.NamedArg
}
// Write write string
func (stmt Statement) Write(sql ...string) (err error) {
for _, s := range sql {
_, err = stmt.SQL.WriteString(s)
}
return
}
// WriteQuoted write quoted field
func (stmt Statement) WriteQuoted(field interface{}) (err error) {
_, err = stmt.SQL.WriteString(stmt.Quote(field))
return
}
// Write write string
func (stmt Statement) AddVar(vars ...interface{}) string {
2020-01-29 22:03:06 +03:00
var placeholders strings.Builder
for idx, v := range vars {
if idx > 0 {
placeholders.WriteByte(',')
}
2020-01-29 14:22:44 +03:00
if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, namedArg)
2020-01-29 22:03:06 +03:00
placeholders.WriteByte('@')
placeholders.WriteString(namedArg.Name)
} else if arrs, ok := v.([]interface{}); ok {
placeholders.WriteByte('(')
if len(arrs) > 0 {
placeholders.WriteString(stmt.AddVar(arrs...))
} else {
placeholders.WriteString("NULL")
}
placeholders.WriteByte(')')
2020-01-29 14:22:44 +03:00
} else {
2020-01-29 22:03:06 +03:00
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
2020-01-29 14:22:44 +03:00
}
}
2020-01-29 22:03:06 +03:00
return placeholders.String()
2020-01-29 14:22:44 +03:00
}
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) {
return fmt.Sprint(field)
}
// AddClause add clause
func (s Statement) AddClause(clause clause.Interface) {
s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
}
2020-01-29 22:03:06 +03:00
// BuildCondtions build conditions
func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) {
if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err != nil {
query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Condition{clause.Raw{SQL: sql, Values: args}}
}
}
args = append([]interface{}{query}, args...)
for _, arg := range args {
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
case clause.Builder:
conditions = append(conditions, v)
case *DB:
if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conditions = append(conditions, cs...)
}
}
case map[interface{}]interface{}:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
case map[string]string:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
case map[string]interface{}:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
default:
// TODO check is struct
// struct, slice -> ids
}
}
if len(conditions) == 0 {
conditions = append(conditions, clause.ID{Value: args})
}
return conditions
}
func (s Statement) AddError(err error) {
}