Refactor builder

This commit is contained in:
Jinzhu 2020-01-30 15:14:48 +08:00
parent 85bfd175c6
commit 9d5b9834d9
11 changed files with 412 additions and 156 deletions

View File

@ -1,5 +1,7 @@
package gorm package gorm
import "github.com/jinzhu/gorm/clause"
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
// // update all users's name to `hello` // // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello") // db.Model(&User{}).Update("name", "hello")
@ -11,6 +13,27 @@ func (db *DB) Model(value interface{}) (tx *DB) {
return return
} }
// Clauses Add clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{
AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...),
})
}
return
}
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
func (db *DB) Table(name string) (tx *DB) { func (db *DB) Table(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -32,18 +55,25 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)})
return return
} }
// Not add NOT condition // Not add NOT condition
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))},
})
return return
} }
// Or add OR conditions // Or add OR conditions
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)},
})
return return
} }
@ -98,20 +128,13 @@ func (db *DB) Offset(offset int64) (tx *DB) {
// } // }
// //
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Refer https://jinzhu.github.io/gorm/crud.html#scopes func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB {
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
for _, f := range funcs { for _, f := range funcs {
db = f(db) db = f(db)
} }
return db return db
} }
//Preloads only preloads relations, don`t touch out
func (db *DB) Preloads(out interface{}) (tx *DB) {
tx = db.getInstance()
return
}
// Preload preload associations with given conditions // Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) {

View File

@ -1,31 +1,131 @@
package clause package clause
// Builder builder interface // Clause
type BuilderInterface interface { type Clause struct {
Write(sql ...string) error Name string // WHERE
WriteQuoted(field interface{}) error Priority float64
AddVar(vars ...interface{}) string BeforeExpressions []Expression
Quote(field interface{}) string AfterNameExpressions []Expression
AfterExpressions []Expression
Expression Expression
Builder ClauseBuilder
}
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface {
Build(Clause, Builder)
}
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder.Build(c, builder)
} else {
builders := c.BeforeExpressions
if c.Name != "" {
builders = append(builders, Expr{c.Name})
}
builders = append(builders, c.AfterNameExpressions...)
if c.Expression != nil {
builders = append(builders, c.Expression)
}
for idx, expr := range append(builders, c.AfterExpressions...) {
if idx != 0 {
builder.WriteByte(' ')
}
expr.Build(builder)
}
}
} }
// Interface clause interface // Interface clause interface
type Interface interface { type Interface interface {
Name() string Name() string
Builder Build(Builder)
MergeExpression(Expression)
} }
// Builder condition builder type OverrideNameInterface interface {
type Builder interface { OverrideName() string
Build(builder BuilderInterface)
} }
// NegationBuilder negation condition builder ////////////////////////////////////////////////////////////////////////////////
type NegationBuilder interface { // Predefined Clauses
NegationBuild(builder BuilderInterface) ////////////////////////////////////////////////////////////////////////////////
}
// Where where clause // Where where clause
type Where struct { type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
Builders []Expression
}
func (where Where) Name() string {
return "WHERE"
}
func (where Where) Build(builder Builder) {
var withConditions bool
if len(where.AndConditions) > 0 {
withConditions = true
where.AndConditions.Build(builder)
}
if len(where.Builders) > 0 {
for _, b := range where.Builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
or.Build(builder)
} else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
}
}
if !withConditions {
builder.Write(" FALSE")
}
return
}
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.Builders = append(where.Builders, w.Builders...)
} else {
where.Builders = append(where.Builders, expr)
}
} }
// Select select attrs when querying, updating, creating // Select select attrs when querying, updating, creating

View File

@ -1,19 +0,0 @@
package clause
type ExprInterface interface {
}
type Expr struct {
}
type Average struct {
}
type Minimum struct {
}
type Maximum struct {
}
type Sum struct {
}

30
clause/expression.go Normal file
View File

@ -0,0 +1,30 @@
package clause
// Expression expression interface
type Expression interface {
Build(builder Builder)
}
// NegationExpressionBuilder negation expression builder
type NegationExpressionBuilder interface {
NegationBuild(builder Builder)
}
// Builder builder interface
type Builder interface {
WriteByte(byte) error
Write(sql ...string) error
WriteQuoted(field interface{}) error
AddVar(vars ...interface{}) string
Quote(field interface{}) string
}
// Expr raw expression
type Expr struct {
Value string
}
// Build build raw expression
func (expr Expr) Build(builder Builder) {
builder.Write(expr.Value)
}

View File

@ -2,10 +2,13 @@ package clause
import "strings" import "strings"
type Condition Builder ////////////////////////////////////////////////////////////////////////////////
type AddConditions []Condition // Query Expressions
////////////////////////////////////////////////////////////////////////////////
func (cs AddConditions) Build(builder BuilderInterface) { type AddConditions []Expression
func (cs AddConditions) Build(builder Builder) {
for idx, c := range cs { for idx, c := range cs {
if idx > 0 { if idx > 0 {
builder.Write(" AND ") builder.Write(" AND ")
@ -14,9 +17,9 @@ func (cs AddConditions) Build(builder BuilderInterface) {
} }
} }
type ORConditions []Condition type ORConditions []Expression
func (cs ORConditions) Build(builder BuilderInterface) { func (cs ORConditions) Build(builder Builder) {
for idx, c := range cs { for idx, c := range cs {
if idx > 0 { if idx > 0 {
builder.Write(" OR ") builder.Write(" OR ")
@ -25,15 +28,15 @@ func (cs ORConditions) Build(builder BuilderInterface) {
} }
} }
type NotConditions []Condition type NotConditions []Expression
func (cs NotConditions) Build(builder BuilderInterface) { func (cs NotConditions) Build(builder Builder) {
for idx, c := range cs { for idx, c := range cs {
if idx > 0 { if idx > 0 {
builder.Write(" AND ") builder.Write(" AND ")
} }
if negationBuilder, ok := c.(NegationBuilder); ok { if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder) negationBuilder.NegationBuild(builder)
} else { } else {
builder.Write(" NOT ") builder.Write(" NOT ")
@ -42,15 +45,15 @@ func (cs NotConditions) Build(builder BuilderInterface) {
} }
} }
// Raw raw sql for where // String raw sql for where
type Raw struct { type String struct {
SQL string SQL string
Values []interface{} Values []interface{}
} }
func (raw Raw) Build(builder BuilderInterface) { func (str String) Build(builder Builder) {
sql := raw.SQL sql := str.SQL
for _, v := range raw.Values { for _, v := range str.Values {
sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1)
} }
builder.Write(sql) builder.Write(sql)
@ -62,7 +65,7 @@ type IN struct {
Values []interface{} Values []interface{}
} }
func (in IN) Build(builder BuilderInterface) { func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column) builder.WriteQuoted(in.Column)
switch len(in.Values) { switch len(in.Values) {
@ -75,7 +78,7 @@ func (in IN) Build(builder BuilderInterface) {
} }
} }
func (in IN) NegationBuild(builder BuilderInterface) { func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) { switch len(in.Values) {
case 0: case 0:
case 1: case 1:
@ -91,7 +94,7 @@ type Eq struct {
Value interface{} Value interface{}
} }
func (eq Eq) Build(builder BuilderInterface) { func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column) builder.WriteQuoted(eq.Column)
if eq.Value == nil { if eq.Value == nil {
@ -101,7 +104,7 @@ func (eq Eq) Build(builder BuilderInterface) {
} }
} }
func (eq Eq) NegationBuild(builder BuilderInterface) { func (eq Eq) NegationBuild(builder Builder) {
Neq{eq.Column, eq.Value}.Build(builder) Neq{eq.Column, eq.Value}.Build(builder)
} }
@ -111,7 +114,7 @@ type Neq struct {
Value interface{} Value interface{}
} }
func (neq Neq) Build(builder BuilderInterface) { func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column) builder.WriteQuoted(neq.Column)
if neq.Value == nil { if neq.Value == nil {
@ -121,7 +124,7 @@ func (neq Neq) Build(builder BuilderInterface) {
} }
} }
func (neq Neq) NegationBuild(builder BuilderInterface) { func (neq Neq) NegationBuild(builder Builder) {
Eq{neq.Column, neq.Value}.Build(builder) Eq{neq.Column, neq.Value}.Build(builder)
} }
@ -131,12 +134,12 @@ type Gt struct {
Value interface{} Value interface{}
} }
func (gt Gt) Build(builder BuilderInterface) { func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column) builder.WriteQuoted(gt.Column)
builder.Write(" > ", builder.AddVar(gt.Value)) builder.Write(" > ", builder.AddVar(gt.Value))
} }
func (gt Gt) NegationBuild(builder BuilderInterface) { func (gt Gt) NegationBuild(builder Builder) {
Lte{gt.Column, gt.Value}.Build(builder) Lte{gt.Column, gt.Value}.Build(builder)
} }
@ -146,12 +149,12 @@ type Gte struct {
Value interface{} Value interface{}
} }
func (gte Gte) Build(builder BuilderInterface) { func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column) builder.WriteQuoted(gte.Column)
builder.Write(" >= ", builder.AddVar(gte.Value)) builder.Write(" >= ", builder.AddVar(gte.Value))
} }
func (gte Gte) NegationBuild(builder BuilderInterface) { func (gte Gte) NegationBuild(builder Builder) {
Lt{gte.Column, gte.Value}.Build(builder) Lt{gte.Column, gte.Value}.Build(builder)
} }
@ -161,12 +164,12 @@ type Lt struct {
Value interface{} Value interface{}
} }
func (lt Lt) Build(builder BuilderInterface) { func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column) builder.WriteQuoted(lt.Column)
builder.Write(" < ", builder.AddVar(lt.Value)) builder.Write(" < ", builder.AddVar(lt.Value))
} }
func (lt Lt) NegationBuild(builder BuilderInterface) { func (lt Lt) NegationBuild(builder Builder) {
Gte{lt.Column, lt.Value}.Build(builder) Gte{lt.Column, lt.Value}.Build(builder)
} }
@ -176,12 +179,12 @@ type Lte struct {
Value interface{} Value interface{}
} }
func (lte Lte) Build(builder BuilderInterface) { func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column) builder.WriteQuoted(lte.Column)
builder.Write(" <= ", builder.AddVar(lte.Value)) builder.Write(" <= ", builder.AddVar(lte.Value))
} }
func (lte Lte) NegationBuild(builder BuilderInterface) { func (lte Lte) NegationBuild(builder Builder) {
Gt{lte.Column, lte.Value}.Build(builder) Gt{lte.Column, lte.Value}.Build(builder)
} }
@ -191,12 +194,12 @@ type Like struct {
Value interface{} Value interface{}
} }
func (like Like) Build(builder BuilderInterface) { func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column) builder.WriteQuoted(like.Column)
builder.Write(" LIKE ", builder.AddVar(like.Value)) builder.Write(" LIKE ", builder.AddVar(like.Value))
} }
func (like Like) NegationBuild(builder BuilderInterface) { func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column) builder.WriteQuoted(like.Column)
builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
} }
@ -204,11 +207,11 @@ func (like Like) NegationBuild(builder BuilderInterface) {
// Map // Map
type Map map[interface{}]interface{} type Map map[interface{}]interface{}
func (m Map) Build(builder BuilderInterface) { func (m Map) Build(builder Builder) {
// TODO // TODO
} }
func (m Map) NegationBuild(builder BuilderInterface) { func (m Map) NegationBuild(builder Builder) {
// TODO // TODO
} }
@ -219,13 +222,13 @@ type Attrs struct {
Omit []string Omit []string
} }
func (attrs Attrs) Build(builder BuilderInterface) { func (attrs Attrs) Build(builder Builder) {
// TODO // TODO
// builder.WriteQuoted(like.Column) // builder.WriteQuoted(like.Column)
// builder.Write(" LIKE ", builder.AddVar(like.Value)) // builder.Write(" LIKE ", builder.AddVar(like.Value))
} }
func (attrs Attrs) NegationBuild(builder BuilderInterface) { func (attrs Attrs) NegationBuild(builder Builder) {
// TODO // TODO
} }
@ -234,7 +237,7 @@ type ID struct {
Value []interface{} Value []interface{}
} }
func (id ID) Build(builder BuilderInterface) { func (id ID) Build(builder Builder) {
if len(id.Value) == 1 { if len(id.Value) == 1 {
} }
// TODO // TODO
@ -242,6 +245,6 @@ func (id ID) Build(builder BuilderInterface) {
// builder.Write(" LIKE ", builder.AddVar(like.Value)) // builder.Write(" LIKE ", builder.AddVar(like.Value))
} }
func (id ID) NegationBuild(builder BuilderInterface) { func (id ID) NegationBuild(builder Builder) {
// TODO // TODO
} }

22
errors.go Normal file
View File

@ -0,0 +1,22 @@
package gorm
import "errors"
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrUnaddressable unaddressable value
ErrUnaddressable = errors.New("using unaddressable value")
)
type Error struct {
Err error
}
func (e Error) Unwrap() error {
return e.Err
}

