forked from mirror/gorm
Work on create callbacks
This commit is contained in:
parent
728c0d4470
commit
d52ee0aa44
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
|
@ -17,8 +18,14 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
|
||||
db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Table: db.Statement.Table},
|
||||
})
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
fmt.Println(err)
|
||||
fmt.Println(result)
|
||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,9 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)})
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
AndConditions: tx.Statement.BuildCondtion(query, args...),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -63,7 +65,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
|||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))},
|
||||
AndConditions: []clause.Expression{
|
||||
clause.NotConditions(tx.Statement.BuildCondtion(query, args...)),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -72,7 +76,9 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
|||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)},
|
||||
ORConditions: []clause.ORConditions{
|
||||
tx.Statement.BuildCondtion(query, args...),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
package clause
|
||||
|
||||
type Insert struct {
|
||||
Table Table
|
||||
Priority string
|
||||
}
|
||||
|
||||
// Name insert clause name
|
||||
func (insert Insert) Name() string {
|
||||
return "INSERT"
|
||||
}
|
||||
|
||||
// Build build insert clause
|
||||
func (insert Insert) Build(builder Builder) {
|
||||
if insert.Priority != "" {
|
||||
builder.Write(insert.Priority)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.Write("INTO ")
|
||||
builder.WriteQuoted(insert.Table)
|
||||
}
|
||||
|
||||
// MergeExpression merge insert clauses
|
||||
func (insert Insert) MergeExpression(expr Expression) {
|
||||
if v, ok := expr.(Insert); ok {
|
||||
if insert.Priority == "" {
|
||||
insert.Priority = v.Priority
|
||||
}
|
||||
if insert.Table.Table == "" {
|
||||
insert.Table = v.Table
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package clause
|
||||
|
||||
type Values struct {
|
||||
Columns []Column
|
||||
Values [][]interface{}
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (Values) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (values Values) Build(builder Builder) {
|
||||
if len(values.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
|
||||
builder.Write(" VALUES ")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
builder.WriteByte('(')
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.Write(builder.AddVar(value...))
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.Write("DEFAULT VALUES")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
|
||||
db.DB, err = sql.Open("postgres", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
func (Dialector) Migrator() gorm.Migrator {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
||||
return "?"
|
||||
}
|
|
@ -1,29 +1,33 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{}
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (Dialector) Initialize(db *gorm.DB) error {
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
|
||||
return nil
|
||||
db.DB, err = sql.Open("sqlite3", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
func (Dialector) Migrator() gorm.Migrator {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
|
||||
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
|
||||
return "?"
|
||||
}
|
||||
|
|
|
@ -4,7 +4,16 @@ import (
|
|||
"database/sql"
|
||||
)
|
||||
|
||||
func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
|
||||
// Create insert the value into database
|
||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
@ -36,32 +45,12 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create insert the value into database
|
||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
@ -78,7 +67,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) {
|
||||
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
@ -88,16 +77,6 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
@ -119,6 +98,29 @@ func (db *DB) Association(column string) *Association {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Count(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
tx := db.Begin(opts...)
|
||||
|
|
6
go.mod
6
go.mod
|
@ -2,4 +2,8 @@ module github.com/jinzhu/gorm
|
|||
|
||||
go 1.13
|
||||
|
||||
require github.com/jinzhu/inflection v1.0.0
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/lib/pq v1.3.0
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
)
|
||||
|
|
20
gorm.go
20
gorm.go
|
@ -28,10 +28,11 @@ type DB struct {
|
|||
*Config
|
||||
Dialector
|
||||
Instance
|
||||
DB CommonDB
|
||||
clone bool
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
DB CommonDB
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
clone bool
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
|
@ -140,11 +141,12 @@ func (db *DB) getInstance() *DB {
|
|||
Context: ctx,
|
||||
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
|
||||
},
|
||||
Config: db.Config,
|
||||
Dialector: db.Dialector,
|
||||
DB: db.DB,
|
||||
callbacks: db.callbacks,
|
||||
cacheStore: db.cacheStore,
|
||||
Config: db.Config,
|
||||
Dialector: db.Dialector,
|
||||
ClauseBuilders: db.ClauseBuilders,
|
||||
DB: db.DB,
|
||||
callbacks: db.callbacks,
|
||||
cacheStore: db.cacheStore,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
type Dialector interface {
|
||||
Initialize(*DB) error
|
||||
Migrator() Migrator
|
||||
BindVar(stmt Statement, v interface{}) string
|
||||
BindVar(stmt *Statement, v interface{}) string
|
||||
}
|
||||
|
||||
// CommonDB common db interface
|
||||
|
|
49
statement.go
49
statement.go
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -21,7 +22,7 @@ type Instance struct {
|
|||
Statement *Statement
|
||||
}
|
||||
|
||||
func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
||||
func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
||||
if len(clauses) > 0 {
|
||||
instance.Statement.Build(clauses...)
|
||||
}
|
||||
|
@ -29,7 +30,7 @@ func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
|||
}
|
||||
|
||||
// AddError add error to instance
|
||||
func (inst Instance) AddError(err error) {
|
||||
func (inst *Instance) AddError(err error) {
|
||||
if inst.Error == nil {
|
||||
inst.Error = err
|
||||
} else {
|
||||
|
@ -55,11 +56,11 @@ type Statement struct {
|
|||
|
||||
// StatementOptimizer statement optimizer interface
|
||||
type StatementOptimizer interface {
|
||||
OptimizeStatement(Statement)
|
||||
OptimizeStatement(*Statement)
|
||||
}
|
||||
|
||||
// Write write string
|
||||
func (stmt Statement) Write(sql ...string) (err error) {
|
||||
func (stmt *Statement) Write(sql ...string) (err error) {
|
||||
for _, s := range sql {
|
||||
_, err = stmt.SQL.WriteString(s)
|
||||
}
|
||||
|
@ -67,12 +68,12 @@ func (stmt Statement) Write(sql ...string) (err error) {
|
|||
}
|
||||
|
||||
// Write write string
|
||||
func (stmt Statement) WriteByte(c byte) (err error) {
|
||||
func (stmt *Statement) WriteByte(c byte) (err error) {
|
||||
return stmt.SQL.WriteByte(c)
|
||||
}
|
||||
|
||||
// WriteQuoted write quoted field
|
||||
func (stmt Statement) WriteQuoted(field interface{}) (err error) {
|
||||
func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
|
||||
_, err = stmt.SQL.WriteString(stmt.Quote(field))
|
||||
return
|
||||
}
|
||||
|
@ -107,7 +108,7 @@ func (stmt Statement) Quote(field interface{}) string {
|
|||
}
|
||||
|
||||
// Write write string
|
||||
func (stmt Statement) AddVar(vars ...interface{}) string {
|
||||
func (stmt *Statement) AddVar(vars ...interface{}) string {
|
||||
var placeholders strings.Builder
|
||||
for idx, v := range vars {
|
||||
if idx > 0 {
|
||||
|
@ -134,7 +135,7 @@ func (stmt Statement) AddVar(vars ...interface{}) string {
|
|||
}
|
||||
|
||||
// AddClause add clause
|
||||
func (stmt Statement) AddClause(v clause.Interface) {
|
||||
func (stmt *Statement) AddClause(v clause.Interface) {
|
||||
if optimizer, ok := v.(StatementOptimizer); ok {
|
||||
optimizer.OptimizeStatement(stmt)
|
||||
}
|
||||
|
@ -154,6 +155,30 @@ func (stmt Statement) AddClause(v clause.Interface) {
|
|||
stmt.Clauses[v.Name()] = c
|
||||
}
|
||||
|
||||
// AddClauseIfNotExists add clause if not exists
|
||||
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||
if optimizer, ok := v.(StatementOptimizer); ok {
|
||||
optimizer.OptimizeStatement(stmt)
|
||||
}
|
||||
|
||||
log.Println(v.Name())
|
||||
if c, ok := stmt.Clauses[v.Name()]; !ok {
|
||||
if namer, ok := v.(clause.OverrideNameInterface); ok {
|
||||
c.Name = namer.OverrideName()
|
||||
} else {
|
||||
c.Name = v.Name()
|
||||
}
|
||||
|
||||
if c.Expression != nil {
|
||||
v.MergeExpression(c.Expression)
|
||||
}
|
||||
|
||||
c.Expression = v
|
||||
stmt.Clauses[v.Name()] = c
|
||||
log.Println(stmt.Clauses[v.Name()])
|
||||
}
|
||||
}
|
||||
|
||||
// BuildCondtion build condition
|
||||
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
|
||||
if sql, ok := query.(string); ok {
|
||||
|
@ -211,7 +236,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
|||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
func (stmt Statement) Build(clauses ...string) {
|
||||
func (stmt *Statement) Build(clauses ...string) {
|
||||
var firstClauseWritten bool
|
||||
|
||||
for _, name := range clauses {
|
||||
|
@ -221,7 +246,11 @@ func (stmt Statement) Build(clauses ...string) {
|
|||
}
|
||||
|
||||
firstClauseWritten = true
|
||||
c.Build(stmt)
|
||||
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
|
||||
b.Build(c, stmt)
|
||||
} else {
|
||||
c.Build(stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO handle named vars
|
||||
|
|
Loading…
Reference in New Issue