Refactor builder

This commit is contained in:
Jinzhu 2020-01-30 15:14:48 +08:00
parent 85bfd175c6
commit 9d5b9834d9
11 changed files with 412 additions and 156 deletions

View File

@ -1,5 +1,7 @@
package gorm
import "github.com/jinzhu/gorm/clause"
// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
@ -11,6 +13,27 @@ func (db *DB) Model(value interface{}) (tx *DB) {
return
}
// Clauses Add clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{
AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...),
})
}
return
}
// Table specify the table you would like to run db operations
func (db *DB) Table(name string) (tx *DB) {
tx = db.getInstance()
@ -32,18 +55,25 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)})
return
}
// Not add NOT condition
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))},
})
return
}
// Or add OR conditions
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)},
})
return
}
@ -98,20 +128,13 @@ func (db *DB) Offset(offset int64) (tx *DB) {
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Refer https://jinzhu.github.io/gorm/crud.html#scopes
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs {
db = f(db)
}
return db
}
//Preloads only preloads relations, don`t touch out
func (db *DB) Preloads(out interface{}) (tx *DB) {
tx = db.getInstance()
return
}
// Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) {

View File

@ -1,31 +1,131 @@
package clause
// Builder builder interface
type BuilderInterface interface {
Write(sql ...string) error
WriteQuoted(field interface{}) error
AddVar(vars ...interface{}) string
Quote(field interface{}) string
// Clause
type Clause struct {
Name string // WHERE
Priority float64
BeforeExpressions []Expression
AfterNameExpressions []Expression
AfterExpressions []Expression
Expression Expression
Builder ClauseBuilder
}
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface {
Build(Clause, Builder)
}
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder.Build(c, builder)
} else {
builders := c.BeforeExpressions
if c.Name != "" {
builders = append(builders, Expr{c.Name})
}
builders = append(builders, c.AfterNameExpressions...)
if c.Expression != nil {
builders = append(builders, c.Expression)
}
for idx, expr := range append(builders, c.AfterExpressions...) {
if idx != 0 {
builder.WriteByte(' ')
}
expr.Build(builder)
}
}
}
// Interface clause interface
type Interface interface {
Name() string
Builder
Build(Builder)
MergeExpression(Expression)
}
// Builder condition builder
type Builder interface {
Build(builder BuilderInterface)
type OverrideNameInterface interface {
OverrideName() string
}
// NegationBuilder negation condition builder
type NegationBuilder interface {
NegationBuild(builder BuilderInterface)
}
////////////////////////////////////////////////////////////////////////////////
// Predefined Clauses
////////////////////////////////////////////////////////////////////////////////
// Where where clause
type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
Builders []Expression
}
func (where Where) Name() string {
return "WHERE"
}
func (where Where) Build(builder Builder) {
var withConditions bool
if len(where.AndConditions) > 0 {
withConditions = true
where.AndConditions.Build(builder)
}
if len(where.Builders) > 0 {
for _, b := range where.Builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
or.Build(builder)
} else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
}
}
if !withConditions {
builder.Write(" FALSE")
}
return
}
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.Builders = append(where.Builders, w.Builders...)
} else {
where.Builders = append(where.Builders, expr)
}
}
// Select select attrs when querying, updating, creating

View File

@ -1,19 +0,0 @@
package clause
type ExprInterface interface {
}
type Expr struct {
}
type Average struct {
}
type Minimum struct {
}
type Maximum struct {
}
type Sum struct {
}

30
clause/expression.go Normal file
View File

@ -0,0 +1,30 @@
package clause
// Expression expression interface
type Expression interface {
Build(builder Builder)
}
// NegationExpressionBuilder negation expression builder
type NegationExpressionBuilder interface {
NegationBuild(builder Builder)
}
// Builder builder interface
type Builder interface {
WriteByte(byte) error
Write(sql ...string) error
WriteQuoted(field interface{}) error
AddVar(vars ...interface{}) string
Quote(field interface{}) string
}
// Expr raw expression
type Expr struct {
Value string
}
// Build build raw expression
func (expr Expr) Build(builder Builder) {
builder.Write(expr.Value)
}

View File

