gorm/gorm.go

442 lines
9.8 KiB
Go
Raw Normal View History

2020-01-28 18:01:35 +03:00
package gorm
import (
2020-01-29 14:22:44 +03:00
"context"
2020-06-05 05:08:22 +03:00
"database/sql"
2020-03-09 15:37:01 +03:00
"fmt"
2021-04-09 06:43:24 +03:00
"sort"
2020-02-02 09:40:44 +03:00
"sync"
2020-01-28 18:01:35 +03:00
"time"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
2020-01-28 18:01:35 +03:00
)
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
2020-01-28 18:01:35 +03:00
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
2020-02-02 09:40:44 +03:00
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
2020-01-31 09:17:02 +03:00
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
2020-01-28 18:01:35 +03:00
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
2020-06-01 16:26:23 +03:00
// DryRun generate sql without execute
DryRun bool
2020-06-05 05:08:22 +03:00
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
2020-12-16 14:33:35 +03:00
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
2020-08-23 15:08:23 +03:00
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
2020-12-02 09:59:50 +03:00
// CreateBatchSize default create batch size
CreateBatchSize int
2020-06-05 05:08:22 +03:00
2020-03-09 08:10:48 +03:00
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
2020-06-23 04:38:51 +03:00
// Plugins registered plugins
Plugins map[string]Plugin
2020-03-09 08:10:48 +03:00
2020-05-31 06:34:59 +03:00
callbacks *callbacks
cacheStore *sync.Map
2020-02-05 06:14:58 +03:00
}
2021-03-04 12:14:08 +03:00
func (c *Config) Apply(config *Config) error {
2021-03-04 15:37:39 +03:00
if config != c {
*config = *c
}
2021-03-04 12:14:08 +03:00
return nil
}
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
}
}
}
return nil
}
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
}
2020-01-28 18:01:35 +03:00
// DB GORM DB definition
type DB struct {
*Config
2020-03-09 08:10:48 +03:00
Error error
RowsAffected int64
Statement *Statement
2020-05-31 18:55:56 +03:00
clone int
2020-01-30 10:14:48 +03:00
}
2020-02-02 09:40:44 +03:00
// Session session config when create session with Session() method
2020-01-30 10:14:48 +03:00
type Session struct {
2020-12-16 14:33:35 +03:00
DryRun bool
PrepareStmt bool
NewDB bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
2020-01-30 10:14:48 +03:00
}
// Open initialize db session based on dialector
2021-03-04 12:14:08 +03:00
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
2021-04-09 06:43:24 +03:00
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
2021-03-04 12:14:08 +03:00
for _, opt := range opts {
if opt != nil {
if err := opt.Apply(config); err != nil {
return nil, err
}
2021-03-26 09:20:42 +03:00
defer func(opt Option) {
2021-03-29 13:36:01 +03:00
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
2021-03-26 09:20:42 +03:00
}(opt)
2021-03-04 12:14:08 +03:00
}
2020-02-02 03:35:01 +03:00
}
2021-03-11 05:29:52 +03:00
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
}
2020-01-31 09:17:02 +03:00
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{}
}
2020-02-02 03:35:01 +03:00
if config.Logger == nil {
config.Logger = logger.Default
}
if config.NowFunc == nil {
config.NowFunc = func() time.Time { return time.Now().Local() }
2020-02-02 03:35:01 +03:00
}
2020-03-09 08:10:48 +03:00
if dialector != nil {
config.Dialector = dialector
}
2020-06-23 05:36:45 +03:00
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
}
2020-03-09 08:10:48 +03:00
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
2020-05-31 18:55:56 +03:00
db = &DB{Config: config, clone: 1}
2020-02-02 03:35:01 +03:00
2020-03-09 15:37:01 +03:00
db.callbacks = initializeCallbacks(db)
2020-02-02 14:32:27 +03:00
2020-05-29 17:34:35 +03:00
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
2020-02-02 03:35:01 +03:00
}
2020-06-02 06:09:17 +03:00
2020-07-28 09:26:09 +03:00
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Mux: &sync.RWMutex{},
2020-07-28 09:26:09 +03:00
PreparedSQL: make([]string, 0, 100),
}
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
2020-07-28 09:26:09 +03:00
2020-06-05 05:08:22 +03:00
if config.PrepareStmt {
2020-07-28 09:26:09 +03:00
db.ConnPool = preparedStmt
2020-06-05 05:08:22 +03:00
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
2020-06-05 05:08:22 +03:00
}
if err == nil && !config.DisableAutomaticPing {
2020-06-02 06:09:17 +03:00
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
2020-02-02 03:35:01 +03:00
return
2020-01-30 10:14:48 +03:00
}
// Session create new db session
2020-03-09 15:37:01 +03:00
func (db *DB) Session(config *Session) *DB {
2020-01-30 10:14:48 +03:00
var (
2020-05-31 18:55:56 +03:00
txConfig = *db.Config
tx = &DB{
Config: &txConfig,
Statement: db.Statement,
2021-01-07 06:45:40 +03:00
Error: db.Error,
2020-05-31 18:55:56 +03:00
clone: 1,
}
2020-01-30 10:14:48 +03:00
)
2020-12-02 09:59:50 +03:00
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
}
if config.SkipDefaultTransaction {
tx.Config.SkipDefaultTransaction = true
}
2020-08-23 15:08:23 +03:00
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
2020-11-17 10:19:58 +03:00
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
2020-11-17 10:19:58 +03:00
}
if config.Context != nil {
2020-05-31 18:55:56 +03:00
tx.Statement.Context = config.Context
}
2020-06-05 05:08:22 +03:00
if config.PrepareStmt {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
2020-07-28 09:26:09 +03:00
preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
2020-08-03 16:48:36 +03:00
Mux: preparedStmt.Mux,
2020-07-28 09:26:09 +03:00
Stmts: preparedStmt.Stmts,
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
2020-06-05 05:08:22 +03:00
}
}
2020-11-17 10:19:58 +03:00
if config.SkipHooks {
2020-11-17 12:49:43 +03:00
tx.Statement.SkipHooks = true
2020-11-17 10:19:58 +03:00
}
2020-12-16 14:33:35 +03:00
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
2020-01-30 10:14:48 +03:00
}
2020-06-01 16:26:23 +03:00
if config.DryRun {
tx.Config.DryRun = true
}
if config.QueryFields {
tx.Config.QueryFields = true
}
2020-01-30 10:14:48 +03:00
if config.Logger != nil {
2020-05-31 18:55:56 +03:00
tx.Config.Logger = config.Logger
2020-01-30 10:14:48 +03:00
}
if config.NowFunc != nil {
2020-05-31 18:55:56 +03:00
tx.Config.NowFunc = config.NowFunc
2020-01-30 10:14:48 +03:00
}
2020-05-31 18:55:56 +03:00
return tx
2020-01-29 14:22:44 +03:00
}
// WithContext change current instance db's context to ctx
2020-03-09 15:37:01 +03:00
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
2020-01-30 10:14:48 +03:00
}
// Debug start debug mode
2020-03-09 15:37:01 +03:00
func (db *DB) Debug() (tx *DB) {
2020-05-31 18:55:56 +03:00
return db.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
2020-05-31 18:55:56 +03:00
})
2020-01-30 10:14:48 +03:00
}
2020-01-29 14:22:44 +03:00
// Set store value with key into current db instance's context
2020-03-09 15:37:01 +03:00
func (db *DB) Set(key string, value interface{}) *DB {
2020-01-29 14:22:44 +03:00
tx := db.getInstance()
tx.Statement.Settings.Store(key, value)
return tx
}
// Get get value with key from current db instance's context
2020-03-09 15:37:01 +03:00
func (db *DB) Get(key string) (interface{}, bool) {
return db.Statement.Settings.Load(key)
2020-01-29 14:22:44 +03:00
}
2020-05-31 18:55:56 +03:00
// InstanceSet store value with key into current db instance's context
func (db *DB) InstanceSet(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
return tx
}
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
2020-05-31 18:55:56 +03:00
}
2020-02-02 03:35:01 +03:00
// Callback returns callback manager
2020-03-09 15:37:01 +03:00
func (db *DB) Callback() *callbacks {
2020-02-02 03:35:01 +03:00
return db.callbacks
}
2020-03-09 08:10:48 +03:00
// AddError add error to db
2020-04-19 18:11:56 +03:00
func (db *DB) AddError(err error) error {
2020-03-09 15:37:01 +03:00
if db.Error == nil {
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
2020-04-19 18:11:56 +03:00
return db.Error
2020-03-09 08:10:48 +03:00
}
2020-06-17 14:56:03 +03:00
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
2021-03-19 10:54:32 +03:00
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
2020-06-17 14:56:03 +03:00
}
if sqldb, ok := connPool.(*sql.DB); ok {
return sqldb, nil
}
2021-04-19 16:03:39 +03:00
return nil, ErrInvalidDB
2020-06-17 14:56:03 +03:00
}
2020-03-09 15:37:01 +03:00
func (db *DB) getInstance() *DB {
2020-05-31 18:55:56 +03:00
if db.clone > 0 {
2021-04-09 06:07:14 +03:00
tx := &DB{Config: db.Config, Error: db.Error}
2020-05-31 18:55:56 +03:00
if db.clone == 1 {
// clone with new statement
2020-06-05 05:08:22 +03:00
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
2020-11-10 13:38:24 +03:00
Vars: make([]interface{}, 0, 8),
2020-06-05 05:08:22 +03:00
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
2020-05-31 18:55:56 +03:00
}
return tx
2020-01-29 14:22:44 +03:00
}
return db
}
2020-05-30 11:47:16 +03:00
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
2020-06-08 08:45:41 +03:00
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
2020-07-20 13:59:28 +03:00
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
2020-06-08 08:45:41 +03:00
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
2020-06-08 08:45:41 +03:00
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}
2020-06-23 04:38:51 +03:00
func (db *DB) Use(plugin Plugin) error {
2020-06-23 04:38:51 +03:00
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
2020-06-23 04:38:51 +03:00
return ErrRegistered
}
if err := plugin.Initialize(db); err != nil {
return err
}
db.Plugins[name] = plugin
return nil
2020-06-23 04:38:51 +03:00
}