Use *gorm.DB to replace gorm.DB

This commit is contained in:
Jinzhu 2020-03-09 20:37:01 +08:00
parent 2a0c3e39f2
commit 9e8a4db36b
7 changed files with 84 additions and 88 deletions

View File

@ -90,7 +90,6 @@ func (p *processor) Execute(db *DB) {
} }
if stmt := db.Statement; stmt != nil { if stmt := db.Statement; stmt != nil {
db.Error = stmt.Error
db.RowsAffected = stmt.RowsAffected db.RowsAffected = stmt.RowsAffected
db.Logger.Trace(curTime, func() (string, int64) { db.Logger.Trace(curTime, func() (string, int64) {

View File

@ -13,14 +13,14 @@ import (
// db.Model(&User{}).Update("name", "hello") // db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello") // db.Model(&user).Update("name", "hello")
func (db DB) Model(value interface{}) (tx DB) { func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Model = value tx.Statement.Model = value
return return
} }
// Clauses Add clauses // Clauses Add clauses
func (db DB) Clauses(conds ...clause.Expression) (tx DB) { func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
var whereConds []interface{} var whereConds []interface{}
@ -39,14 +39,14 @@ func (db DB) Clauses(conds ...clause.Expression) (tx DB) {
} }
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
func (db DB) Table(name string) (tx DB) { func (db *DB) Table(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Table = name tx.Statement.Table = name
return return
} }
// Select specify fields that you want when querying, creating, updating // Select specify fields that you want when querying, creating, updating
func (db DB) Select(query interface{}, args ...interface{}) (tx DB) { func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := query.(type) { switch v := query.(type) {
@ -97,7 +97,7 @@ func (db DB) Select(query interface{}, args ...interface{}) (tx DB) {
} }
// Omit specify fields that you want to ignore when creating, updating and querying // Omit specify fields that you want to ignore when creating, updating and querying
func (db DB) Omit(columns ...string) (tx DB) { func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
@ -108,21 +108,21 @@ func (db DB) Omit(columns ...string) (tx DB) {
return return
} }
func (db DB) Where(query interface{}, args ...interface{}) (tx DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)})
return return
} }
// Not add NOT condition // Not add NOT condition
func (db DB) Not(query interface{}, args ...interface{}) (tx DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}})
return return
} }
// Or add OR conditions // Or add OR conditions
func (db DB) Or(query interface{}, args ...interface{}) (tx DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}}) tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}})
return return
@ -131,13 +131,13 @@ func (db DB) Or(query interface{}, args ...interface{}) (tx DB) {
// Joins specify Joins conditions // Joins specify Joins conditions
// db.Joins("Account").Find(&user) // db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (db DB) Joins(query string, args ...interface{}) (tx DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
// Group specify the group method on the find // Group specify the group method on the find
func (db DB) Group(name string) (tx DB) { func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name}}, Columns: []clause.Column{{Name: name}},
@ -146,7 +146,7 @@ func (db DB) Group(name string) (tx DB) {
} }
// Having specify HAVING conditions for GROUP BY // Having specify HAVING conditions for GROUP BY
func (db DB) Having(query interface{}, args ...interface{}) (tx DB) { func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
Having: tx.Statement.BuildCondtion(query, args...), Having: tx.Statement.BuildCondtion(query, args...),
@ -157,7 +157,7 @@ func (db DB) Having(query interface{}, args ...interface{}) (tx DB) {
// Order specify order when retrieve records from database // Order specify order when retrieve records from database
// db.Order("name DESC") // db.Order("name DESC")
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (db DB) Order(value interface{}) (tx DB) { func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := value.(type) { switch v := value.(type) {
@ -176,14 +176,14 @@ func (db DB) Order(value interface{}) (tx DB) {
} }
// Limit specify the number of records to be retrieved // Limit specify the number of records to be retrieved
func (db DB) Limit(limit int) (tx DB) { func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit}) tx.Statement.AddClause(clause.Limit{Limit: limit})
return return
} }
// Offset specify the number of records to skip before starting to return the records // Offset specify the number of records to skip before starting to return the records
func (db DB) Offset(offset int) (tx DB) { func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset}) tx.Statement.AddClause(clause.Limit{Offset: offset})
return return
@ -201,7 +201,7 @@ func (db DB) Offset(offset int) (tx DB) {
// } // }
// //
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db DB) Scopes(funcs ...func(DB) DB) DB { func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs { for _, f := range funcs {
db = f(db) db = f(db)
} }
@ -210,27 +210,27 @@ func (db DB) Scopes(funcs ...func(DB) DB) DB {
// Preload preload associations with given conditions // Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db DB) Preload(column string, conditions ...interface{}) (tx DB) { func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) Assign(attrs ...interface{}) (tx DB) { func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) Attrs(attrs ...interface{}) (tx DB) { func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) Unscoped() (tx DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) Raw(sql string, values ...interface{}) (tx DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)

View File

@ -12,7 +12,7 @@ import (
) )
var ( var (
DB gorm.DB DB *gorm.DB
err error err error
) )
@ -23,9 +23,9 @@ func init() {
} }
func TestCURD(t *testing.T) { func TestCURD(t *testing.T) {
tests.RunTestsSuit(t, &DB) tests.RunTestsSuit(t, DB)
} }
func TestMigrate(t *testing.T) { func TestMigrate(t *testing.T) {
tests.TestMigrate(t, &DB) tests.TestMigrate(t, DB)
} }

