gorm/clause/expression.go

376 lines
8.0 KiB
Go
Raw Normal View History

2020-01-30 10:14:48 +03:00
package clause
2020-05-30 16:05:27 +03:00
import (
2020-07-10 07:28:24 +03:00
"database/sql"
2020-05-30 16:05:27 +03:00
"database/sql/driver"
2020-09-17 16:52:41 +03:00
"go/ast"
2020-05-30 16:05:27 +03:00
"reflect"
)
2020-05-30 13:50:20 +03:00
2020-01-30 10:14:48 +03:00
// Expression expression interface
type Expression interface {
Build(builder Builder)
}
// NegationExpressionBuilder negation expression builder
type NegationExpressionBuilder interface {
NegationBuild(builder Builder)
}
// Expr raw expression
type Expr struct {
2020-11-03 05:30:05 +03:00
SQL string
Vars []interface{}
WithoutParentheses bool
2020-01-30 10:14:48 +03:00
}
// Build build raw expression
func (expr Expr) Build(builder Builder) {
2020-05-30 13:50:20 +03:00
var (
afterParenthesis bool
idx int
)
2020-03-09 12:07:00 +03:00
for _, v := range []byte(expr.SQL) {
if v == '?' && len(expr.Vars) > idx {
2020-11-03 05:30:05 +03:00
if afterParenthesis || expr.WithoutParentheses {
2020-05-30 16:05:27 +03:00
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
2020-09-08 16:28:04 +03:00
if rv.Len() == 0 {
builder.AddVar(builder, nil)
2020-09-09 05:31:48 +03:00
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
2020-05-30 16:05:27 +03:00
}
2020-05-30 13:50:20 +03:00
}
2020-05-30 16:05:27 +03:00
default:
builder.AddVar(builder, expr.Vars[idx])
2020-05-30 13:50:20 +03:00
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
2020-03-09 12:07:00 +03:00
idx++
} else {
2020-05-30 13:50:20 +03:00
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
2020-03-09 12:07:00 +03:00
builder.WriteByte(v)
}
2020-02-07 18:45:35 +03:00
}
}
2020-07-10 07:28:24 +03:00
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []interface{}
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
2021-03-04 13:40:47 +03:00
idx int
inName bool
afterParenthesis bool
namedMap = make(map[string]interface{}, len(expr.Vars))
2020-07-10 07:28:24 +03:00
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]interface{}:
for k, v := range value {
namedMap[k] = v
}
2020-09-17 16:52:41 +03:00
default:
var appendFieldsToMap func(reflect.Value)
appendFieldsToMap = func(reflectValue reflect.Value) {
reflectValue = reflect.Indirect(reflectValue)
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
if fieldStruct.Anonymous {
appendFieldsToMap(reflectValue.Field(i))
}
}
2020-09-17 16:52:41 +03:00
}
}
}
appendFieldsToMap(reflect.ValueOf(value))
2020-07-10 07:28:24 +03:00
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
2020-07-10 07:28:24 +03:00
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
2021-03-04 13:40:47 +03:00
afterParenthesis = false
2020-07-10 07:28:24 +03:00
builder.WriteByte(v)
} else if v == '?' && len(expr.Vars) > idx {
2021-03-04 13:40:47 +03:00
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
2020-07-10 07:28:24 +03:00
idx++
} else if inName {
name = append(name, v)
} else {
2021-03-04 13:40:47 +03:00
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
2020-07-10 07:28:24 +03:00
builder.WriteByte(v)
}
}
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
2020-07-10 07:28:24 +03:00
}
}
2020-02-07 18:45:35 +03:00
// IN Whether a value is within a set of values
type IN struct {
Column interface{}
Values []interface{}
}
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
2020-03-09 12:07:00 +03:00
builder.WriteString(" IN (NULL)")
2020-02-07 18:45:35 +03:00
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" = ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
2020-02-07 18:45:35 +03:00
default:
2020-03-09 12:07:00 +03:00
builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
2020-02-07 18:45:35 +03:00
}
}
func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) {
case 0:
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteQuoted(in.Column)
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
2020-02-07 18:45:35 +03:00
default:
2020-05-24 17:52:16 +03:00
builder.WriteQuoted(in.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
2020-02-07 18:45:35 +03:00
}
}
// Eq equal to for where
type Eq struct {
Column interface{}
Value interface{}
}
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
2021-05-31 10:25:38 +03:00
switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" IN (")
rv := reflect.ValueOf(eq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
2021-05-31 12:21:27 +03:00
builder.AddVar(builder, rv.Index(i).Interface())
2021-05-31 10:25:38 +03:00
}
builder.WriteByte(')')
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
2020-02-07 18:45:35 +03:00
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq(eq).Build(builder)
2020-02-07 18:45:35 +03:00
}
// Neq not equal to for where
type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
2021-05-31 10:25:38 +03:00
switch neq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" NOT IN (")
rv := reflect.ValueOf(neq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
2021-05-31 12:21:27 +03:00
builder.AddVar(builder, rv.Index(i).Interface())
2021-05-31 10:25:38 +03:00
}
builder.WriteByte(')')
default:
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
2020-02-07 18:45:35 +03:00
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq(neq).Build(builder)
2020-02-07 18:45:35 +03:00
}
// Gt greater than for where
type Gt Eq
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" > ")
builder.AddVar(builder, gt.Value)
2020-02-07 18:45:35 +03:00
}
func (gt Gt) NegationBuild(builder Builder) {
Lte(gt).Build(builder)
2020-02-07 18:45:35 +03:00
}
// Gte greater than or equal to for where
type Gte Eq
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" >= ")
builder.AddVar(builder, gte.Value)
2020-02-07 18:45:35 +03:00
}
func (gte Gte) NegationBuild(builder Builder) {
Lt(gte).Build(builder)
2020-02-07 18:45:35 +03:00
}
// Lt less than for where
type Lt Eq
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" < ")
builder.AddVar(builder, lt.Value)
2020-02-07 18:45:35 +03:00
}
func (lt Lt) NegationBuild(builder Builder) {
Gte(lt).Build(builder)
2020-02-07 18:45:35 +03:00
}
// Lte less than or equal to for where
type Lte Eq
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" <= ")
builder.AddVar(builder, lte.Value)
2020-02-07 18:45:35 +03:00
}
func (lte Lte) NegationBuild(builder Builder) {
Gt(lte).Build(builder)
2020-02-07 18:45:35 +03:00
}
// Like whether string matches regular expression
type Like Eq
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" LIKE ")
builder.AddVar(builder, like.Value)
2020-02-07 18:45:35 +03:00
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
2020-03-09 12:07:00 +03:00
builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value)
2020-02-07 18:45:35 +03:00
}
func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok {
value, _ = valuer.Value()
}
return value == nil || eqNilReflect(value)
}
func eqNilReflect(value interface{}) bool {
reflectValue := reflect.ValueOf(value)
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
}