Implement build conditions

This commit is contained in:
Jinzhu 2020-01-30 03:03:06 +08:00
parent b9cce2be6a
commit 85bfd175c6
5 changed files with 154 additions and 15 deletions

View File

@ -7,12 +7,14 @@ package gorm
// db.Model(&user).Update("name", "hello") // db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) { func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Model = value
return return
} }
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
func (db *DB) Table(name string) (tx *DB) { func (db *DB) Table(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Table = name
return return
} }

View File

@ -11,6 +11,11 @@ type BuilderInterface interface {
// Interface clause interface // Interface clause interface
type Interface interface { type Interface interface {
Name() string Name() string
Builder
}
// Builder condition builder
type Builder interface {
Build(builder BuilderInterface) Build(builder BuilderInterface)
} }

View File

@ -2,7 +2,8 @@ package clause
import "strings" import "strings"
type AddConditions []Interface type Condition Builder
type AddConditions []Condition
func (cs AddConditions) Build(builder BuilderInterface) { func (cs AddConditions) Build(builder BuilderInterface) {
for idx, c := range cs { for idx, c := range cs {
@ -13,7 +14,7 @@ func (cs AddConditions) Build(builder BuilderInterface) {
} }
} }
type ORConditions []Interface type ORConditions []Condition
func (cs ORConditions) Build(builder BuilderInterface) { func (cs ORConditions) Build(builder BuilderInterface) {
for idx, c := range cs { for idx, c := range cs {
@ -24,7 +25,7 @@ func (cs ORConditions) Build(builder BuilderInterface) {
} }
} }
type NotConditions []Interface type NotConditions []Condition
func (cs NotConditions) Build(builder BuilderInterface) { func (cs NotConditions) Build(builder BuilderInterface) {
for idx, c := range cs { for idx, c := range cs {
@ -64,16 +65,22 @@ type IN struct {
func (in IN) Build(builder BuilderInterface) { func (in IN) Build(builder BuilderInterface) {
builder.WriteQuoted(in.Column) builder.WriteQuoted(in.Column)
if len(in.Values) == 0 { switch len(in.Values) {
case 0:
builder.Write(" IN (NULL)") builder.Write(" IN (NULL)")
} else { case 1:
builder.Write(" = ", builder.AddVar(in.Values...))
default:
builder.Write(" IN (", builder.AddVar(in.Values...), ")") builder.Write(" IN (", builder.AddVar(in.Values...), ")")
} }
} }
func (in IN) NegationBuild(builder BuilderInterface) { func (in IN) NegationBuild(builder BuilderInterface) {
if len(in.Values) != 0 { switch len(in.Values) {
builder.WriteQuoted(in.Column) case 0:
case 1:
builder.Write(" <> ", builder.AddVar(in.Values...))
default:
builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")")
} }
} }
@ -193,3 +200,48 @@ func (like Like) NegationBuild(builder BuilderInterface) {
builder.WriteQuoted(like.Column) builder.WriteQuoted(like.Column)
builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
} }
// Map
type Map map[interface{}]interface{}
func (m Map) Build(builder BuilderInterface) {
// TODO
}
func (m Map) NegationBuild(builder BuilderInterface) {
// TODO
}
// Attrs
type Attrs struct {
Value interface{}
Select []string
Omit []string
}
func (attrs Attrs) Build(builder BuilderInterface) {
// TODO
// builder.WriteQuoted(like.Column)
// builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (attrs Attrs) NegationBuild(builder BuilderInterface) {
// TODO
}
// ID
type ID struct {
Value []interface{}
}
func (id ID) Build(builder BuilderInterface) {
if len(id.Value) == 1 {
}
// TODO
// builder.WriteQuoted(like.Column)
// builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (id ID) NegationBuild(builder BuilderInterface) {
// TODO
}

View File

@ -93,7 +93,7 @@ func (db *DB) getInstance() *DB {
Dialector: db.Dialector, Dialector: db.Dialector,
Context: context.Background(), Context: context.Background(),
Result: Result{ Result: Result{
Statement: &Statement{DB: db, Clauses: map[string][]clause.Interface{}}, Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}},
}, },
} }
} }
@ -106,3 +106,9 @@ func (db *DB) Debug() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
// Session start session mode
func (db *DB) Session() (tx *DB) {
tx = db.getInstance()
return
}

View File

@ -4,7 +4,9 @@ import (
"bytes" "bytes"
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"strconv"
"strings" "strings"
"sync" "sync"
@ -13,9 +15,10 @@ import (
// Statement statement // Statement statement
type Statement struct { type Statement struct {
Model interface{}
Dest interface{} Dest interface{}
Table interface{} Table string
Clauses map[string][]clause.Interface Clauses map[string][]clause.Condition
Settings sync.Map Settings sync.Map
Context context.Context Context context.Context
DB *DB DB *DB
@ -45,16 +48,29 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
// Write write string // Write write string
func (stmt Statement) AddVar(vars ...interface{}) string { func (stmt Statement) AddVar(vars ...interface{}) string {
var placeholders []string var placeholders strings.Builder
for _, v := range vars { for idx, v := range vars {
if idx > 0 {
placeholders.WriteByte(',')
}
if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, namedArg) stmt.NamedVars = append(stmt.NamedVars, namedArg)
placeholders = append(placeholders, "@"+namedArg.Name) placeholders.WriteByte('@')
placeholders.WriteString(namedArg.Name)
} else if arrs, ok := v.([]interface{}); ok {
placeholders.WriteByte('(')
if len(arrs) > 0 {
placeholders.WriteString(stmt.AddVar(arrs...))
} else { } else {
placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v)) placeholders.WriteString("NULL")
}
placeholders.WriteByte(')')
} else {
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
} }
} }
return strings.Join(placeholders, ",") return placeholders.String()
} }
// Quote returns quoted value // Quote returns quoted value
@ -66,3 +82,61 @@ func (stmt Statement) Quote(field interface{}) (str string) {
func (s Statement) AddClause(clause clause.Interface) { func (s Statement) AddClause(clause clause.Interface) {
s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
} }
// BuildCondtions build conditions
func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) {
if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err != nil {
query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Condition{clause.Raw{SQL: sql, Values: args}}
}
}
args = append([]interface{}{query}, args...)
for _, arg := range args {
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
case clause.Builder:
conditions = append(conditions, v)
case *DB:
if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conditions = append(conditions, cs...)
}
}
case map[interface{}]interface{}:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
case map[string]string:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
case map[string]interface{}:
var clauseMap = clause.Map{}
for i, j := range v {
clauseMap[i] = j
}
conditions = append(conditions, clauseMap)
default:
// TODO check is struct
// struct, slice -> ids
}
}
if len(conditions) == 0 {
conditions = append(conditions, clause.ID{Value: args})
}
return conditions
}
func (s Statement) AddError(err error) {
}