Work on create callbacks

This commit is contained in:
Jinzhu 2020-02-03 10:40:03 +08:00
parent 728c0d4470
commit d52ee0aa44
11 changed files with 224 additions and 64 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
) )
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
@ -17,8 +18,14 @@ func SaveBeforeAssociations(db *gorm.DB) {
} }
func Create(db *gorm.DB) { func Create(db *gorm.DB) {
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") db.Statement.AddClauseIfNotExists(clause.Insert{
db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) Table: clause.Table{Table: db.Statement.Table},
})
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err)
fmt.Println(result)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
} }

View File

@ -55,7 +55,9 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)}) tx.Statement.AddClause(clause.Where{
AndConditions: tx.Statement.BuildCondtion(query, args...),
})
return return
} }
@ -63,7 +65,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(clause.Where{
AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))}, AndConditions: []clause.Expression{
clause.NotConditions(tx.Statement.BuildCondtion(query, args...)),
},
}) })
return return
} }
@ -72,7 +76,9 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)}, ORConditions: []clause.ORConditions{
tx.Statement.BuildCondtion(query, args...),
},
}) })
return return
} }

34
clause/insert.go Normal file
View File

@ -0,0 +1,34 @@
package clause
type Insert struct {
Table Table
Priority string
}
// Name insert clause name
func (insert Insert) Name() string {
return "INSERT"
}
// Build build insert clause
func (insert Insert) Build(builder Builder) {
if insert.Priority != "" {
builder.Write(insert.Priority)
builder.WriteByte(' ')
}
builder.Write("INTO ")
builder.WriteQuoted(insert.Table)
}
// MergeExpression merge insert clauses
func (insert Insert) MergeExpression(expr Expression) {
if v, ok := expr.(Insert); ok {
if insert.Priority == "" {
insert.Priority = v.Priority
}
if insert.Table.Table == "" {
insert.Table = v.Table
}
}
}

39
clause/value.go Normal file
View File

@ -0,0 +1,39 @@
package clause
type Values struct {
Columns []Column
Values [][]interface{}
}
// Name from clause name
func (Values) Name() string {
return ""
}
// Build build from clause
func (values Values) Build(builder Builder) {
if len(values.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteByte(')')
builder.Write(" VALUES ")
for idx, value := range values.Values {
builder.WriteByte('(')
if idx > 0 {
builder.WriteByte(',')
}
builder.Write(builder.AddVar(value...))
builder.WriteByte(')')
}
} else {
builder.Write("DEFAULT VALUES")
}
}

View File

@ -0,0 +1,33 @@
package postgres
import (
"database/sql"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
_ "github.com/lib/pq"
)
type Dialector struct {
DSN string
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("postgres", dialector.DSN)
return
}
func (Dialector) Migrator() gorm.Migrator {
return nil
}
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

View File

@ -1,29 +1,33 @@
package sqlite package sqlite
import ( import (
"database/sql"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
type Dialector struct { type Dialector struct {
DSN string
} }
func Open(dsn string) gorm.Dialector { func Open(dsn string) gorm.Dialector {
return &Dialector{} return &Dialector{DSN: dsn}
} }
func (Dialector) Initialize(db *gorm.DB) error { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db)
return nil db.DB, err = sql.Open("sqlite3", dialector.DSN)
return
} }
func (Dialector) Migrator() gorm.Migrator { func (Dialector) Migrator() gorm.Migrator {
return nil return nil
} }
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "?"
} }

View File

@ -4,7 +4,16 @@ import (
"database/sql" "database/sql"
) )
func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { // Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
return
}
// Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
@ -36,32 +45,12 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
return return
} }
func (db *DB) Row() *sql.Row { func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
return nil
}
func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil
}
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
return nil
}
// Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) {
tx = db.getInstance()
return
}
// Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
@ -78,7 +67,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) {
return return
} }
func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
@ -88,16 +77,6 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return return
} }
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -119,6 +98,29 @@ func (db *DB) Association(column string) *Association {
return nil return nil
} }
func (db *DB) Count(value interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) Row() *sql.Row {
return nil
}
func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil
}
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance()
return
}
func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
return nil
}
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
tx := db.Begin(opts...) tx := db.Begin(opts...)

6
go.mod
View File

@ -2,4 +2,8 @@ module github.com/jinzhu/gorm
go 1.13 go 1.13
require github.com/jinzhu/inflection v1.0.0 require (
github.com/jinzhu/inflection v1.0.0
github.com/lib/pq v1.3.0
github.com/mattn/go-sqlite3 v2.0.3+incompatible
)

20
gorm.go
View File

