gorm/finisher_api.go

624 lines
17 KiB
Go
Raw Normal View History

2020-01-29 14:22:44 +03:00
package gorm
import (
"database/sql"
2020-05-28 08:12:56 +03:00
"errors"
"fmt"
2020-03-08 08:24:08 +03:00
"reflect"
2020-02-22 15:57:29 +03:00
"strings"
2020-02-04 03:56:15 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
2020-01-29 14:22:44 +03:00
)
2020-02-03 05:40:03 +03:00
// Create insert the value into database
2020-03-09 15:37:01 +03:00
func (db *DB) Create(value interface{}) (tx *DB) {
2020-12-02 09:59:50 +03:00
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
2020-02-03 05:40:03 +03:00
tx = db.getInstance()
tx.Statement.Dest = value
2020-03-09 15:37:01 +03:00
tx.callbacks.Create().Execute(tx)
2020-02-03 05:40:03 +03:00
return
}
2020-11-16 16:42:30 +03:00
// CreateInBatches insert the value in batches into database
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
2020-12-02 09:59:50 +03:00
var rowsAffected int64
2020-11-16 16:42:30 +03:00
tx = db.getInstance()
2020-12-02 09:59:50 +03:00
tx.AddError(tx.Transaction(func(tx *DB) error {
for i := 0; i < reflectValue.Len(); i += batchSize {
2020-11-16 16:42:30 +03:00
ends := i + batchSize
if ends > reflectValue.Len() {
ends = reflectValue.Len()
}
2020-12-02 09:59:50 +03:00
subtx := tx.getInstance()
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
subtx.callbacks.Create().Execute(subtx)
if subtx.Error != nil {
return subtx.Error
}
rowsAffected += subtx.RowsAffected
}
return nil
}))
tx.RowsAffected = rowsAffected
2020-11-16 16:42:30 +03:00
default:
2020-12-02 09:59:50 +03:00
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
2020-11-16 16:42:30 +03:00
}
return
}
2020-02-03 05:40:03 +03:00
// Save update value in database, if the value doesn't have primary key, will insert it
2020-03-09 15:37:01 +03:00
func (db *DB) Save(value interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-08 08:24:08 +03:00
tx.Statement.Dest = value
2020-06-09 19:02:14 +03:00
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
2020-11-16 15:22:08 +03:00
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
2020-06-09 19:02:14 +03:00
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero {
2020-05-24 15:44:37 +03:00
tx.callbacks.Create().Execute(tx)
return
}
2020-03-08 08:24:08 +03:00
}
2020-06-06 17:52:08 +03:00
}
2020-03-08 08:24:08 +03:00
2020-06-09 19:02:14 +03:00
fallthrough
default:
2020-08-30 05:12:49 +03:00
selectedUpdate := len(tx.Statement.Selects) != 0
// when updating, use all fields including those zero-value fields
if !selectedUpdate {
2020-06-09 19:02:14 +03:00
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
2020-08-30 05:12:49 +03:00
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
2020-08-30 05:12:49 +03:00
return tx.Create(value)
}
}
2020-05-23 18:50:48 +03:00
}
2020-06-09 19:02:14 +03:00
2020-01-29 14:22:44 +03:00
return
}
// First find first record that match given conditions, order by primary key
2020-05-26 16:30:17 +03:00
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
2020-06-08 17:32:35 +03:00
tx = db.Limit(1).Order(clause.OrderByColumn{
2020-02-04 03:56:15 +03:00
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
2020-03-07 08:43:20 +03:00
if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
2020-03-07 08:43:20 +03:00
}
2020-03-03 09:18:12 +03:00
tx.Statement.RaiseErrorOnNotFound = true
2020-05-26 16:30:17 +03:00
tx.Statement.Dest = dest
2020-03-09 15:37:01 +03:00
tx.callbacks.Query().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
// Take return a record that match given conditions, the order will depend on the database implementation
2020-05-26 16:30:17 +03:00
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
2020-06-08 17:32:35 +03:00
tx = db.Limit(1)
2020-03-07 08:43:20 +03:00
if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
2020-03-07 08:43:20 +03:00
}
2020-03-03 09:18:12 +03:00
tx.Statement.RaiseErrorOnNotFound = true
2020-05-26 16:30:17 +03:00
tx.Statement.Dest = dest
2020-03-09 15:37:01 +03:00
tx.callbacks.Query().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
// Last find last record that match given conditions, order by primary key
2020-05-26 16:30:17 +03:00
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
2020-06-08 17:32:35 +03:00
tx = db.Limit(1).Order(clause.OrderByColumn{
2020-02-23 18:28:35 +03:00
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
2020-03-04 06:32:36 +03:00
Desc: true,
2020-02-23 18:28:35 +03:00
})
2020-03-07 08:43:20 +03:00
if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
2020-03-07 08:43:20 +03:00
}
2020-03-03 09:18:12 +03:00
tx.Statement.RaiseErrorOnNotFound = true
2020-05-26 16:30:17 +03:00
tx.Statement.Dest = dest
2020-03-09 15:37:01 +03:00
tx.callbacks.Query().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
// Find find records that match given conditions
2020-05-26 16:30:17 +03:00
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-07 08:43:20 +03:00
if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
2020-03-07 08:43:20 +03:00
}
2020-05-26 16:30:17 +03:00
tx.Statement.Dest = dest
2020-03-09 15:37:01 +03:00
tx.callbacks.Query().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
2020-06-10 10:36:29 +03:00
// FindInBatches find records in batches
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
batch int
)
2020-06-10 10:36:29 +03:00
for {
result := queryDB.Limit(batchSize).Find(dest)
2020-06-10 10:36:29 +03:00
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
tx.AddError(fc(result, batch))
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
} else {
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
2020-06-10 10:36:29 +03:00
}
}
tx.RowsAffected = rowsAffected
return tx
2020-06-10 10:36:29 +03:00
}
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
tx.assignInterfacesToValue(andCond.Exprs)
2020-05-28 08:12:56 +03:00
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
exprs := tx.Statement.BuildCondition(value)
tx.assignInterfacesToValue(exprs)
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
}
}
}
}
2020-05-28 08:12:56 +03:00
}
} else if len(values) > 0 {
exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
tx.assignInterfacesToValue(exprs)
return
2020-05-28 08:12:56 +03:00
}
}
}
}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
2020-05-28 08:12:56 +03:00
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
2020-05-28 08:12:56 +03:00
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
2020-05-28 08:12:56 +03:00
}
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
2020-05-28 08:12:56 +03:00
}
2020-01-29 14:22:44 +03:00
return
}
2020-05-28 11:10:10 +03:00
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
2020-05-28 11:10:10 +03:00
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
2020-05-28 11:10:10 +03:00
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
2020-05-28 11:10:10 +03:00
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
2020-05-28 11:10:10 +03:00
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
2020-05-28 11:10:10 +03:00
}
return tx.Create(dest)
2020-06-08 17:32:35 +03:00
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
2020-05-28 11:10:10 +03:00
assigns := map[string]interface{}{}
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
default:
}
}
}
return tx.Model(dest).Updates(assigns)
}
2020-05-28 08:12:56 +03:00
2020-06-08 17:32:35 +03:00
return db
2020-01-29 14:22:44 +03:00
}
2020-06-15 07:28:35 +03:00
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
2020-03-09 15:37:01 +03:00
func (db *DB) Update(column string, value interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-07 08:43:20 +03:00
tx.Statement.Dest = map[string]interface{}{column: value}
2020-03-09 15:37:01 +03:00
tx.callbacks.Update().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
2020-06-15 07:28:35 +03:00
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
2020-03-09 15:37:01 +03:00
func (db *DB) Updates(values interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-07 08:43:20 +03:00
tx.Statement.Dest = values
2020-03-09 15:37:01 +03:00
tx.callbacks.Update().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
2020-03-09 15:37:01 +03:00
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-07 08:43:20 +03:00
tx.Statement.Dest = map[string]interface{}{column: value}
2020-11-17 12:49:43 +03:00
tx.Statement.SkipHooks = true
2020-03-09 15:37:01 +03:00
tx.callbacks.Update().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
2020-03-09 15:37:01 +03:00
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-07 08:43:20 +03:00
tx.Statement.Dest = values
2020-11-17 12:49:43 +03:00
tx.Statement.SkipHooks = true
2020-03-09 15:37:01 +03:00
tx.callbacks.Update().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
2020-02-03 05:40:03 +03:00
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
2020-03-09 15:37:01 +03:00
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-03-08 09:51:52 +03:00
if len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
2020-03-08 09:51:52 +03:00
}
tx.Statement.Dest = value
2020-03-09 15:37:01 +03:00
tx.callbacks.Delete().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}
2020-05-24 06:32:59 +03:00
func (db *DB) Count(count *int64) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-05-24 06:32:59 +03:00
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
2020-08-19 15:30:39 +03:00
defer func() {
tx.Statement.Model = nil
}()
2020-05-24 06:32:59 +03:00
}
2020-06-05 14:19:08 +03:00
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
db.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
2020-06-05 14:19:08 +03:00
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
2020-06-23 03:51:01 +03:00
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
expr := clause.Expr{SQL: "count(1)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
2020-06-23 03:51:01 +03:00
}
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
}
2020-06-05 14:19:08 +03:00
}
2020-06-23 03:51:01 +03:00
tx.Statement.AddClause(clause.Select{Expression: expr})
2020-06-05 14:19:08 +03:00
}
2020-10-22 06:28:43 +03:00
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(db.Statement.Clauses, "ORDER BY")
defer func() {
db.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
2020-05-24 06:32:59 +03:00
tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx)
2020-06-23 17:41:41 +03:00
if tx.RowsAffected != 1 {
*count = tx.RowsAffected
2020-05-24 06:32:59 +03:00
}
2020-01-29 14:22:44 +03:00
return
}
2020-03-09 15:37:01 +03:00
func (db *DB) Row() *sql.Row {
tx := db.getInstance().InstanceSet("rows", false)
2020-03-09 15:37:01 +03:00
tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
2020-02-03 05:40:03 +03:00
}
2020-03-09 15:37:01 +03:00
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().InstanceSet("rows", true)
2020-03-09 15:37:01 +03:00
tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
2020-02-03 05:40:03 +03:00
}
// Scan scan value to a struct
2020-03-09 15:37:01 +03:00
func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
2020-01-30 10:14:48 +03:00
tx = db.getInstance()
tx.Config = &config
if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {
defer rows.Close()
if rows.Next() {
tx.ScanRows(rows, dest)
}
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
2020-01-30 10:14:48 +03:00
return
}
2020-05-31 13:51:43 +03:00
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
2020-06-05 14:19:08 +03:00
if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
2020-06-09 10:34:55 +03:00
} else if tx.Statement.Table == "" {
tx.AddError(ErrModelValueRequired)
2020-06-05 14:19:08 +03:00
}
2020-06-09 10:34:55 +03:00
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
})
}
2020-06-09 10:34:55 +03:00
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
2020-05-31 13:51:43 +03:00
return
}
2020-05-26 18:13:05 +03:00
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance()
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
2020-05-26 18:13:05 +03:00
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
}
2020-05-26 18:13:05 +03:00
Scan(rows, tx, true)
return tx.Error
2020-01-30 10:14:48 +03:00
}
2020-02-23 18:28:35 +03:00
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
2020-03-09 15:37:01 +03:00
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
2020-01-29 14:22:44 +03:00
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
2020-12-16 14:33:35 +03:00
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
2020-11-19 05:45:17 +03:00
if err == nil {
err = fc(db.Session(&Session{}))
}
} else {
tx := db.Begin(opts...)
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
2020-11-19 05:45:17 +03:00
if err = tx.Error; err == nil {
err = fc(tx)
}
2020-01-29 14:22:44 +03:00
if err == nil {
err = tx.Commit().Error
}
2020-01-29 14:22:44 +03:00
}
panicked = false
return
}
2020-02-23 18:28:35 +03:00
// Begin begins a transaction
2020-06-05 05:08:22 +03:00
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.Session(&Session{Context: db.Statement.Context})
2020-06-05 05:08:22 +03:00
opt *sql.TxOptions
err error
)
if len(opts) > 0 {
opt = opts[0]
}
2020-02-23 18:28:35 +03:00
2020-06-05 05:08:22 +03:00
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
2020-02-23 18:28:35 +03:00
} else {
2020-06-05 05:08:22 +03:00
err = ErrInvalidTransaction
2020-02-23 18:28:35 +03:00
}
2020-06-05 05:08:22 +03:00
if err != nil {
tx.AddError(err)
}
return tx
2020-01-29 14:22:44 +03:00
}
2020-02-23 18:28:35 +03:00
// Commit commit a transaction
2020-03-09 15:37:01 +03:00
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
2020-06-05 05:08:22 +03:00
db.AddError(committer.Commit())
2020-02-23 18:28:35 +03:00
} else {
db.AddError(ErrInvalidTransaction)
}
return db
2020-01-29 14:22:44 +03:00
}
2020-02-23 18:28:35 +03:00
// Rollback rollback a transaction
2020-03-09 15:37:01 +03:00
func (db *DB) Rollback() *DB {
2020-06-05 05:08:22 +03:00
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
2020-02-23 18:28:35 +03:00
} else {
db.AddError(ErrInvalidTransaction)
}
return db
2020-01-29 14:22:44 +03:00
}
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
2020-07-16 06:27:04 +03:00
db.AddError(savePointer.SavePoint(db, name))
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
2020-07-16 06:27:04 +03:00
db.AddError(savePointer.RollbackTo(db, name))
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
2020-02-23 18:28:35 +03:00
// Exec execute raw sql
2020-03-09 15:37:01 +03:00
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
2020-01-29 14:22:44 +03:00
tx = db.getInstance()
2020-02-22 15:57:29 +03:00
tx.Statement.SQL = strings.Builder{}
2020-07-10 07:28:24 +03:00
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
2020-03-09 15:37:01 +03:00
tx.callbacks.Raw().Execute(tx)
2020-01-29 14:22:44 +03:00
return
}