2020-01-28 18:01:35 +03:00
package gorm
import (
2020-01-29 14:22:44 +03:00
"context"
2020-06-05 05:08:22 +03:00
"database/sql"
2020-03-09 15:37:01 +03:00
"fmt"
2023-08-19 16:33:31 +03:00
"reflect"
2021-04-09 06:43:24 +03:00
"sort"
2020-02-02 09:40:44 +03:00
"sync"
2020-01-28 18:01:35 +03:00
"time"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
2020-01-28 18:01:35 +03:00
)
2021-03-07 05:57:22 +03:00
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
2020-01-28 18:01:35 +03:00
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
2020-02-02 09:40:44 +03:00
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
2020-01-31 09:17:02 +03:00
// NamingStrategy tables, columns naming strategy
NamingStrategy schema . Namer
2020-09-24 14:28:52 +03:00
// FullSaveAssociations full save associations
FullSaveAssociations bool
2020-01-28 18:01:35 +03:00
// Logger
Logger logger . Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func ( ) time . Time
2020-06-01 16:26:23 +03:00
// DryRun generate sql without execute
DryRun bool
2020-06-05 05:08:22 +03:00
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
2020-06-05 16:23:20 +03:00
// DisableAutomaticPing
DisableAutomaticPing bool
2020-06-22 06:04:44 +03:00
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
2023-01-01 14:54:28 +03:00
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
2020-12-16 14:33:35 +03:00
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
2020-08-23 15:08:23 +03:00
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
2020-11-20 10:38:25 +03:00
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
2020-12-02 09:59:50 +03:00
// CreateBatchSize default create batch size
CreateBatchSize int
2023-03-24 05:07:05 +03:00
// TranslateError enabling error translation
TranslateError bool
2024-06-17 06:59:06 +03:00
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
2020-06-05 05:08:22 +03:00
2020-03-09 08:10:48 +03:00
// ClauseBuilders clause builder
ClauseBuilders map [ string ] clause . ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
2020-06-23 04:38:51 +03:00
// Plugins registered plugins
Plugins map [ string ] Plugin
2020-03-09 08:10:48 +03:00
2020-05-31 06:34:59 +03:00
callbacks * callbacks
cacheStore * sync . Map
2020-02-05 06:14:58 +03:00
}
2022-02-09 12:39:01 +03:00
// Apply update config to new config
2021-03-04 12:14:08 +03:00
func ( c * Config ) Apply ( config * Config ) error {
2021-03-04 15:37:39 +03:00
if config != c {
* config = * c
}
2021-03-04 12:14:08 +03:00
return nil
}
2022-02-09 12:39:01 +03:00
// AfterInitialize initialize plugins after db connected
2021-03-04 12:14:08 +03:00
func ( c * Config ) AfterInitialize ( db * DB ) error {
if db != nil {
for _ , plugin := range c . Plugins {
if err := plugin . Initialize ( db ) ; err != nil {
return err
}
}
}
return nil
}
2022-02-09 12:39:01 +03:00
// Option gorm option interface
2021-03-04 12:14:08 +03:00
type Option interface {
Apply ( * Config ) error
AfterInitialize ( * DB ) error
}
2020-01-28 18:01:35 +03:00
// DB GORM DB definition
type DB struct {
* Config
2020-03-09 08:10:48 +03:00
Error error
RowsAffected int64
Statement * Statement
2020-05-31 18:55:56 +03:00
clone int
2020-01-30 10:14:48 +03:00
}
2020-02-02 09:40:44 +03:00
// Session session config when create session with Session() method
2020-01-30 10:14:48 +03:00
type Session struct {
2020-12-16 14:33:35 +03:00
DryRun bool
PrepareStmt bool
NewDB bool
2022-01-28 14:26:10 +03:00
Initialized bool
2020-12-16 14:33:35 +03:00
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
2024-06-17 06:59:06 +03:00
PropagateUnscoped bool
2020-12-16 14:33:35 +03:00
QueryFields bool
Context context . Context
Logger logger . Interface
NowFunc func ( ) time . Time
CreateBatchSize int
2020-01-30 10:14:48 +03:00
}
// Open initialize db session based on dialector
2021-03-04 12:14:08 +03:00
func Open ( dialector Dialector , opts ... Option ) ( db * DB , err error ) {
config := & Config { }
2021-04-09 06:43:24 +03:00
sort . Slice ( opts , func ( i , j int ) bool {
_ , isConfig := opts [ i ] . ( * Config )
_ , isConfig2 := opts [ j ] . ( * Config )
return isConfig && ! isConfig2
} )
2021-03-04 12:14:08 +03:00
for _ , opt := range opts {
if opt != nil {
2022-03-31 15:57:20 +03:00
if applyErr := opt . Apply ( config ) ; applyErr != nil {
return nil , applyErr
2021-03-04 12:14:08 +03:00
}
2021-03-26 09:20:42 +03:00
defer func ( opt Option ) {
2021-03-29 13:36:01 +03:00
if errr := opt . AfterInitialize ( db ) ; errr != nil {
err = errr
}
2021-03-26 09:20:42 +03:00
} ( opt )
2021-03-04 12:14:08 +03:00
}
2020-02-02 03:35:01 +03:00
}
2021-03-11 05:29:52 +03:00
if d , ok := dialector . ( interface { Apply ( * Config ) error } ) ; ok {
if err = d . Apply ( config ) ; err != nil {
return
}
}
2020-01-31 09:17:02 +03:00
if config . NamingStrategy == nil {
2023-05-30 05:00:48 +03:00
config . NamingStrategy = schema . NamingStrategy { IdentifierMaxLength : 64 } // Default Identifier length is 64
2020-01-31 09:17:02 +03:00
}
2020-02-02 03:35:01 +03:00
if config . Logger == nil {
config . Logger = logger . Default
}
if config . NowFunc == nil {
2020-05-31 03:58:08 +03:00
config . NowFunc = func ( ) time . Time { return time . Now ( ) . Local ( ) }
2020-02-02 03:35:01 +03:00
}
2020-03-09 08:10:48 +03:00
if dialector != nil {
config . Dialector = dialector
}
2020-06-23 05:36:45 +03:00
if config . Plugins == nil {
config . Plugins = map [ string ] Plugin { }
}
2020-03-09 08:10:48 +03:00
if config . cacheStore == nil {
config . cacheStore = & sync . Map { }
}
2020-05-31 18:55:56 +03:00
db = & DB { Config : config , clone : 1 }
2020-02-02 03:35:01 +03:00
2020-03-09 15:37:01 +03:00
db . callbacks = initializeCallbacks ( db )
2020-02-02 14:32:27 +03:00
2020-05-29 17:34:35 +03:00
if config . ClauseBuilders == nil {
config . ClauseBuilders = map [ string ] clause . ClauseBuilder { }
}
2020-06-05 16:23:20 +03:00
if config . Dialector != nil {
err = config . Dialector . Initialize ( db )
2023-04-21 17:17:21 +03:00
if err != nil {
2023-08-20 14:46:56 +03:00
if db , _ := db . DB ( ) ; db != nil {
2023-04-21 17:17:21 +03:00
_ = db . Close ( )
}
}
2024-10-09 14:29:48 +03:00
if config . TranslateError {
if _ , ok := db . Dialector . ( ErrorTranslator ) ; ! ok {
config . Logger . Warn ( context . Background ( ) , "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator." , db . Dialector . Name ( ) )
}
}
2020-02-02 03:35:01 +03:00
}
2020-06-02 06:09:17 +03:00
2020-06-05 05:08:22 +03:00
if config . PrepareStmt {
2023-06-05 11:23:17 +03:00
preparedStmt := NewPreparedStmtDB ( db . ConnPool )
db . cacheStore . Store ( preparedStmtDBKey , preparedStmt )
2020-07-28 09:26:09 +03:00
db . ConnPool = preparedStmt
2020-06-05 05:08:22 +03:00
}
2020-06-05 16:23:20 +03:00
db . Statement = & Statement {
DB : db ,
ConnPool : db . ConnPool ,
Context : context . Background ( ) ,
Clauses : map [ string ] clause . Clause { } ,
2020-06-05 05:08:22 +03:00
}
2020-06-05 16:23:20 +03:00
if err == nil && ! config . DisableAutomaticPing {
2020-06-02 06:09:17 +03:00
if pinger , ok := db . ConnPool . ( interface { Ping ( ) error } ) ; ok {
err = pinger . Ping ( )
}
}
if err != nil {
config . Logger . Error ( context . Background ( ) , "failed to initialize database, got error %v" , err )
}
2020-02-02 03:35:01 +03:00
return
2020-01-30 10:14:48 +03:00
}
// Session create new db session
2020-03-09 15:37:01 +03:00
func ( db * DB ) Session ( config * Session ) * DB {
2020-01-30 10:14:48 +03:00
var (
2020-05-31 18:55:56 +03:00
txConfig = * db . Config
tx = & DB {
Config : & txConfig ,
Statement : db . Statement ,
2021-01-07 06:45:40 +03:00
Error : db . Error ,
2020-05-31 18:55:56 +03:00
clone : 1 ,
}
2020-01-30 10:14:48 +03:00
)
2020-12-02 09:59:50 +03:00
if config . CreateBatchSize > 0 {
tx . Config . CreateBatchSize = config . CreateBatchSize
}
2020-07-16 13:05:55 +03:00
if config . SkipDefaultTransaction {
tx . Config . SkipDefaultTransaction = true
}
2020-08-23 15:08:23 +03:00
if config . AllowGlobalUpdate {
txConfig . AllowGlobalUpdate = true
}
2020-09-24 14:28:52 +03:00
if config . FullSaveAssociations {
txConfig . FullSaveAssociations = true
}
2024-06-17 06:59:06 +03:00
if config . PropagateUnscoped {
txConfig . PropagateUnscoped = true
}
2020-11-17 10:19:58 +03:00
if config . Context != nil || config . PrepareStmt || config . SkipHooks {
2020-06-05 16:23:20 +03:00
tx . Statement = tx . Statement . clone ( )
tx . Statement . DB = tx
2020-11-17 10:19:58 +03:00
}
if config . Context != nil {
2020-05-31 18:55:56 +03:00
tx . Statement . Context = config . Context
}
2020-06-05 05:08:22 +03:00
if config . PrepareStmt {
2023-06-05 11:23:17 +03:00
var preparedStmt * PreparedStmtDB
2021-03-07 05:57:22 +03:00
if v , ok := db . cacheStore . Load ( preparedStmtDBKey ) ; ok {
2023-06-05 11:23:17 +03:00
preparedStmt = v . ( * PreparedStmtDB )
} else {
preparedStmt = NewPreparedStmtDB ( db . ConnPool )
db . cacheStore . Store ( preparedStmtDBKey , preparedStmt )
}
switch t := tx . Statement . ConnPool . ( type ) {
case Tx :
tx . Statement . ConnPool = & PreparedStmtTX {
Tx : t ,
PreparedStmtDB : preparedStmt ,
}
default :
tx . Statement . ConnPool = & PreparedStmtDB {
ConnPool : db . Config . ConnPool ,
Mux : preparedStmt . Mux ,
Stmts : preparedStmt . Stmts ,
2020-07-28 09:26:09 +03:00
}
2020-06-05 05:08:22 +03:00
}
2023-06-05 11:23:17 +03:00
txConfig . ConnPool = tx . Statement . ConnPool
txConfig . PrepareStmt = true
2020-06-05 05:08:22 +03:00
}
2020-11-17 10:19:58 +03:00
if config . SkipHooks {
2020-11-17 12:49:43 +03:00
tx . Statement . SkipHooks = true
2020-11-17 10:19:58 +03:00
}
2020-12-16 14:33:35 +03:00
if config . DisableNestedTransaction {
txConfig . DisableNestedTransaction = true
}
2020-11-17 10:41:17 +03:00
if ! config . NewDB {
2020-06-05 16:23:20 +03:00
tx . clone = 2
2020-01-30 10:14:48 +03:00
}
2020-06-01 16:26:23 +03:00
if config . DryRun {
tx . Config . DryRun = true
}
2020-11-20 10:38:25 +03:00
if config . QueryFields {
tx . Config . QueryFields = true
}
2020-01-30 10:14:48 +03:00
if config . Logger != nil {
2020-05-31 18:55:56 +03:00
tx . Config . Logger = config . Logger
2020-01-30 10:14:48 +03:00
}
if config . NowFunc != nil {
2020-05-31 18:55:56 +03:00
tx . Config . NowFunc = config . NowFunc
2020-01-30 10:14:48 +03:00
}
2022-01-28 14:26:10 +03:00
if config . Initialized {
tx = tx . getInstance ( )
}
2020-05-31 18:55:56 +03:00
return tx
2020-01-29 14:22:44 +03:00
}
// WithContext change current instance db's context to ctx
2020-03-09 15:37:01 +03:00
func ( db * DB ) WithContext ( ctx context . Context ) * DB {
2020-11-17 10:41:17 +03:00
return db . Session ( & Session { Context : ctx } )
2020-01-30 10:14:48 +03:00
}
// Debug start debug mode
2020-03-09 15:37:01 +03:00
func ( db * DB ) Debug ( ) ( tx * DB ) {
2022-07-18 13:06:45 +03:00
tx = db . getInstance ( )
return tx . Session ( & Session {
2020-11-17 10:41:17 +03:00
Logger : db . Logger . LogMode ( logger . Info ) ,
2020-05-31 18:55:56 +03:00
} )
2020-01-30 10:14:48 +03:00
}
2020-01-29 14:22:44 +03:00
// Set store value with key into current db instance's context
2020-03-09 15:37:01 +03:00
func ( db * DB ) Set ( key string , value interface { } ) * DB {
2020-01-29 14:22:44 +03:00
tx := db . getInstance ( )
tx . Statement . Settings . Store ( key , value )
return tx
}
// Get get value with key from current db instance's context
2020-03-09 15:37:01 +03:00
func ( db * DB ) Get ( key string ) ( interface { } , bool ) {
2020-06-05 16:23:20 +03:00
return db . Statement . Settings . Load ( key )
2020-01-29 14:22:44 +03:00
}
2020-05-31 18:55:56 +03:00
// InstanceSet store value with key into current db instance's context
func ( db * DB ) InstanceSet ( key string , value interface { } ) * DB {
tx := db . getInstance ( )
tx . Statement . Settings . Store ( fmt . Sprintf ( "%p" , tx . Statement ) + key , value )
return tx
}
// InstanceGet get value with key from current db instance's context
func ( db * DB ) InstanceGet ( key string ) ( interface { } , bool ) {
2020-06-05 16:23:20 +03:00
return db . Statement . Settings . Load ( fmt . Sprintf ( "%p" , db . Statement ) + key )
2020-05-31 18:55:56 +03:00
}
2020-02-02 03:35:01 +03:00
// Callback returns callback manager
2020-03-09 15:37:01 +03:00
func ( db * DB ) Callback ( ) * callbacks {
2020-02-02 03:35:01 +03:00
return db . callbacks
}
2020-03-09 08:10:48 +03:00
// AddError add error to db
2020-04-19 18:11:56 +03:00
func ( db * DB ) AddError ( err error ) error {
2023-03-10 11:51:27 +03:00
if err != nil {
2023-03-24 05:07:05 +03:00
if db . Config . TranslateError {
if errTranslator , ok := db . Dialector . ( ErrorTranslator ) ; ok {
err = errTranslator . Translate ( err )
}
2023-03-10 11:51:27 +03:00
}
2023-03-06 09:03:31 +03:00
2023-03-10 11:51:27 +03:00
if db . Error == nil {
db . Error = err
} else {
db . Error = fmt . Errorf ( "%v; %w" , db . Error , err )
}
2020-03-09 15:37:01 +03:00
}
2020-04-19 18:11:56 +03:00
return db . Error
2020-03-09 08:10:48 +03:00
}
2020-06-17 14:56:03 +03:00
// DB returns `*sql.DB`
func ( db * DB ) DB ( ) ( * sql . DB , error ) {
connPool := db . ConnPool
2023-08-19 16:33:31 +03:00
if db . Statement != nil && db . Statement . ConnPool != nil {
connPool = db . Statement . ConnPool
}
if tx , ok := connPool . ( * sql . Tx ) ; ok && tx != nil {
return ( * sql . DB ) ( reflect . ValueOf ( tx ) . Elem ( ) . FieldByName ( "db" ) . UnsafePointer ( ) ) , nil
2023-06-05 11:25:05 +03:00
}
2021-03-19 10:54:32 +03:00
if dbConnector , ok := connPool . ( GetDBConnector ) ; ok && dbConnector != nil {
2023-06-05 11:24:00 +03:00
if sqldb , err := dbConnector . GetDBConn ( ) ; sqldb != nil || err != nil {
return sqldb , err
}
2020-06-17 14:56:03 +03:00
}
2023-06-05 11:24:00 +03:00
if sqldb , ok := connPool . ( * sql . DB ) ; ok && sqldb != nil {
2020-06-17 14:56:03 +03:00
return sqldb , nil
}
2021-04-19 16:03:39 +03:00
return nil , ErrInvalidDB
2020-06-17 14:56:03 +03:00
}
2020-03-09 15:37:01 +03:00
func ( db * DB ) getInstance ( ) * DB {
2020-05-31 18:55:56 +03:00
if db . clone > 0 {
2021-04-09 06:07:14 +03:00
tx := & DB { Config : db . Config , Error : db . Error }
2020-05-31 18:55:56 +03:00
2020-06-05 16:23:20 +03:00
if db . clone == 1 {
// clone with new statement
2020-06-05 05:08:22 +03:00
tx . Statement = & Statement {
2023-08-04 05:35:59 +03:00
DB : tx ,
ConnPool : db . Statement . ConnPool ,
Context : db . Statement . Context ,
Clauses : map [ string ] clause . Clause { } ,
Vars : make ( [ ] interface { } , 0 , 8 ) ,
SkipHooks : db . Statement . SkipHooks ,
2020-06-05 05:08:22 +03:00
}
2024-06-17 06:59:06 +03:00
if db . Config . PropagateUnscoped {
tx . Statement . Unscoped = db . Statement . Unscoped
}
2020-06-05 16:23:20 +03:00
} else {
// with clone statement
tx . Statement = db . Statement . clone ( )
tx . Statement . DB = tx
2020-05-31 18:55:56 +03:00
}
return tx
2020-01-29 14:22:44 +03:00
}
return db
}
2020-05-30 11:47:16 +03:00
2022-02-09 12:39:01 +03:00
// Expr returns clause.Expr, which can be used to pass SQL expression as params
2020-05-30 11:47:16 +03:00
func Expr ( expr string , args ... interface { } ) clause . Expr {
return clause . Expr { SQL : expr , Vars : args }
}
2020-06-08 08:45:41 +03:00
2022-02-09 12:39:01 +03:00
// SetupJoinTable setup join table schema
2020-06-08 08:45:41 +03:00
func ( db * DB ) SetupJoinTable ( model interface { } , field string , joinTable interface { } ) error {
var (
tx = db . getInstance ( )
stmt = tx . Statement
modelSchema , joinSchema * schema . Schema
)
2021-10-08 06:05:50 +03:00
err := stmt . Parse ( model )
if err != nil {
2020-06-08 08:45:41 +03:00
return err
}
2021-10-08 06:05:50 +03:00
modelSchema = stmt . Schema
2020-06-08 08:45:41 +03:00
2021-10-08 06:05:50 +03:00
err = stmt . Parse ( joinTable )
if err != nil {
2020-06-08 08:45:41 +03:00
return err
}
2021-10-08 06:05:50 +03:00
joinSchema = stmt . Schema
2020-06-08 08:45:41 +03:00
2021-10-08 06:05:50 +03:00
relation , ok := modelSchema . Relationships . Relations [ field ]
isRelation := ok && relation . JoinTable != nil
if ! isRelation {
2022-08-12 16:46:18 +03:00
return fmt . Errorf ( "failed to find relation: %s" , field )
2021-10-08 06:05:50 +03:00
}
for _ , ref := range relation . References {
f := joinSchema . LookUpField ( ref . ForeignKey . DBName )
if f == nil {
return fmt . Errorf ( "missing field %s for join table" , ref . ForeignKey . DBName )
2020-06-08 08:45:41 +03:00
}
2021-10-08 06:05:50 +03:00
f . DataType = ref . ForeignKey . DataType
f . GORMDataType = ref . ForeignKey . GORMDataType
if f . Size == 0 {
f . Size = ref . ForeignKey . Size
2020-06-21 05:19:16 +03:00
}
2021-10-08 06:05:50 +03:00
ref . ForeignKey = f
}
2020-06-21 05:19:16 +03:00
2021-10-08 06:05:50 +03:00
for name , rel := range relation . JoinTable . Relationships . Relations {
if _ , ok := joinSchema . Relationships . Relations [ name ] ; ! ok {
rel . Schema = joinSchema
joinSchema . Relationships . Relations [ name ] = rel
}
2020-06-08 08:45:41 +03:00
}
2021-10-08 06:05:50 +03:00
relation . JoinTable = joinSchema
2020-06-08 08:45:41 +03:00
return nil
}
2020-06-23 04:38:51 +03:00
2022-02-09 12:39:01 +03:00
// Use use plugin
2021-01-10 05:15:48 +03:00
func ( db * DB ) Use ( plugin Plugin ) error {
2020-06-23 04:38:51 +03:00
name := plugin . Name ( )
2021-01-10 05:15:48 +03:00
if _ , ok := db . Plugins [ name ] ; ok {
2020-06-23 04:38:51 +03:00
return ErrRegistered
}
2021-01-10 05:15:48 +03:00
if err := plugin . Initialize ( db ) ; err != nil {
return err
}
db . Plugins [ name ] = plugin
return nil
2020-06-23 04:38:51 +03:00
}
2021-11-01 12:08:54 +03:00
// ToSQL for generate SQL string.
//
2022-12-25 06:37:23 +03:00
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
2021-11-01 12:08:54 +03:00
func ( db * DB ) ToSQL ( queryFn func ( tx * DB ) * DB ) string {
2022-03-03 05:17:29 +03:00
tx := queryFn ( db . Session ( & Session { DryRun : true , SkipDefaultTransaction : true } ) )
2021-11-01 12:08:54 +03:00
stmt := tx . Statement
return db . Dialector . Explain ( stmt . SQL . String ( ) , stmt . Vars ... )
}