mirror of https://github.com/go-gorm/gorm.git
Refactor builder
This commit is contained in:
parent
85bfd175c6
commit
9d5b9834d9
|
@ -1,5 +1,7 @@
|
|||
package gorm
|
||||
|
||||
import "github.com/jinzhu/gorm/clause"
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
|
@ -11,6 +13,27 @@ func (db *DB) Model(value interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
|
||||
for _, cond := range conds {
|
||||
if c, ok := cond.(clause.Interface); ok {
|
||||
tx.Statement.AddClause(c)
|
||||
} else {
|
||||
whereConds = append(whereConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(whereConds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
func (db *DB) Table(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
@ -32,18 +55,25 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)})
|
||||
return
|
||||
}
|
||||
|
||||
// Not add NOT condition
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Or add OR conditions
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -98,20 +128,13 @@ func (db *DB) Offset(offset int64) (tx *DB) {
|
|||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
// Refer https://jinzhu.github.io/gorm/crud.html#scopes
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||
for _, f := range funcs {
|
||||
db = f(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
//Preloads only preloads relations, don`t touch out
|
||||
func (db *DB) Preloads(out interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) {
|
||||
|
|
128
clause/clause.go
128
clause/clause.go
|
@ -1,31 +1,131 @@
|
|||
package clause
|
||||
|
||||
// Builder builder interface
|
||||
type BuilderInterface interface {
|
||||
Write(sql ...string) error
|
||||
WriteQuoted(field interface{}) error
|
||||
AddVar(vars ...interface{}) string
|
||||
Quote(field interface{}) string
|
||||
// Clause
|
||||
type Clause struct {
|
||||
Name string // WHERE
|
||||
Priority float64
|
||||
BeforeExpressions []Expression
|
||||
AfterNameExpressions []Expression
|
||||
AfterExpressions []Expression
|
||||
Expression Expression
|
||||
Builder ClauseBuilder
|
||||
}
|
||||
|
||||
// ClauseBuilder clause builder, allows to custmize how to build clause
|
||||
type ClauseBuilder interface {
|
||||
Build(Clause, Builder)
|
||||
}
|
||||
|
||||
// Build build clause
|
||||
func (c Clause) Build(builder Builder) {
|
||||
if c.Builder != nil {
|
||||
c.Builder.Build(c, builder)
|
||||
} else {
|
||||
builders := c.BeforeExpressions
|
||||
if c.Name != "" {
|
||||
builders = append(builders, Expr{c.Name})
|
||||
}
|
||||
|
||||
builders = append(builders, c.AfterNameExpressions...)
|
||||
if c.Expression != nil {
|
||||
builders = append(builders, c.Expression)
|
||||
}
|
||||
|
||||
for idx, expr := range append(builders, c.AfterExpressions...) {
|
||||
if idx != 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Interface clause interface
|
||||
type Interface interface {
|
||||
Name() string
|
||||
Builder
|
||||
Build(Builder)
|
||||
MergeExpression(Expression)
|
||||
}
|
||||
|
||||
// Builder condition builder
|
||||
type Builder interface {
|
||||
Build(builder BuilderInterface)
|
||||
type OverrideNameInterface interface {
|
||||
OverrideName() string
|
||||
}
|
||||
|
||||
// NegationBuilder negation condition builder
|
||||
type NegationBuilder interface {
|
||||
NegationBuild(builder BuilderInterface)
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Predefined Clauses
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
AndConditions AddConditions
|
||||
ORConditions []ORConditions
|
||||
Builders []Expression
|
||||
}
|
||||
|
||||
func (where Where) Name() string {
|
||||
return "WHERE"
|
||||
}
|
||||
|
||||
func (where Where) Build(builder Builder) {
|
||||
var withConditions bool
|
||||
|
||||
if len(where.AndConditions) > 0 {
|
||||
withConditions = true
|
||||
where.AndConditions.Build(builder)
|
||||
}
|
||||
|
||||
if len(where.Builders) > 0 {
|
||||
for _, b := range where.Builders {
|
||||
if withConditions {
|
||||
builder.Write(" AND ")
|
||||
}
|
||||
withConditions = true
|
||||
b.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var singleOrConditions []ORConditions
|
||||
for _, or := range where.ORConditions {
|
||||
if len(or) == 1 {
|
||||
if withConditions {
|
||||
builder.Write(" OR ")
|
||||
or.Build(builder)
|
||||
} else {
|
||||
singleOrConditions = append(singleOrConditions, or)
|
||||
}
|
||||
} else {
|
||||
withConditions = true
|
||||
builder.Write(" AND (")
|
||||
or.Build(builder)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
for _, or := range singleOrConditions {
|
||||
if withConditions {
|
||||
builder.Write(" AND ")
|
||||
or.Build(builder)
|
||||
} else {
|
||||
withConditions = true
|
||||
or.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
if !withConditions {
|
||||
builder.Write(" FALSE")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (where Where) MergeExpression(expr Expression) {
|
||||
if w, ok := expr.(Where); ok {
|
||||
where.AndConditions = append(where.AndConditions, w.AndConditions...)
|
||||
where.ORConditions = append(where.ORConditions, w.ORConditions...)
|
||||
where.Builders = append(where.Builders, w.Builders...)
|
||||
} else {
|
||||
where.Builders = append(where.Builders, expr)
|
||||
}
|
||||
}
|
||||
|
||||
// Select select attrs when querying, updating, creating
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
package clause
|
||||
|
||||
type ExprInterface interface {
|
||||
}
|
||||
|
||||
type Expr struct {
|
||||
}
|
||||
|
||||
type Average struct {
|
||||
}
|
||||
|
||||
type Minimum struct {
|
||||
}
|
||||
|
||||
type Maximum struct {
|
||||
}
|
||||
|
||||
type Sum struct {
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package clause
|
||||
|
||||
// Expression expression interface
|
||||
type Expression interface {
|
||||
Build(builder Builder)
|
||||
}
|
||||
|
||||
// NegationExpressionBuilder negation expression builder
|
||||
type NegationExpressionBuilder interface {
|
||||
NegationBuild(builder Builder)
|
||||
}
|
||||
|
||||
// Builder builder interface
|
||||
type Builder interface {
|
||||
WriteByte(byte) error
|
||||
Write(sql ...string) error
|
||||
WriteQuoted(field interface{}) error
|
||||
AddVar(vars ...interface{}) string
|
||||
Quote(field interface{}) string
|
||||
}
|
||||
|
||||
// Expr raw expression
|
||||
type Expr struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr Expr) Build(builder Builder) {
|
||||
builder.Write(expr.Value)
|
||||
}
|
|
@ -2,10 +2,13 @@ package clause
|
|||
|
||||
import "strings"
|
||||
|
||||
type Condition Builder
|
||||
type AddConditions []Condition
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Query Expressions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (cs AddConditions) Build(builder BuilderInterface) {
|
||||
type AddConditions []Expression
|
||||
|
||||
func (cs AddConditions) Build(builder Builder) {
|
||||
for idx, c := range cs {
|
||||
if idx > 0 {
|
||||
builder.Write(" AND ")
|
||||
|
@ -14,9 +17,9 @@ func (cs AddConditions) Build(builder BuilderInterface) {
|
|||
}
|
||||
}
|
||||
|
||||
type ORConditions []Condition
|
||||
type ORConditions []Expression
|
||||
|
||||
func (cs ORConditions) Build(builder BuilderInterface) {
|
||||
func (cs ORConditions) Build(builder Builder) {
|
||||
for idx, c := range cs {
|
||||
if idx > 0 {
|
||||
builder.Write(" OR ")
|
||||
|
@ -25,15 +28,15 @@ func (cs ORConditions) Build(builder BuilderInterface) {
|
|||
}
|
||||
}
|
||||
|
||||
type NotConditions []Condition
|
||||
type NotConditions []Expression
|
||||
|
||||
func (cs NotConditions) Build(builder BuilderInterface) {
|
||||
func (cs NotConditions) Build(builder Builder) {
|
||||
for idx, c := range cs {
|
||||
if idx > 0 {
|
||||
builder.Write(" AND ")
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationBuilder); ok {
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.Write(" NOT ")
|
||||
|
@ -42,15 +45,15 @@ func (cs NotConditions) Build(builder BuilderInterface) {
|
|||
}
|
||||
}
|
||||
|
||||
// Raw raw sql for where
|
||||
type Raw struct {
|
||||
// String raw sql for where
|
||||
type String struct {
|
||||
SQL string
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
func (raw Raw) Build(builder BuilderInterface) {
|
||||
sql := raw.SQL
|
||||
for _, v := range raw.Values {
|
||||
func (str String) Build(builder Builder) {
|
||||
sql := str.SQL
|
||||
for _, v := range str.Values {
|
||||
sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1)
|
||||
}
|
||||
builder.Write(sql)
|
||||
|
@ -62,7 +65,7 @@ type IN struct {
|
|||
Values []interface{}
|
||||
}
|
||||
|
||||
func (in IN) Build(builder BuilderInterface) {
|
||||
func (in IN) Build(builder Builder) {
|
||||
builder.WriteQuoted(in.Column)
|
||||
|
||||
switch len(in.Values) {
|
||||
|
@ -75,7 +78,7 @@ func (in IN) Build(builder BuilderInterface) {
|
|||
}
|
||||
}
|
||||
|
||||
func (in IN) NegationBuild(builder BuilderInterface) {
|
||||
func (in IN) NegationBuild(builder Builder) {
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
case 1:
|
||||
|
@ -91,7 +94,7 @@ type Eq struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (eq Eq) Build(builder BuilderInterface) {
|
||||
func (eq Eq) Build(builder Builder) {
|
||||
builder.WriteQuoted(eq.Column)
|
||||
|
||||
if eq.Value == nil {
|
||||
|
@ -101,7 +104,7 @@ func (eq Eq) Build(builder BuilderInterface) {
|
|||
}
|
||||
}
|
||||
|
||||
func (eq Eq) NegationBuild(builder BuilderInterface) {
|
||||
func (eq Eq) NegationBuild(builder Builder) {
|
||||
Neq{eq.Column, eq.Value}.Build(builder)
|
||||
}
|
||||
|
||||
|
@ -111,7 +114,7 @@ type Neq struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (neq Neq) Build(builder BuilderInterface) {
|
||||
func (neq Neq) Build(builder Builder) {
|
||||
builder.WriteQuoted(neq.Column)
|
||||
|
||||
if neq.Value == nil {
|
||||
|
@ -121,7 +124,7 @@ func (neq Neq) Build(builder BuilderInterface) {
|
|||
}
|
||||
}
|
||||
|
||||
func (neq Neq) NegationBuild(builder BuilderInterface) {
|
||||
func (neq Neq) NegationBuild(builder Builder) {
|
||||
Eq{neq.Column, neq.Value}.Build(builder)
|
||||
}
|
||||
|
||||
|
@ -131,12 +134,12 @@ type Gt struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (gt Gt) Build(builder BuilderInterface) {
|
||||
func (gt Gt) Build(builder Builder) {
|
||||
builder.WriteQuoted(gt.Column)
|
||||
builder.Write(" > ", builder.AddVar(gt.Value))
|
||||
}
|
||||
|
||||
func (gt Gt) NegationBuild(builder BuilderInterface) {
|
||||
func (gt Gt) NegationBuild(builder Builder) {
|
||||
Lte{gt.Column, gt.Value}.Build(builder)
|
||||
}
|
||||
|
||||
|
@ -146,12 +149,12 @@ type Gte struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (gte Gte) Build(builder BuilderInterface) {
|
||||
func (gte Gte) Build(builder Builder) {
|
||||
builder.WriteQuoted(gte.Column)
|
||||
builder.Write(" >= ", builder.AddVar(gte.Value))
|
||||
}
|
||||
|
||||
func (gte Gte) NegationBuild(builder BuilderInterface) {
|
||||
func (gte Gte) NegationBuild(builder Builder) {
|
||||
Lt{gte.Column, gte.Value}.Build(builder)
|
||||
}
|
||||
|
||||
|
@ -161,12 +164,12 @@ type Lt struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (lt Lt) Build(builder BuilderInterface) {
|
||||
func (lt Lt) Build(builder Builder) {
|
||||
builder.WriteQuoted(lt.Column)
|
||||
builder.Write(" < ", builder.AddVar(lt.Value))
|
||||
}
|
||||
|
||||
func (lt Lt) NegationBuild(builder BuilderInterface) {
|
||||
func (lt Lt) NegationBuild(builder Builder) {
|
||||
Gte{lt.Column, lt.Value}.Build(builder)
|
||||
}
|
||||
|
||||
|
@ -176,12 +179,12 @@ type Lte struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (lte Lte) Build(builder BuilderInterface) {
|
||||
func (lte Lte) Build(builder Builder) {
|
||||
builder.WriteQuoted(lte.Column)
|
||||
builder.Write(" <= ", builder.AddVar(lte.Value))
|
||||
}
|
||||
|
||||
func (lte Lte) NegationBuild(builder BuilderInterface) {
|
||||
func (lte Lte) NegationBuild(builder Builder) {
|
||||
Gt{lte.Column, lte.Value}.Build(builder)
|
||||
}
|
||||
|
||||
|
@ -191,12 +194,12 @@ type Like struct {
|
|||
Value interface{}
|
||||
}
|
||||
|
||||
func (like Like) Build(builder BuilderInterface) {
|
||||
func (like Like) Build(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.Write(" LIKE ", builder.AddVar(like.Value))
|
||||
}
|
||||
|
||||
func (like Like) NegationBuild(builder BuilderInterface) {
|
||||
func (like Like) NegationBuild(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
|
||||
}
|
||||
|
@ -204,11 +207,11 @@ func (like Like) NegationBuild(builder BuilderInterface) {
|
|||
// Map
|
||||
type Map map[interface{}]interface{}
|
||||
|
||||
func (m Map) Build(builder BuilderInterface) {
|
||||
func (m Map) Build(builder Builder) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func (m Map) NegationBuild(builder BuilderInterface) {
|
||||
func (m Map) NegationBuild(builder Builder) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
|
@ -219,13 +222,13 @@ type Attrs struct {
|
|||
Omit []string
|
||||
}
|
||||
|
||||
func (attrs Attrs) Build(builder BuilderInterface) {
|
||||
func (attrs Attrs) Build(builder Builder) {
|
||||
// TODO
|
||||
// builder.WriteQuoted(like.Column)
|
||||
// builder.Write(" LIKE ", builder.AddVar(like.Value))
|
||||
}
|
||||
|
||||
func (attrs Attrs) NegationBuild(builder BuilderInterface) {
|
||||
func (attrs Attrs) NegationBuild(builder Builder) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
|
@ -234,7 +237,7 @@ type ID struct {
|
|||
Value []interface{}
|
||||
}
|
||||
|
||||
func (id ID) Build(builder BuilderInterface) {
|
||||
func (id ID) Build(builder Builder) {
|
||||
if len(id.Value) == 1 {
|
||||
}
|
||||
// TODO
|
||||
|
@ -242,6 +245,6 @@ func (id ID) Build(builder BuilderInterface) {
|
|||
// builder.Write(" LIKE ", builder.AddVar(like.Value))
|
||||
}
|
||||
|
||||
func (id ID) NegationBuild(builder BuilderInterface) {
|
||||
func (id ID) NegationBuild(builder Builder) {
|
||||
// TODO
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package gorm
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrRecordNotFound record not found error
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
||||
ErrInvalidSQL = errors.New("invalid SQL")
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||
// ErrUnaddressable unaddressable value
|
||||
ErrUnaddressable = errors.New("using unaddressable value")
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e Error) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
|
@ -33,8 +33,6 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
// TODO
|
||||
return nil
|
||||
|
@ -45,6 +43,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
|
@ -88,12 +87,12 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
|
||||
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
|
||||
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
@ -109,6 +108,16 @@ func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
//Preloads only preloads relations, don`t touch out
|
||||
func (db *DB) Preloads(out interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Association(column string) *Association {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
tx := db.Begin(opts...)
|
||||
|
@ -148,7 +157,3 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
|||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Association(column string) *Association {
|
||||
return nil
|
||||
}
|
||||
|
|
105
gorm.go
105
gorm.go
|
@ -25,44 +25,72 @@ type Config struct {
|
|||
NowFunc func() time.Time
|
||||
}
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embeded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
}
|
||||
|
||||
// Dialector GORM database dialector
|
||||
type Dialector interface {
|
||||
Migrator() Migrator
|
||||
BindVar(stmt Statement, v interface{}) string
|
||||
}
|
||||
|
||||
// Result
|
||||
type Result struct {
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Statement *Statement
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
type DB struct {
|
||||
*Config
|
||||
Dialector
|
||||
Result
|
||||
Instance
|
||||
clone bool
|
||||
}
|
||||
|
||||
// Session session config when create new session
|
||||
type Session struct {
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
}
|
||||
|
||||
// Open initialize db session based on dialector
|
||||
func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
||||
return &DB{
|
||||
Config: config,
|
||||
Dialector: dialector,
|
||||
clone: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Session create new db session
|
||||
func (db *DB) Session(config *Session) *DB {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
txConfig = *tx.Config
|
||||
)
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Context = config.Context
|
||||
}
|
||||
|
||||
if config.Logger != nil {
|
||||
txConfig.Logger = config.Logger
|
||||
}
|
||||
|
||||
if config.NowFunc != nil {
|
||||
txConfig.NowFunc = config.NowFunc
|
||||
}
|
||||
|
||||
tx.Config = &txConfig
|
||||
tx.clone = true
|
||||
return tx
|
||||
}
|
||||
|
||||
// WithContext change current instance db's context to ctx
|
||||
func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Context = ctx
|
||||
return tx
|
||||
return db.Session(&Session{Context: ctx})
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
|
||||
}
|
||||
|
||||
func (db *DB) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set store value with key into current db instance's context
|
||||
|
@ -80,35 +108,22 @@ func (db *DB) Get(key string) (interface{}, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (db *DB) Close() *DB {
|
||||
// TODO
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
// db.Result.Statement == nil means root DB
|
||||
if db.Result.Statement == nil {
|
||||
if db.clone {
|
||||
ctx := db.Instance.Context
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
return &DB{
|
||||
Config: db.Config,
|
||||
Dialector: db.Dialector,
|
||||
Context: context.Background(),
|
||||
Result: Result{
|
||||
Statement: &Statement{DB: db, Clauses: map[string][]clause.Condition{}},
|
||||
Instance: Instance{
|
||||
Context: ctx,
|
||||
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
// Session start session mode
|
||||
func (db *DB) Session() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
package logger
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
Info LogLevel = iota + 1
|
||||
Warn
|
||||
Error
|
||||
)
|
||||
|
||||
// Interface logger interface
|
||||
type Interface interface {
|
||||
LogMode(LogLevel) Interface
|
||||
}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embeded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
}
|
107
statement.go
107
statement.go
|
@ -1,7 +1,6 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
|
@ -13,25 +12,43 @@ import (
|
|||
"github.com/jinzhu/gorm/clause"
|
||||
)
|
||||
|
||||
// Statement statement
|
||||
type Statement struct {
|
||||
Model interface{}
|
||||
Dest interface{}
|
||||
Table string
|
||||
Clauses map[string][]clause.Condition
|
||||
Settings sync.Map
|
||||
Context context.Context
|
||||
DB *DB
|
||||
StatementBuilder
|
||||
// Instance db instance
|
||||
type Instance struct {
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Context context.Context
|
||||
Statement *Statement
|
||||
}
|
||||
|
||||
// StatementBuilder statement builder
|
||||
type StatementBuilder struct {
|
||||
SQL bytes.Buffer
|
||||
// AddError add error to instance
|
||||
func (inst Instance) AddError(err error) {
|
||||
if inst.Error == nil {
|
||||
inst.Error = err
|
||||
} else {
|
||||
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Statement statement
|
||||
type Statement struct {
|
||||
Table string
|
||||
Model interface{}
|
||||
Dest interface{}
|
||||
Clauses map[string]clause.Clause
|
||||
Settings sync.Map
|
||||
DB *DB
|
||||
|
||||
// SQL Builder
|
||||
SQL strings.Builder
|
||||
Vars []interface{}
|
||||
NamedVars []sql.NamedArg
|
||||
}
|
||||
|
||||
// StatementOptimizer statement optimizer interface
|
||||
type StatementOptimizer interface {
|
||||
OptimizeStatement(Statement)
|
||||
}
|
||||
|
||||
// Write write string
|
||||
func (stmt Statement) Write(sql ...string) (err error) {
|
||||
for _, s := range sql {
|
||||
|
@ -40,12 +57,23 @@ func (stmt Statement) Write(sql ...string) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// Write write string
|
||||
func (stmt Statement) WriteByte(c byte) (err error) {
|
||||
return stmt.SQL.WriteByte(c)
|
||||
}
|
||||
|
||||
// WriteQuoted write quoted field
|
||||
func (stmt Statement) WriteQuoted(field interface{}) (err error) {
|
||||
_, err = stmt.SQL.WriteString(stmt.Quote(field))
|
||||
return
|
||||
}
|
||||
|
||||
// Quote returns quoted value
|
||||
func (stmt Statement) Quote(field interface{}) (str string) {
|
||||
// FIXME
|
||||
return fmt.Sprint(field)
|
||||
}
|
||||
|
||||
// Write write string
|
||||
func (stmt Statement) AddVar(vars ...interface{}) string {
|
||||
var placeholders strings.Builder
|
||||
|
@ -73,23 +101,34 @@ func (stmt Statement) AddVar(vars ...interface{}) string {
|
|||
return placeholders.String()
|
||||
}
|
||||
|
||||
// Quote returns quoted value
|
||||
func (stmt Statement) Quote(field interface{}) (str string) {
|
||||
return fmt.Sprint(field)
|
||||
}
|
||||
|
||||
// AddClause add clause
|
||||
func (s Statement) AddClause(clause clause.Interface) {
|
||||
s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause)
|
||||
func (stmt Statement) AddClause(v clause.Interface) {
|
||||
if optimizer, ok := v.(StatementOptimizer); ok {
|
||||
optimizer.OptimizeStatement(stmt)
|
||||
}
|
||||
|
||||
c, _ := stmt.Clauses[v.Name()]
|
||||
if namer, ok := v.(clause.OverrideNameInterface); ok {
|
||||
c.Name = namer.OverrideName()
|
||||
} else {
|
||||
c.Name = v.Name()
|
||||
}
|
||||
|
||||
if c.Expression != nil {
|
||||
v.MergeExpression(c.Expression)
|
||||
}
|
||||
|
||||
c.Expression = v
|
||||
stmt.Clauses[v.Name()] = c
|
||||
}
|
||||
|
||||
// BuildCondtions build conditions
|
||||
func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (conditions []clause.Condition) {
|
||||
// BuildCondtion build condition
|
||||
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
|
||||
if sql, ok := query.(string); ok {
|
||||
if i, err := strconv.Atoi(sql); err != nil {
|
||||
query = i
|
||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
||||
return []clause.Condition{clause.Raw{SQL: sql, Values: args}}
|
||||
return []clause.Expression{clause.String{SQL: sql, Values: args}}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,12 +139,12 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi
|
|||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case clause.Builder:
|
||||
case clause.Expression:
|
||||
conditions = append(conditions, v)
|
||||
case *DB:
|
||||
if v.Statement == nil {
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
conditions = append(conditions, cs...)
|
||||
conditions = append(conditions, cs.Expression)
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
|
@ -135,8 +174,22 @@ func (s Statement) BuildCondtions(query interface{}, args ...interface{}) (condi
|
|||
if len(conditions) == 0 {
|
||||
conditions = append(conditions, clause.ID{Value: args})
|
||||
}
|
||||
|
||||
return conditions
|
||||
}
|
||||
|
||||
func (s Statement) AddError(err error) {
|
||||
// Build build sql with clauses names
|
||||
func (stmt Statement) Build(clauses ...string) {
|
||||
var includeSpace bool
|
||||
|
||||
for _, name := range clauses {
|
||||
if c, ok := stmt.Clauses[name]; ok {
|
||||
if includeSpace {
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
includeSpace = true
|
||||
c.Build(stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue