mirror of https://github.com/go-gorm/gorm.git
709 lines
20 KiB
Go
709 lines
20 KiB
Go
package gorm
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
"gorm.io/gorm/schema"
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
// Create inserts value, returning the inserted data's primary key in value's id
|
|
func (db *DB) Create(value interface{}) (tx *DB) {
|
|
if db.CreateBatchSize > 0 {
|
|
return db.CreateInBatches(value, db.CreateBatchSize)
|
|
}
|
|
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = value
|
|
return tx.callbacks.Create().Execute(tx)
|
|
}
|
|
|
|
// CreateInBatches inserts value in batches of batchSize
|
|
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
|
|
|
switch reflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
var rowsAffected int64
|
|
tx = db.getInstance()
|
|
|
|
callFc := func(tx *DB) error {
|
|
// the reflection length judgment of the optimized value
|
|
reflectLen := reflectValue.Len()
|
|
for i := 0; i < reflectLen; i += batchSize {
|
|
ends := i + batchSize
|
|
if ends > reflectLen {
|
|
ends = reflectLen
|
|
}
|
|
|
|
subtx := tx.getInstance()
|
|
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
|
|
subtx.callbacks.Create().Execute(subtx)
|
|
if subtx.Error != nil {
|
|
return subtx.Error
|
|
}
|
|
rowsAffected += subtx.RowsAffected
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if tx.SkipDefaultTransaction {
|
|
tx.AddError(callFc(tx.Session(&Session{})))
|
|
} else {
|
|
tx.AddError(tx.Transaction(callFc))
|
|
}
|
|
|
|
tx.RowsAffected = rowsAffected
|
|
default:
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = value
|
|
tx = tx.callbacks.Create().Execute(tx)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
|
|
func (db *DB) Save(value interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = value
|
|
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
|
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
|
|
reflectValue = reflect.Indirect(reflectValue)
|
|
}
|
|
|
|
switch reflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
|
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
|
}
|
|
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
|
|
case reflect.Struct:
|
|
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
|
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
|
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
|
|
return tx.callbacks.Create().Execute(tx)
|
|
}
|
|
}
|
|
}
|
|
|
|
fallthrough
|
|
default:
|
|
selectedUpdate := len(tx.Statement.Selects) != 0
|
|
// when updating, use all fields including those zero-value fields
|
|
if !selectedUpdate {
|
|
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
|
}
|
|
|
|
tx = tx.callbacks.Update().Execute(tx)
|
|
|
|
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
|
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
|
if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
|
|
return tx.Create(value)
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// First finds the first record ordered by primary key, matching given conditions conds
|
|
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
|
})
|
|
if len(conds) > 0 {
|
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
|
}
|
|
}
|
|
tx.Statement.RaiseErrorOnNotFound = true
|
|
tx.Statement.Dest = dest
|
|
return tx.callbacks.Query().Execute(tx)
|
|
}
|
|
|
|
// Take finds the first record returned by the database in no specified order, matching given conditions conds
|
|
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
|
tx = db.Limit(1)
|
|
if len(conds) > 0 {
|
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
|
}
|
|
}
|
|
tx.Statement.RaiseErrorOnNotFound = true
|
|
tx.Statement.Dest = dest
|
|
return tx.callbacks.Query().Execute(tx)
|
|
}
|
|
|
|
// Last finds the last record ordered by primary key, matching given conditions conds
|
|
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
|
Desc: true,
|
|
})
|
|
if len(conds) > 0 {
|
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
|
}
|
|
}
|
|
tx.Statement.RaiseErrorOnNotFound = true
|
|
tx.Statement.Dest = dest
|
|
return tx.callbacks.Query().Execute(tx)
|
|
}
|
|
|
|
// Find finds all records matching given conditions conds
|
|
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
if len(conds) > 0 {
|
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
|
}
|
|
}
|
|
tx.Statement.Dest = dest
|
|
return tx.callbacks.Query().Execute(tx)
|
|
}
|
|
|
|
// FindInBatches finds all records in batches of batchSize
|
|
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
|
var (
|
|
tx = db.Order(clause.OrderByColumn{
|
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
|
}).Session(&Session{})
|
|
queryDB = tx
|
|
rowsAffected int64
|
|
batch int
|
|
)
|
|
|
|
// user specified offset or limit
|
|
var totalSize int
|
|
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
|
if limit, ok := c.Expression.(clause.Limit); ok {
|
|
if limit.Limit != nil {
|
|
totalSize = *limit.Limit
|
|
}
|
|
|
|
if totalSize > 0 && batchSize > totalSize {
|
|
batchSize = totalSize
|
|
}
|
|
|
|
// reset to offset to 0 in next batch
|
|
tx = tx.Offset(-1).Session(&Session{})
|
|
}
|
|
}
|
|
|
|
for {
|
|
result := queryDB.Limit(batchSize).Find(dest)
|
|
rowsAffected += result.RowsAffected
|
|
batch++
|
|
|
|
if result.Error == nil && result.RowsAffected != 0 {
|
|
fcTx := result.Session(&Session{NewDB: true})
|
|
fcTx.RowsAffected = result.RowsAffected
|
|
tx.AddError(fc(fcTx, batch))
|
|
} else if result.Error != nil {
|
|
tx.AddError(result.Error)
|
|
}
|
|
|
|
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
|
break
|
|
}
|
|
|
|
if totalSize > 0 {
|
|
if totalSize <= int(rowsAffected) {
|
|
break
|
|
}
|
|
if totalSize/batchSize == batch {
|
|
batchSize = totalSize % batchSize
|
|
}
|
|
}
|
|
|
|
// Optimize for-break
|
|
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
|
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
|
tx.AddError(ErrPrimaryKeyRequired)
|
|
break
|
|
}
|
|
|
|
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
|
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
|
}
|
|
|
|
tx.RowsAffected = rowsAffected
|
|
return tx
|
|
}
|
|
|
|
func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
|
for _, value := range values {
|
|
switch v := value.(type) {
|
|
case []clause.Expression:
|
|
for _, expr := range v {
|
|
if eq, ok := expr.(clause.Eq); ok {
|
|
switch column := eq.Column.(type) {
|
|
case string:
|
|
if field := db.Statement.Schema.LookUpField(column); field != nil {
|
|
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
|
}
|
|
case clause.Column:
|
|
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
|
|
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
|
}
|
|
}
|
|
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
|
db.assignInterfacesToValue(andCond.Exprs)
|
|
}
|
|
}
|
|
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
|
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
|
|
db.assignInterfacesToValue(exprs)
|
|
}
|
|
default:
|
|
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
|
switch reflectValue.Kind() {
|
|
case reflect.Struct:
|
|
for _, f := range s.Fields {
|
|
if f.Readable {
|
|
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
|
|
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
|
|
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else if len(values) > 0 {
|
|
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
|
|
db.assignInterfacesToValue(exprs)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
|
// Each conds must be a struct or map.
|
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
|
})
|
|
|
|
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
|
|
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
|
if where, ok := c.Expression.(clause.Where); ok {
|
|
tx.assignInterfacesToValue(where.Exprs)
|
|
}
|
|
}
|
|
|
|
// initialize with attrs, conds
|
|
if len(tx.Statement.attrs) > 0 {
|
|
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
|
}
|
|
}
|
|
|
|
// initialize with attrs, conds
|
|
if len(tx.Statement.assigns) > 0 {
|
|
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
|
}
|
|
return
|
|
}
|
|
|
|
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
|
// Each conds must be a struct or map.
|
|
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
|
})
|
|
if result := queryTx.Find(dest, conds...); result.Error == nil {
|
|
if result.RowsAffected == 0 {
|
|
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
|
if where, ok := c.Expression.(clause.Where); ok {
|
|
result.assignInterfacesToValue(where.Exprs)
|
|
}
|
|
}
|
|
|
|
// initialize with attrs, conds
|
|
if len(db.Statement.attrs) > 0 {
|
|
result.assignInterfacesToValue(db.Statement.attrs...)
|
|
}
|
|
|
|
// initialize with attrs, conds
|
|
if len(db.Statement.assigns) > 0 {
|
|
result.assignInterfacesToValue(db.Statement.assigns...)
|
|
}
|
|
|
|
return tx.Create(dest)
|
|
} else if len(db.Statement.assigns) > 0 {
|
|
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
|
assigns := map[string]interface{}{}
|
|
for _, expr := range exprs {
|
|
if eq, ok := expr.(clause.Eq); ok {
|
|
switch column := eq.Column.(type) {
|
|
case string:
|
|
assigns[column] = eq.Value
|
|
case clause.Column:
|
|
assigns[column.Name] = eq.Value
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
return tx.Model(dest).Updates(assigns)
|
|
}
|
|
} else {
|
|
tx.Error = result.Error
|
|
}
|
|
return tx
|
|
}
|
|
|
|
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
|
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = map[string]interface{}{column: value}
|
|
return tx.callbacks.Update().Execute(tx)
|
|
}
|
|
|
|
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
|
func (db *DB) Updates(values interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = values
|
|
return tx.callbacks.Update().Execute(tx)
|
|
}
|
|
|
|
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = map[string]interface{}{column: value}
|
|
tx.Statement.SkipHooks = true
|
|
return tx.callbacks.Update().Execute(tx)
|
|
}
|
|
|
|
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
tx.Statement.Dest = values
|
|
tx.Statement.SkipHooks = true
|
|
return tx.callbacks.Update().Execute(tx)
|
|
}
|
|
|
|
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
|
|
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
|
|
// time if null.
|
|
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
if len(conds) > 0 {
|
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
|
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
|
}
|
|
}
|
|
tx.Statement.Dest = value
|
|
return tx.callbacks.Delete().Execute(tx)
|
|
}
|
|
|
|
func (db *DB) Count(count *int64) (tx *DB) {
|
|
tx = db.getInstance()
|
|
if tx.Statement.Model == nil {
|
|
tx.Statement.Model = tx.Statement.Dest
|
|
defer func() {
|
|
tx.Statement.Model = nil
|
|
}()
|
|
}
|
|
|
|
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
|
|
defer func() {
|
|
tx.Statement.Clauses["SELECT"] = selectClause
|
|
}()
|
|
} else {
|
|
defer delete(tx.Statement.Clauses, "SELECT")
|
|
}
|
|
|
|
if len(tx.Statement.Selects) == 0 {
|
|
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
|
|
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
|
|
expr := clause.Expr{SQL: "count(*)"}
|
|
|
|
if len(tx.Statement.Selects) == 1 {
|
|
dbName := tx.Statement.Selects[0]
|
|
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
|
|
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
|
|
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
|
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
|
dbName = f.DBName
|
|
}
|
|
}
|
|
|
|
if tx.Statement.Distinct {
|
|
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
|
} else if dbName != "*" {
|
|
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
|
}
|
|
}
|
|
}
|
|
|
|
tx.Statement.AddClause(clause.Select{Expression: expr})
|
|
}
|
|
|
|
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
|
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
|
|
delete(tx.Statement.Clauses, "ORDER BY")
|
|
defer func() {
|
|
tx.Statement.Clauses["ORDER BY"] = orderByClause
|
|
}()
|
|
}
|
|
}
|
|
|
|
tx.Statement.Dest = count
|
|
tx = tx.callbacks.Query().Execute(tx)
|
|
|
|
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
|
|
*count = tx.RowsAffected
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (db *DB) Row() *sql.Row {
|
|
tx := db.getInstance().Set("rows", false)
|
|
tx = tx.callbacks.Row().Execute(tx)
|
|
row, ok := tx.Statement.Dest.(*sql.Row)
|
|
if !ok && tx.DryRun {
|
|
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
|
}
|
|
return row
|
|
}
|
|
|
|
func (db *DB) Rows() (*sql.Rows, error) {
|
|
tx := db.getInstance().Set("rows", true)
|
|
tx = tx.callbacks.Row().Execute(tx)
|
|
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
|
if !ok && tx.DryRun && tx.Error == nil {
|
|
tx.Error = ErrDryRunModeUnsupported
|
|
}
|
|
return rows, tx.Error
|
|
}
|
|
|
|
// Scan scans selected value to the struct dest
|
|
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|
config := *db.Config
|
|
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
|
config.Logger = newLogger
|
|
|
|
tx = db.getInstance()
|
|
tx.Config = &config
|
|
|
|
if rows, err := tx.Rows(); err == nil {
|
|
if rows.Next() {
|
|
tx.ScanRows(rows, dest)
|
|
} else {
|
|
tx.RowsAffected = 0
|
|
}
|
|
tx.AddError(rows.Close())
|
|
}
|
|
|
|
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
|
|
return newLogger.SQL, tx.RowsAffected
|
|
}, tx.Error)
|
|
tx.Logger = currentLogger
|
|
return
|
|
}
|
|
|
|
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
|
// var ages []int64
|
|
// db.Model(&users).Pluck("age", &ages)
|
|
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
if tx.Statement.Model != nil {
|
|
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
|
if f := tx.Statement.Schema.LookUpField(column); f != nil {
|
|
column = f.DBName
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(tx.Statement.Selects) != 1 {
|
|
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
|
tx.Statement.AddClauseIfNotExists(clause.Select{
|
|
Distinct: tx.Statement.Distinct,
|
|
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
|
})
|
|
}
|
|
tx.Statement.Dest = dest
|
|
return tx.callbacks.Query().Execute(tx)
|
|
}
|
|
|
|
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|
tx := db.getInstance()
|
|
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
|
|
tx.AddError(err)
|
|
}
|
|
tx.Statement.Dest = dest
|
|
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
|
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
|
elem := tx.Statement.ReflectValue.Elem()
|
|
if !elem.IsValid() {
|
|
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
|
|
tx.Statement.ReflectValue.Set(elem)
|
|
}
|
|
tx.Statement.ReflectValue = elem
|
|
}
|
|
Scan(rows, tx, ScanInitialized)
|
|
return tx.Error
|
|
}
|
|
|
|
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
|
|
// returned to the connection pool.
|
|
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
|
if db.Error != nil {
|
|
return db.Error
|
|
}
|
|
|
|
tx := db.getInstance()
|
|
sqlDB, err := tx.DB()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
conn, err := sqlDB.Conn(tx.Statement.Context)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
defer conn.Close()
|
|
tx.Statement.ConnPool = conn
|
|
return fc(tx)
|
|
}
|
|
|
|
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
|
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
|
// they are rolled back.
|
|
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
|
panicked := true
|
|
|
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
|
// nested transaction
|
|
if !db.DisableNestedTransaction {
|
|
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
// Make sure to rollback when panic, Block error or Commit error
|
|
if panicked || err != nil {
|
|
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
|
}
|
|
}()
|
|
}
|
|
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
|
|
} else {
|
|
tx := db.Begin(opts...)
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
defer func() {
|
|
// Make sure to rollback when panic, Block error or Commit error
|
|
if panicked || err != nil {
|
|
tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
if err = fc(tx); err == nil {
|
|
panicked = false
|
|
return tx.Commit().Error
|
|
}
|
|
}
|
|
|
|
panicked = false
|
|
return
|
|
}
|
|
|
|
// Begin begins a transaction with any transaction options opts
|
|
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|
var (
|
|
// clone statement
|
|
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
|
|
opt *sql.TxOptions
|
|
err error
|
|
)
|
|
|
|
if len(opts) > 0 {
|
|
opt = opts[0]
|
|
}
|
|
|
|
switch beginner := tx.Statement.ConnPool.(type) {
|
|
case TxBeginner:
|
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
|
case ConnPoolBeginner:
|
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
|
default:
|
|
err = ErrInvalidTransaction
|
|
}
|
|
|
|
if err != nil {
|
|
tx.AddError(err)
|
|
}
|
|
|
|
return tx
|
|
}
|
|
|
|
// Commit commits the changes in a transaction
|
|
func (db *DB) Commit() *DB {
|
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
|
db.AddError(committer.Commit())
|
|
} else {
|
|
db.AddError(ErrInvalidTransaction)
|
|
}
|
|
return db
|
|
}
|
|
|
|
// Rollback rollbacks the changes in a transaction
|
|
func (db *DB) Rollback() *DB {
|
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
|
if !reflect.ValueOf(committer).IsNil() {
|
|
db.AddError(committer.Rollback())
|
|
}
|
|
} else {
|
|
db.AddError(ErrInvalidTransaction)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func (db *DB) SavePoint(name string) *DB {
|
|
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
|
db.AddError(savePointer.SavePoint(db, name))
|
|
} else {
|
|
db.AddError(ErrUnsupportedDriver)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func (db *DB) RollbackTo(name string) *DB {
|
|
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
|
db.AddError(savePointer.RollbackTo(db, name))
|
|
} else {
|
|
db.AddError(ErrUnsupportedDriver)
|
|
}
|
|
return db
|
|
}
|
|
|
|
// Exec executes raw sql
|
|
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
|
tx = db.getInstance()
|
|
tx.Statement.SQL = strings.Builder{}
|
|
|
|
if strings.Contains(sql, "@") {
|
|
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
|
} else {
|
|
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
|
}
|
|
|
|
return tx.callbacks.Raw().Execute(tx)
|
|
}
|