View File

@ -33,8 +33,6 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
return return
} }
// Scan scan value to a struct
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
// TODO // TODO
return nil return nil
@ -45,6 +43,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil return nil, nil
} }
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
@ -88,12 +87,12 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return return
} }
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
@ -109,6 +108,16 @@ func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) {
return return
} }
//Preloads only preloads relations, don`t touch out
func (db *DB) Preloads(out interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) Association(column string) *Association {
return nil
}
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
tx := db.Begin(opts...) tx := db.Begin(opts...)
@ -148,7 +157,3 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db *DB) Association(column string) *Association {
return nil
}

105
gorm.go
View File

@ -25,44 +25,72 @@ type Config struct {
NowFunc func() time.Time NowFunc func() time.Time
} }
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}
// Dialector GORM database dialector // Dialector GORM database dialector
type Dialector interface { type Dialector interface {
Migrator() Migrator Migrator() Migrator
BindVar(stmt Statement, v interface{}) string BindVar(stmt Statement, v interface{}) string
} }
// Result
type Result struct {
Error error
RowsAffected int64
Statement *Statement
}
// DB GORM DB definition // DB GORM DB definition
type DB struct { type DB struct {
*Config *Config
Dialector Dialector
Result Instance
clone bool
}
// Session session config when create new session
type Session struct {
Context context.Context Context context.Context
Logger logger.Interface
NowFunc func() time.Time
}
// Open initialize db session based on dialector
func Open(dialector Dialector, config *Config) (db *DB, err error) {
return &DB{
Config: config,
Dialector: dialector,
clone: true,
}, nil
}
// Session create new db session
func (db *DB) Session(config *Session) *DB {
var (
tx = db.getInstance()
txConfig = *tx.Config
)
if config.Context != nil {
tx.Context = config.Context
}
if config.Logger != nil {
txConfig.Logger = config.Logger
}
if config.NowFunc != nil {
txConfig.NowFunc = config.NowFunc
}
tx.Config = &txConfig
tx.clone = true
return tx
} }
// WithContext change current instance db's context to ctx // WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB { func (db *DB) WithContext(ctx context.Context) *DB {
tx := db.getInstance() return db.Session(&Session{Context: ctx})
tx.Context = ctx }
return tx
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
}
func (db *DB) Close() error {
return nil
} }
// Set store value with key into current db instance's context // Set store value with key into current db instance's context
@ -80,35 +108,22 @@ func (db *DB) Get(key string) (interface{}, bool) {
return nil, false return nil, false
} }
func (db *DB) Close() *DB {
// TODO
return db
}
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {
// db.Result.Statement == nil means root DB if db.clone {
if db.Result.Statement == nil { ctx := db.Instance.Context
if ctx == nil {
ctx = context.Background()
}
return &DB{ return &DB{
Config: db.Config, Config: db.Config,
Dialector: db.Dialector, Dialector: db.Dialector,
Context: context.Background(), Instance: Instance{
Result: Result{ Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}}, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
}, },
} }
} }
return db return db
} }
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
tx = db.getInstance()
return
}
// Session start session mode
func (db *DB) Session() (tx *DB) {
tx = db.getInstance()
return
}