@ -2,10 +2,13 @@ package clause
import "strings"
type Condition Builder
type AddConditions []Condition
////////////////////////////////////////////////////////////////////////////////
// Query Expressions
////////////////////////////////////////////////////////////////////////////////
func (cs AddConditions) Build(builder BuilderInterface) {
type AddConditions []Expression
func (cs AddConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" AND ")
@ -14,9 +17,9 @@ func (cs AddConditions) Build(builder BuilderInterface) {
}
}
type ORConditions []Condition
type ORConditions []Expression
func (cs ORConditions) Build(builder BuilderInterface) {
func (cs ORConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" OR ")
@ -25,15 +28,15 @@ func (cs ORConditions) Build(builder BuilderInterface) {
}
}
type NotConditions []Condition
type NotConditions []Expression
func (cs NotConditions) Build(builder BuilderInterface) {
func (cs NotConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" AND ")
}
if negationBuilder, ok := c.(NegationBuilder); ok {
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.Write(" NOT ")
@ -42,15 +45,15 @@ func (cs NotConditions) Build(builder BuilderInterface) {
}
}
// Raw raw sql for where
type Raw struct {
// String raw sql for where
type String struct {
SQL string
Values []interface{}
}
func (raw Raw) Build(builder BuilderInterface) {
sql := raw.SQL
for _, v := range raw.Values {
func (str String) Build(builder Builder) {
sql := str.SQL
for _, v := range str.Values {
sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1)
}
builder.Write(sql)
@ -62,7 +65,7 @@ type IN struct {
Values []interface{}
}
func (in IN) Build(builder BuilderInterface) {
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
@ -75,7 +78,7 @@ func (in IN) Build(builder BuilderInterface) {
}
}
func (in IN) NegationBuild(builder BuilderInterface) {
func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) {
case 0:
case 1:
@ -91,7 +94,7 @@ type Eq struct {
Value interface{}
}
func (eq Eq) Build(builder BuilderInterface) {
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
if eq.Value == nil {
@ -101,7 +104,7 @@ func (eq Eq) Build(builder BuilderInterface) {
}
}
func (eq Eq) NegationBuild(builder BuilderInterface) {
func (eq Eq) NegationBuild(builder Builder) {
Neq{eq.Column, eq.Value}.Build(builder)
}
@ -111,7 +114,7 @@ type Neq struct {
Value interface{}
}
func (neq Neq) Build(builder BuilderInterface) {
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
if neq.Value == nil {
@ -121,7 +124,7 @@ func (neq Neq) Build(builder BuilderInterface) {
}
}
func (neq Neq) NegationBuild(builder BuilderInterface) {
func (neq Neq) NegationBuild(builder Builder) {
Eq{neq.Column, neq.Value}.Build(builder)
}
@ -131,12 +134,12 @@ type Gt struct {
Value interface{}
}
func (gt Gt) Build(builder BuilderInterface) {
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.Write(" > ", builder.AddVar(gt.Value))
}
func (gt Gt) NegationBuild(builder BuilderInterface) {
func (gt Gt) NegationBuild(builder Builder) {
Lte{gt.Column, gt.Value}.Build(builder)
}
@ -146,12 +149,12 @@ type Gte struct {
Value interface{}
}
func (gte Gte) Build(builder BuilderInterface) {
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.Write(" >= ", builder.AddVar(gte.Value))
}
func (gte Gte) NegationBuild(builder BuilderInterface) {
func (gte Gte) NegationBuild(builder Builder) {
Lt{gte.Column, gte.Value}.Build(builder)
}
@ -161,12 +164,12 @@ type Lt struct {
Value interface{}
}
func (lt Lt) Build(builder BuilderInterface) {
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.Write(" < ", builder.AddVar(lt.Value))
}
func (lt Lt) NegationBuild(builder BuilderInterface) {
func (lt Lt) NegationBuild(builder Builder) {
Gte{lt.Column, lt.Value}.Build(builder)
}
@ -176,12 +179,12 @@ type Lte struct {
Value interface{}
}
func (lte Lte) Build(builder BuilderInterface) {
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.Write(" <= ", builder.AddVar(lte.Value))
}
func (lte Lte) NegationBuild(builder BuilderInterface) {
func (lte Lte) NegationBuild(builder Builder) {
Gt{lte.Column, lte.Value}.Build(builder)
}
@ -191,12 +194,12 @@ type Like struct {
Value interface{}
}
func (like Like) Build(builder BuilderInterface) {
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (like Like) NegationBuild(builder BuilderInterface) {
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
}
@ -204,11 +207,11 @@ func (like Like) NegationBuild(builder BuilderInterface) {
// Map
type Map map[interface{}]interface{}
func (m Map) Build(builder BuilderInterface) {
func (m Map) Build(builder Builder) {
// TODO
}
func (m Map) NegationBuild(builder BuilderInterface) {
func (m Map) NegationBuild(builder Builder) {
// TODO
}
@ -219,13 +222,13 @@ type Attrs struct {
Omit []string
}
func (attrs Attrs) Build(builder BuilderInterface) {
func (attrs Attrs) Build(builder Builder) {
// TODO
// builder.WriteQuoted(like.Column)
// builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (attrs Attrs) NegationBuild(builder BuilderInterface) {
func (attrs Attrs) NegationBuild(builder Builder) {
// TODO
}
@ -234,7 +237,7 @@ type ID struct {
Value []interface{}
}
func (id ID) Build(builder BuilderInterface) {
func (id ID) Build(builder Builder) {
if len(id.Value) == 1 {
}
// TODO
@ -242,6 +245,6 @@ func (id ID) Build(builder BuilderInterface) {
// builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (id ID) NegationBuild(builder BuilderInterface) {
func (id ID) NegationBuild(builder Builder) {
// TODO
}

22
errors.go Normal file
View File

@ -0,0 +1,22 @@
package gorm
import "errors"
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrUnaddressable unaddressable value
ErrUnaddressable = errors.New("using unaddressable value")
)
type Error struct {
Err error
}
func (e Error) Unwrap() error {
return e.Err
}

View File

@ -33,8 +33,6 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
return
}
// Scan scan value to a struct
func (db *DB) Row() *sql.Row {
// TODO
return nil
@ -45,6 +43,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil
}
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance()
return
@ -88,12 +87,12 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return
}
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}
@ -109,6 +108,16 @@ func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) {
return
}
//Preloads only preloads relations, don`t touch out
func (db *DB) Preloads(out interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) Association(column string) *Association {
return nil
}
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
tx := db.Begin(opts...)
@ -148,7 +157,3 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) Association(column string) *Association {
return nil
}

105
gorm.go
View File

@ -25,44 +25,72 @@ type Config struct {
NowFunc func() time.Time
}
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}
// Dialector GORM database dialector
type Dialector interface {
Migrator() Migrator
BindVar(stmt Statement, v interface{}) string
}
// Result
type Result struct {
Error error
RowsAffected int64
Statement *Statement
}
// DB GORM DB definition
type DB struct {
*Config
Dialector
Result
Instance
clone bool
}
// Session session config when create new session
type Session struct {
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
}
// Open initialize db session based on dialector
func Open(dialector Dialector, config *Config) (db *DB, err error) {
return &DB{
Config: config,
Dialector: dialector,
clone: true,
}, nil
}
// Session create new db session
func (db *DB) Session(config *Session) *DB {
var (
tx = db.getInstance()
txConfig = *tx.Config
)
if config.Context != nil {
tx.Context = config.Context
}
if config.Logger != nil {
txConfig.Logger = config.Logger
}
if config.NowFunc != nil {
txConfig.NowFunc = config.NowFunc
}
tx.Config = &txConfig
tx.clone = true
return tx
}
// WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB {
tx := db.getInstance()
tx.Context = ctx
return tx
return db.Session(&Session{Context: ctx})
}
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
}
func (db *DB) Close() error {
return nil
}
// Set store value with key into current db instance's context
@ -80,35 +108,22 @@ func (db *DB) Get(key string) (interface{}, bool) {
return nil, false
}
func (db *DB) Close() *DB {
// TODO
return db
}
func (db *DB) getInstance() *DB {
// db.Result.Statement == nil means root DB
if db.Result.Statement == nil {
if db.clone {
ctx := db.Instance.Context
if ctx == nil {
ctx = context.Background()
}
return &DB{
Config: db.Config,
Dialector: db.Dialector,
Context: context.Background(),
Result: Result{
Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}},
Instance: Instance{
Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
},
}
}
return db
}
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
tx = db.getInstance()
return
}
// Session start session mode
func (db *DB) Session() (tx *DB) {
tx = db.getInstance()
return
}

View File

@ -1,5 +1,14 @@
package logger
type LogLevel int
const (
Info LogLevel = iota + 1
Warn
Error
)
// Interface logger interface
type Interface interface {
LogMode(LogLevel) Interface
}

15
model.go Normal file
View File

@ -0,0 +1,15 @@
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}

View File

@ -1,7 +1,6 @@
package gorm
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
@ -13,25 +12,43 @@ import (
"github.com/jinzhu/gorm/clause"
)
// Statement statement
type Statement struct {
Model interface{}
Dest interface{}
Table string
Clauses map[string][]clause.Condition
Settings sync.Map
// Instance db instance
type Instance struct {
Error error
RowsAffected int64
Context context.Context
DB *DB
StatementBuilder
Statement *Statement
}
// StatementBuilder statement builder
type StatementBuilder struct {
SQL bytes.Buffer
// AddError add error to instance
func (inst Instance) AddError(err error) {
if inst.Error == nil {
inst.Error = err
} else {
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
}
}
// Statement statement
type Statement struct {
Table string
Model interface{}
Dest interface{}
Clauses map[string]clause.Clause
Settings sync.Map
DB *DB
// SQL Builder
SQL strings.Builder
Vars []interface{}
NamedVars []sql.NamedArg
}
// StatementOptimizer statement optimizer interface
type StatementOptimizer interface {
OptimizeStatement(Statement)
}
// Write write string
func (stmt Statement) Write(sql ...string) (err error) {
for _, s := range sql {
@ -40,12 +57,23 @@ func (stmt Statement) Write(sql ...string) (err error) {
return
}
// Write write string
func (stmt Statement) WriteByte(c byte) (err error) {
return stmt.SQL.WriteByte(c)
}
// WriteQuoted write quoted field
func (stmt Statement) WriteQuoted(field interface{}) (err error) {
_, err = stmt.SQL.WriteString(stmt.Quote(field))
return
}
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) {
// FIXME
return fmt.Sprint(field)
}
// Write write string
func (stmt Statement) AddVar(vars ...interface{}) string {
var placeholders strings.Builder
@ -73,23 +101,34 @@ func (stmt Statement) AddVar(vars ...interface{}) string {
return placeholders.String()
}
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) {
return fmt.Sprint(field)
}
// AddClause add clause
func (s Statement) AddClause(clause clause.Interface) {
s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
func (stmt Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementOptimizer); ok {
optimizer.OptimizeStatement(stmt)
}
c, _ := stmt.Clauses[v.Name()]
if namer, ok := v.(clause.OverrideNameInterface); ok {
c.Name = namer.OverrideName()
} else {
c.Name = v.Name()
}
if c.Expression != nil {
v.MergeExpression(c.Expression)
}
c.Expression = v
stmt.Clauses[v.Name()] = c
}
// BuildCondtions build conditions
func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) {
// BuildCondtion build condition
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err != nil {
query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Condition{clause.Raw{SQL: sql, Values: args}}
return []clause.Expression{clause.String{SQL: sql, Values: args}}
}
}
@ -100,12 +139,12 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi
}
switch v := arg.(type) {
case clause.Builder:
case clause.Expression:
conditions = append(conditions, v)
case *DB:
if v.Statement == nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conditions = append(conditions, cs...)
conditions = append(conditions, cs.Expression)
}
}
case map[interface{}]interface{}:
@ -135,8 +174,22 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi
if len(conditions) == 0 {
conditions = append(conditions, clause.ID{Value: args})
}
return conditions
}
func (s Statement) AddError(err error) {
// Build build sql with clauses names
func (stmt Statement) Build(clauses ...string) {
var includeSpace bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if includeSpace {
stmt.WriteByte(' ')
}
includeSpace = true
c.Build(stmt)
}
}
}