View File

@ -9,15 +9,15 @@ import (
) )
// Create insert the value into database // Create insert the value into database
func (db DB) Create(value interface{}) (tx DB) { func (db *DB) Create(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Create().Execute(&tx) tx.callbacks.Create().Execute(tx)
return return
} }
// Save update value in database, if the value doesn't have primary key, will insert it // Save update value in database, if the value doesn't have primary key, will insert it
func (db DB) Save(value interface{}) (tx DB) { func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
@ -26,7 +26,7 @@ func (db DB) Save(value interface{}) (tx DB) {
reflectValue := reflect.ValueOf(value) reflectValue := reflect.ValueOf(value)
for idx, pf := range tx.Statement.Schema.PrimaryFields { for idx, pf := range tx.Statement.Schema.PrimaryFields {
if pv, isZero := pf.ValueOf(reflectValue); isZero { if pv, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(&tx) tx.callbacks.Create().Execute(tx)
where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv}
return return
} }
@ -38,12 +38,12 @@ func (db DB) Save(value interface{}) (tx DB) {
if len(tx.Statement.Selects) == 0 { if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = []string{"*"} tx.Statement.Selects = []string{"*"}
} }
tx.callbacks.Update().Execute(&tx) tx.callbacks.Update().Execute(tx)
return return
} }
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (db DB) First(out interface{}, conds ...interface{}) (tx DB) { func (db *DB) First(out interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
@ -52,24 +52,24 @@ func (db DB) First(out interface{}, conds ...interface{}) (tx DB) {
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = out tx.Statement.Dest = out
tx.callbacks.Query().Execute(&tx) tx.callbacks.Query().Execute(tx)
return return
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take return a record that match given conditions, the order will depend on the database implementation
func (db DB) Take(out interface{}, conds ...interface{}) (tx DB) { func (db *DB) Take(out interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1) tx = db.getInstance().Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = out tx.Statement.Dest = out
tx.callbacks.Query().Execute(&tx) tx.callbacks.Query().Execute(tx)
return return
} }
// Last find last record that match given conditions, order by primary key // Last find last record that match given conditions, order by primary key
func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) { func (db *DB) Last(out interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true, Desc: true,
@ -79,101 +79,101 @@ func (db DB) Last(out interface{}, conds ...interface{}) (tx DB) {
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = out tx.Statement.Dest = out
tx.callbacks.Query().Execute(&tx) tx.callbacks.Query().Execute(tx)
return return
} }
// Find find records that match given conditions // Find find records that match given conditions
func (db DB) Find(out interface{}, conds ...interface{}) (tx DB) { func (db *DB) Find(out interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
} }
tx.Statement.Dest = out tx.Statement.Dest = out
tx.callbacks.Query().Execute(&tx) tx.callbacks.Query().Execute(tx)
return return
} }
func (db DB) FirstOrInit(out interface{}, where ...interface{}) (tx DB) { func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) FirstOrCreate(out interface{}, where ...interface{}) (tx DB) { func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (db DB) Update(column string, value interface{}) (tx DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.callbacks.Update().Execute(&tx) tx.callbacks.Update().Execute(tx)
return return
} }
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (db DB) Updates(values interface{}) (tx DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.callbacks.Update().Execute(&tx) tx.callbacks.Update().Execute(tx)
return return
} }
func (db DB) UpdateColumn(column string, value interface{}) (tx DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.callbacks.Update().Execute(&tx) tx.callbacks.Update().Execute(tx)
return return
} }
func (db DB) UpdateColumns(values interface{}) (tx DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.callbacks.Update().Execute(&tx) tx.callbacks.Update().Execute(tx)
return return
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db DB) Delete(value interface{}, conds ...interface{}) (tx DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
} }
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Delete().Execute(&tx) tx.callbacks.Delete().Execute(tx)
return return
} }
func (db DB) Count(value interface{}) (tx DB) { func (db *DB) Count(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
tx := db.getInstance() tx := db.getInstance()
tx.callbacks.Row().Execute(&tx) tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row) return tx.Statement.Dest.(*sql.Row)
} }
func (db DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
tx := db.Set("rows", true) tx := db.Set("rows", true)
tx.callbacks.Row().Execute(&tx) tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error return tx.Statement.Dest.(*sql.Rows), tx.Error
} }
// Scan scan value to a struct // Scan scan value to a struct
func (db DB) Scan(dest interface{}) (tx DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
return return
} }
func (db DB) ScanRows(rows *sql.Rows, result interface{}) error { func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
return nil return nil
} }
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
tx := db.Begin(opts...) tx := db.Begin(opts...)
defer func() { defer func() {
@ -194,7 +194,7 @@ func (db DB) Transaction(fc func(tx DB) error, opts ...*sql.TxOptions) (err erro
} }
// Begin begins a transaction // Begin begins a transaction
func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) { func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
var opt *sql.TxOptions var opt *sql.TxOptions
@ -213,7 +213,7 @@ func (db DB) Begin(opts ...*sql.TxOptions) (tx DB) {
} }
// Commit commit a transaction // Commit commit a transaction
func (db DB) Commit() DB { func (db *DB) Commit() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Commit()) db.AddError(comminter.Commit())
} else { } else {
@ -223,7 +223,7 @@ func (db DB) Commit() DB {
} }
// Rollback rollback a transaction // Rollback rollback a transaction
func (db DB) Rollback() DB { func (db *DB) Rollback() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Rollback()) db.AddError(comminter.Rollback())
} else { } else {
@ -233,10 +233,10 @@ func (db DB) Rollback() DB {
} }
// Exec execute raw sql // Exec execute raw sql
func (db DB) Exec(sql string, values ...interface{}) (tx DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
tx.callbacks.Raw().Execute(&tx) tx.callbacks.Raw().Execute(tx)
return return
} }