@ -28,10 +28,11 @@ type DB struct {
*Config *Config
Dialector Dialector
Instance Instance
DB CommonDB DB CommonDB
clone bool ClauseBuilders map[string]clause.ClauseBuilder
callbacks *callbacks clone bool
cacheStore *sync.Map callbacks *callbacks
cacheStore *sync.Map
} }
// Session session config when create session with Session() method // Session session config when create session with Session() method
@ -140,11 +141,12 @@ func (db *DB) getInstance() *DB {
Context: ctx, Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
}, },
Config: db.Config, Config: db.Config,
Dialector: db.Dialector, Dialector: db.Dialector,
DB: db.DB, ClauseBuilders: db.ClauseBuilders,
callbacks: db.callbacks, DB: db.DB,
cacheStore: db.cacheStore, callbacks: db.callbacks,
cacheStore: db.cacheStore,
} }
} }

View File

@ -9,7 +9,7 @@ import (
type Dialector interface { type Dialector interface {
Initialize(*DB) error Initialize(*DB) error
Migrator() Migrator Migrator() Migrator
BindVar(stmt Statement, v interface{}) string BindVar(stmt *Statement, v interface{}) string
} }
// CommonDB common db interface // CommonDB common db interface

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"log"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -21,7 +22,7 @@ type Instance struct {
Statement *Statement Statement *Statement
} }
func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
if len(clauses) > 0 { if len(clauses) > 0 {
instance.Statement.Build(clauses...) instance.Statement.Build(clauses...)
} }
@ -29,7 +30,7 @@ func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) {
} }
// AddError add error to instance // AddError add error to instance
func (inst Instance) AddError(err error) { func (inst *Instance) AddError(err error) {
if inst.Error == nil { if inst.Error == nil {
inst.Error = err inst.Error = err
} else { } else {
@ -55,11 +56,11 @@ type Statement struct {
// StatementOptimizer statement optimizer interface // StatementOptimizer statement optimizer interface
type StatementOptimizer interface { type StatementOptimizer interface {
OptimizeStatement(Statement) OptimizeStatement(*Statement)
} }
// Write write string // Write write string
func (stmt Statement) Write(sql ...string) (err error) { func (stmt *Statement) Write(sql ...string) (err error) {
for _, s := range sql { for _, s := range sql {
_, err = stmt.SQL.WriteString(s) _, err = stmt.SQL.WriteString(s)
} }
@ -67,12 +68,12 @@ func (stmt Statement) Write(sql ...string) (err error) {
} }
// Write write string // Write write string
func (stmt Statement) WriteByte(c byte) (err error) { func (stmt *Statement) WriteByte(c byte) (err error) {
return stmt.SQL.WriteByte(c) return stmt.SQL.WriteByte(c)
} }
// WriteQuoted write quoted field // WriteQuoted write quoted field
func (stmt Statement) WriteQuoted(field interface{}) (err error) { func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
_, err = stmt.SQL.WriteString(stmt.Quote(field)) _, err = stmt.SQL.WriteString(stmt.Quote(field))
return return
} }
@ -107,7 +108,7 @@ func (stmt Statement) Quote(field interface{}) string {
} }
// Write write string // Write write string
func (stmt Statement) AddVar(vars ...interface{}) string { func (stmt *Statement) AddVar(vars ...interface{}) string {
var placeholders strings.Builder var placeholders strings.Builder
for idx, v := range vars { for idx, v := range vars {
if idx > 0 { if idx > 0 {
@ -134,7 +135,7 @@ func (stmt Statement) AddVar(vars ...interface{}) string {
} }
// AddClause add clause // AddClause add clause
func (stmt Statement) AddClause(v clause.Interface) { func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementOptimizer); ok { if optimizer, ok := v.(StatementOptimizer); ok {
optimizer.OptimizeStatement(stmt) optimizer.OptimizeStatement(stmt)
} }
@ -154,6 +155,30 @@ func (stmt Statement) AddClause(v clause.Interface) {
stmt.Clauses[v.Name()] = c stmt.Clauses[v.Name()] = c
} }
// AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
if optimizer, ok := v.(StatementOptimizer); ok {
optimizer.OptimizeStatement(stmt)
}
log.Println(v.Name())
if c, ok := stmt.Clauses[v.Name()]; !ok {
if namer, ok := v.(clause.OverrideNameInterface); ok {
c.Name = namer.OverrideName()
} else {
c.Name = v.Name()
}
if c.Expression != nil {
v.MergeExpression(c.Expression)
}
c.Expression = v
stmt.Clauses[v.Name()] = c
log.Println(stmt.Clauses[v.Name()])
}
}
// BuildCondtion build condition // BuildCondtion build condition
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
if sql, ok := query.(string); ok { if sql, ok := query.(string); ok {
@ -211,7 +236,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
} }
// Build build sql with clauses names // Build build sql with clauses names
func (stmt Statement) Build(clauses ...string) { func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool var firstClauseWritten bool
for _, name := range clauses { for _, name := range clauses {
@ -221,7 +246,11 @@ func (stmt Statement) Build(clauses ...string) {
} }
firstClauseWritten = true firstClauseWritten = true
c.Build(stmt) if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b.Build(c, stmt)
} else {
c.Build(stmt)
}
} }
} }
// TODO handle named vars // TODO handle named vars