View File

@ -1,5 +1,14 @@
package logger package logger
type LogLevel int
const (
Info LogLevel = iota + 1
Warn
Error
)
// Interface logger interface // Interface logger interface
type Interface interface { type Interface interface {
LogMode(LogLevel) Interface
} }

15
model.go Normal file
View File

@ -0,0 +1,15 @@
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
@ -13,25 +12,43 @@ import (
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
) )
// Statement statement // Instance db instance
type Statement struct { type Instance struct {
Model interface{} Error error
Dest interface{} RowsAffected int64
Table string Context context.Context
Clauses map[string][]clause.Condition Statement *Statement
Settings sync.Map
Context context.Context
DB *DB
StatementBuilder
} }
// StatementBuilder statement builder // AddError add error to instance
type StatementBuilder struct { func (inst Instance) AddError(err error) {
SQL bytes.Buffer if inst.Error == nil {
inst.Error = err
} else {
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
}
}
// Statement statement
type Statement struct {
Table string
Model interface{}
Dest interface{}
Clauses map[string]clause.Clause
Settings sync.Map
DB *DB
// SQL Builder
SQL strings.Builder
Vars []interface{} Vars []interface{}
NamedVars []sql.NamedArg NamedVars []sql.NamedArg
} }
// StatementOptimizer statement optimizer interface
type StatementOptimizer interface {
OptimizeStatement(Statement)
}
// Write write string // Write write string
func (stmt Statement) Write(sql ...string) (err error) { func (stmt Statement) Write(sql ...string) (err error) {
for _, s := range sql { for _, s := range sql {
@ -40,12 +57,23 @@ func (stmt Statement) Write(sql ...string) (err error) {
return return
} }
// Write write string
func (stmt Statement) WriteByte(c byte) (err error) {
return stmt.SQL.WriteByte(c)
}
// WriteQuoted write quoted field // WriteQuoted write quoted field
func (stmt Statement) WriteQuoted(field interface{}) (err error) { func (stmt Statement) WriteQuoted(field interface{}) (err error) {
_, err = stmt.SQL.WriteString(stmt.Quote(field)) _, err = stmt.SQL.WriteString(stmt.Quote(field))
return return
} }
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) {
// FIXME
return fmt.Sprint(field)
}
// Write write string // Write write string
func (stmt Statement) AddVar(vars ...interface{}) string { func (stmt Statement) AddVar(vars ...interface{}) string {
var placeholders strings.Builder var placeholders strings.Builder
@ -73,23 +101,34 @@ func (stmt Statement) AddVar(vars ...interface{}) string {
return placeholders.String() return placeholders.String()
} }
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) {
return fmt.Sprint(field)
}
// AddClause add clause // AddClause add clause
func (s Statement) AddClause(clause clause.Interface) { func (stmt Statement) AddClause(v clause.Interface) {
s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) if optimizer, ok := v.(StatementOptimizer); ok {
optimizer.OptimizeStatement(stmt)
}
c, _ := stmt.Clauses[v.Name()]
if namer, ok := v.(clause.OverrideNameInterface); ok {
c.Name = namer.OverrideName()
} else {
c.Name = v.Name()
}
if c.Expression != nil {
v.MergeExpression(c.Expression)
}
c.Expression = v
stmt.Clauses[v.Name()] = c
} }
// BuildCondtions build conditions // BuildCondtion build condition
func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) { func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
if sql, ok := query.(string); ok { if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err != nil { if i, err := strconv.Atoi(sql); err != nil {
query = i query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Condition{clause.Raw{SQL: sql, Values: args}} return []clause.Expression{clause.String{SQL: sql, Values: args}}
} }
} }
@ -100,12 +139,12 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi
} }
switch v := arg.(type) { switch v := arg.(type) {
case clause.Builder: case clause.Expression:
conditions = append(conditions, v) conditions = append(conditions, v)
case *DB: case *DB:
if v.Statement == nil { if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conditions = append(conditions, cs...) conditions = append(conditions, cs.Expression)
} }
} }
case map[interface{}]interface{}: case map[interface{}]interface{}:
@ -135,8 +174,22 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi
if len(conditions) == 0 { if len(conditions) == 0 {
conditions = append(conditions, clause.ID{Value: args}) conditions = append(conditions, clause.ID{Value: args})
} }
return conditions return conditions
} }
func (s Statement) AddError(err error) { // Build build sql with clauses names
func (stmt Statement) Build(clauses ...string) {
var includeSpace bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if includeSpace {
stmt.WriteByte(' ')
}
includeSpace = true
c.Build(stmt)
}
}
} }