35
gorm.go
View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@ -51,7 +52,7 @@ type Session struct {
} }
// Open initialize db session based on dialector // Open initialize db session based on dialector
func Open(dialector Dialector, config *Config) (db DB, err error) { func Open(dialector Dialector, config *Config) (db *DB, err error) {
if config == nil { if config == nil {
config = &Config{} config = &Config{}
} }
@ -87,21 +88,21 @@ func Open(dialector Dialector, config *Config) (db DB, err error) {
}, },
} }
db = DB{ db = &DB{
Config: config, Config: config,
clone: true, clone: true,
} }
db.callbacks = initializeCallbacks(&db) db.callbacks = initializeCallbacks(db)
if dialector != nil { if dialector != nil {
err = dialector.Initialize(&db) err = dialector.Initialize(db)
} }
return return
} }
// Session create new db session // Session create new db session
func (db DB) Session(config *Session) DB { func (db *DB) Session(config *Session) *DB {
var ( var (
tx = db.getInstance() tx = db.getInstance()
txConfig = *tx.Config txConfig = *tx.Config
@ -125,24 +126,24 @@ func (db DB) Session(config *Session) DB {
} }
// WithContext change current instance db's context to ctx // WithContext change current instance db's context to ctx
func (db DB) WithContext(ctx context.Context) DB { func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx}) return db.Session(&Session{Context: ctx})
} }
// Debug start debug mode // Debug start debug mode
func (db DB) Debug() (tx DB) { func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
} }
// Set store value with key into current db instance's context // Set store value with key into current db instance's context
func (db DB) Set(key string, value interface{}) DB { func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance() tx := db.getInstance()
tx.Statement.Settings.Store(key, value) tx.Statement.Settings.Store(key, value)
return tx return tx
} }
// Get get value with key from current db instance's context // Get get value with key from current db instance's context
func (db DB) Get(key string) (interface{}, bool) { func (db *DB) Get(key string) (interface{}, bool) {
if db.Statement != nil { if db.Statement != nil {
return db.Statement.Settings.Load(key) return db.Statement.Settings.Load(key)
} }
@ -150,28 +151,32 @@ func (db DB) Get(key string) (interface{}, bool) {
} }
// Callback returns callback manager // Callback returns callback manager
func (db DB) Callback() *callbacks { func (db *DB) Callback() *callbacks {
return db.callbacks return db.callbacks
} }
// AutoMigrate run auto migration for given models // AutoMigrate run auto migration for given models
func (db DB) AutoMigrate(dst ...interface{}) error { func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...) return db.Migrator().AutoMigrate(dst...)
} }
// AddError add error to db // AddError add error to db
func (db DB) AddError(err error) { func (db *DB) AddError(err error) {
db.Statement.AddError(err) if db.Error == nil {
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
} }
func (db DB) getInstance() DB { func (db *DB) getInstance() *DB {
if db.clone { if db.clone {
stmt := db.Config.statementPool.Get().(*Statement) stmt := db.Config.statementPool.Get().(*Statement)
if db.Statement != nil { if db.Statement != nil {
stmt.Context = db.Statement.Context stmt.Context = db.Statement.Context
} }
return DB{Config: db.Config, Statement: stmt} return &DB{Config: db.Config, Statement: stmt}
} }
return db return db

View File

@ -27,7 +27,7 @@ type Config struct {
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := m.DB.Statement stmt := m.DB.Statement
if stmt == nil { if stmt == nil {
stmt = &gorm.Statement{DB: *m.DB} stmt = &gorm.Statement{DB: m.DB}
} }
if err := stmt.Parse(value); err != nil { if err := stmt.Parse(value); err != nil {
@ -496,7 +496,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
parseDependence := func(value interface{}, addToList bool) { parseDependence := func(value interface{}, addToList bool) {
dep := Dependency{ dep := Dependency{
Statement: &gorm.Statement{DB: *m.DB, Dest: value}, Statement: &gorm.Statement{DB: m.DB, Dest: value},
} }
dep.Parse(value) dep.Parse(value)

View File

@ -16,6 +16,7 @@ import (
// Statement statement // Statement statement
type Statement struct { type Statement struct {
*DB
Table string Table string
Model interface{} Model interface{}
Dest interface{} Dest interface{}
@ -25,7 +26,6 @@ type Statement struct {
Omits []string // omit columns Omits []string // omit columns
Settings sync.Map Settings sync.Map
ConnPool ConnPool ConnPool ConnPool
DB DB
Schema *schema.Schema Schema *schema.Schema
Context context.Context Context context.Context
Error error Error error
@ -219,14 +219,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
return conditions return conditions
} }
func (stmt *Statement) AddError(err error) {
if stmt.Error == nil {
stmt.Error = err
} else if err != nil {
stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err)
}
}
// Build build sql with clauses names // Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool var firstClauseWritten bool