From 85bfd175c6bf18cecac0e9c7403b3956a6c4ed54 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jan 2020 03:03:06 +0800 Subject: [PATCH] Implement build conditions --- chainable_api.go | 2 ++ clause/clause.go | 5 +++ clause/operators.go | 66 ++++++++++++++++++++++++++++++---- gorm.go | 8 ++++- statement.go | 88 +++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 154 insertions(+), 15 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index d8f2116c..75e0fa2a 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -7,12 +7,14 @@ package gorm // db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.Model = value return } // Table specify the table you would like to run db operations func (db *DB) Table(name string) (tx *DB) { tx = db.getInstance() + tx.Statement.Table = name return } diff --git a/clause/clause.go b/clause/clause.go index 4495a9d5..1afb120e 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -11,6 +11,11 @@ type BuilderInterface interface { // Interface clause interface type Interface interface { Name() string + Builder +} + +// Builder condition builder +type Builder interface { Build(builder BuilderInterface) } diff --git a/clause/operators.go b/clause/operators.go index 331abea7..a6bdb4aa 100644 --- a/clause/operators.go +++ b/clause/operators.go @@ -2,7 +2,8 @@ package clause import "strings" -type AddConditions []Interface +type Condition Builder +type AddConditions []Condition func (cs AddConditions) Build(builder BuilderInterface) { for idx, c := range cs { @@ -13,7 +14,7 @@ func (cs AddConditions) Build(builder BuilderInterface) { } } -type ORConditions []Interface +type ORConditions []Condition func (cs ORConditions) Build(builder BuilderInterface) { for idx, c := range cs { @@ -24,7 +25,7 @@ func (cs ORConditions) Build(builder BuilderInterface) { } } -type NotConditions []Interface +type NotConditions []Condition func (cs NotConditions) Build(builder BuilderInterface) { for idx, c := range cs { @@ -64,16 +65,22 @@ type IN struct { func (in IN) Build(builder BuilderInterface) { builder.WriteQuoted(in.Column) - if len(in.Values) == 0 { + switch len(in.Values) { + case 0: builder.Write(" IN (NULL)") - } else { + case 1: + builder.Write(" = ", builder.AddVar(in.Values...)) + default: builder.Write(" IN (", builder.AddVar(in.Values...), ")") } } func (in IN) NegationBuild(builder BuilderInterface) { - if len(in.Values) != 0 { - builder.WriteQuoted(in.Column) + switch len(in.Values) { + case 0: + case 1: + builder.Write(" <> ", builder.AddVar(in.Values...)) + default: builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") } } @@ -193,3 +200,48 @@ func (like Like) NegationBuild(builder BuilderInterface) { builder.WriteQuoted(like.Column) builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) } + +// Map +type Map map[interface{}]interface{} + +func (m Map) Build(builder BuilderInterface) { + // TODO +} + +func (m Map) NegationBuild(builder BuilderInterface) { + // TODO +} + +// Attrs +type Attrs struct { + Value interface{} + Select []string + Omit []string +} + +func (attrs Attrs) Build(builder BuilderInterface) { + // TODO + // builder.WriteQuoted(like.Column) + // builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (attrs Attrs) NegationBuild(builder BuilderInterface) { + // TODO +} + +// ID +type ID struct { + Value []interface{} +} + +func (id ID) Build(builder BuilderInterface) { + if len(id.Value) == 1 { + } + // TODO + // builder.WriteQuoted(like.Column) + // builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (id ID) NegationBuild(builder BuilderInterface) { + // TODO +} diff --git a/gorm.go b/gorm.go index 1b6d88df..86d5af9a 100644 --- a/gorm.go +++ b/gorm.go @@ -93,7 +93,7 @@ func (db *DB) getInstance() *DB { Dialector: db.Dialector, Context: context.Background(), Result: Result{ - Statement: &Statement{DB: db, Clauses: map[string][]clause.Interface{}}, + Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}}, }, } } @@ -106,3 +106,9 @@ func (db *DB) Debug() (tx *DB) { tx = db.getInstance() return } + +// Session start session mode +func (db *DB) Session() (tx *DB) { + tx = db.getInstance() + return +} diff --git a/statement.go b/statement.go index 21e95e11..5dab59b3 100644 --- a/statement.go +++ b/statement.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" "fmt" + "strconv" "strings" "sync" @@ -13,9 +15,10 @@ import ( // Statement statement type Statement struct { + Model interface{} Dest interface{} - Table interface{} - Clauses map[string][]clause.Interface + Table string + Clauses map[string][]clause.Condition Settings sync.Map Context context.Context DB *DB @@ -45,16 +48,29 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) { // Write write string func (stmt Statement) AddVar(vars ...interface{}) string { - var placeholders []string - for _, v := range vars { + var placeholders strings.Builder + for idx, v := range vars { + if idx > 0 { + placeholders.WriteByte(',') + } + if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { stmt.NamedVars = append(stmt.NamedVars, namedArg) - placeholders = append(placeholders, "@"+namedArg.Name) + placeholders.WriteByte('@') + placeholders.WriteString(namedArg.Name) + } else if arrs, ok := v.([]interface{}); ok { + placeholders.WriteByte('(') + if len(arrs) > 0 { + placeholders.WriteString(stmt.AddVar(arrs...)) + } else { + placeholders.WriteString("NULL") + } + placeholders.WriteByte(')') } else { - placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v)) + placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } - return strings.Join(placeholders, ",") + return placeholders.String() } // Quote returns quoted value @@ -66,3 +82,61 @@ func (stmt Statement) Quote(field interface{}) (str string) { func (s Statement) AddClause(clause clause.Interface) { s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) } + +// BuildCondtions build conditions +func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) { + if sql, ok := query.(string); ok { + if i, err := strconv.Atoi(sql); err != nil { + query = i + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + return []clause.Condition{clause.Raw{SQL: sql, Values: args}} + } + } + + args = append([]interface{}{query}, args...) + for _, arg := range args { + if valuer, ok := arg.(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + switch v := arg.(type) { + case clause.Builder: + conditions = append(conditions, v) + case *DB: + if v.Statement == nil { + if cs, ok := v.Statement.Clauses["WHERE"]; ok { + conditions = append(conditions, cs...) + } + } + case map[interface{}]interface{}: + var clauseMap = clause.Map{} + for i, j := range v { + clauseMap[i] = j + } + conditions = append(conditions, clauseMap) + case map[string]string: + var clauseMap = clause.Map{} + for i, j := range v { + clauseMap[i] = j + } + conditions = append(conditions, clauseMap) + case map[string]interface{}: + var clauseMap = clause.Map{} + for i, j := range v { + clauseMap[i] = j + } + conditions = append(conditions, clauseMap) + default: + // TODO check is struct + // struct, slice -> ids + } + } + + if len(conditions) == 0 { + conditions = append(conditions, clause.ID{Value: args}) + } + return conditions +} + +func (s Statement) AddError(err error) { +}