forked from mirror/gorm
Refactor structure
This commit is contained in:
parent
078ba75b9c
commit
a145d7e019
|
@ -90,6 +90,9 @@ func (p *processor) Execute(db *DB) {
|
|||
}
|
||||
|
||||
if stmt := db.Statement; stmt != nil {
|
||||
db.Error = stmt.Error
|
||||
db.RowsAffected = stmt.RowsAffected
|
||||
|
||||
db.Logger.Trace(curTime, func() (string, int64) {
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
|
|
|
@ -50,7 +50,7 @@ func Create(db *gorm.DB) {
|
|||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil {
|
||||
|
|
|
@ -57,7 +57,7 @@ func Delete(db *gorm.DB) {
|
|||
db.Statement.Build("DELETE", "FROM", "WHERE")
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
|
|
@ -14,7 +14,7 @@ func Query(db *gorm.DB) {
|
|||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
} else {
|
||||
|
|
|
@ -14,8 +14,8 @@ func RowQuery(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if _, ok := db.Get("rows"); ok {
|
||||
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func Update(db *gorm.DB) {
|
|||
db.Statement.AddClause(ConvertToAssignments(db.Statement))
|
||||
db.Statement.Build("UPDATE", "SET", "WHERE")
|
||||
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/utils"
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
|
@ -64,7 +65,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
}
|
||||
case string:
|
||||
fields := strings.FieldsFunc(v, isChar)
|
||||
fields := strings.FieldsFunc(v, utils.IsChar)
|
||||
|
||||
// normal field names
|
||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||
|
@ -100,7 +101,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar)
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
|
|
|
@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector {
|
|||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
|
||||
db.DB, err = sql.Open("sqlserver", dialector.DSN)
|
||||
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector {
|
|||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
db.DB, err = sql.Open("mysql", dialector.DSN)
|
||||
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector {
|
|||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
|
||||
db.DB, err = sql.Open("postgres", dialector.DSN)
|
||||
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ func Open(dsn string) gorm.Dialector {
|
|||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
// register callbacks
|
||||
callbacks.RegisterDefaultCallbacks(db)
|
||||
db.DB, err = sql.Open("sqlite3", dialector.DSN)
|
||||
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,6 @@ package gorm
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -20,19 +18,3 @@ var (
|
|||
// ErrMissingWhereClause missing where clause
|
||||
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
|
||||
)
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embeded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
}
|
||||
|
||||
func isChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
|
||||
}
|
|
@ -196,14 +196,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||
// Begin begins a transaction
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if beginner, ok := tx.DB.(TxBeginner); ok {
|
||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
||||
var opt *sql.TxOptions
|
||||
var err error
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil {
|
||||
if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
} else {
|
||||
|
@ -214,7 +214,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
|
|||
|
||||
// Commit commit a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil {
|
||||
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
|
||||
db.AddError(comminter.Commit())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
|
@ -224,7 +224,7 @@ func (db *DB) Commit() *DB {
|
|||
|
||||
// Rollback rollback a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil {
|
||||
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
|
||||
db.AddError(comminter.Rollback())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
|
|
58
gorm.go
58
gorm.go
|
@ -21,23 +21,25 @@ type Config struct {
|
|||
Logger logger.Interface
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
NowFunc func() time.Time
|
||||
}
|
||||
|
||||
type shared struct {
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
// ConnPool db conn pool
|
||||
ConnPool ConnPool
|
||||
// Dialector database dialector
|
||||
Dialector
|
||||
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
quoteChars [2]byte
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
type DB struct {
|
||||
*Config
|
||||
Dialector
|
||||
Instance
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
DB CommonDB
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Statement *Statement
|
||||
clone bool
|
||||
*shared
|
||||
}
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
|
@ -65,14 +67,17 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
config.NowFunc = func() time.Time { return time.Now().Local() }
|
||||
}
|
||||
|
||||
if dialector != nil {
|
||||
config.Dialector = dialector
|
||||
}
|
||||
|
||||
if config.cacheStore == nil {
|
||||
config.cacheStore = &sync.Map{}
|
||||
}
|
||||
|
||||
db = &DB{
|
||||
Config: config,
|
||||
Dialector: dialector,
|
||||
ClauseBuilders: map[string]clause.ClauseBuilder{},
|
||||
clone: true,
|
||||
shared: &shared{
|
||||
cacheStore: &sync.Map{},
|
||||
},
|
||||
}
|
||||
|
||||
db.callbacks = initializeCallbacks(db)
|
||||
|
@ -91,7 +96,7 @@ func (db *DB) Session(config *Session) *DB {
|
|||
)
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Context = config.Context
|
||||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
if config.Logger != nil {
|
||||
|
@ -142,23 +147,26 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
|
|||
return db.Migrator().AutoMigrate(dst...)
|
||||
}
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) {
|
||||
db.Statement.AddError(err)
|
||||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone {
|
||||
ctx := db.Instance.Context
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
ctx := context.Background()
|
||||
if db.Statement != nil {
|
||||
ctx = db.Statement.Context
|
||||
}
|
||||
|
||||
return &DB{
|
||||
Instance: Instance{
|
||||
Context: ctx,
|
||||
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
|
||||
},
|
||||
Config: db.Config,
|
||||
Dialector: db.Dialector,
|
||||
ClauseBuilders: db.ClauseBuilders,
|
||||
DB: db.DB,
|
||||
shared: db.shared,
|
||||
Statement: &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Context: ctx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ type Dialector interface {
|
|||
Explain(sql string, vars ...interface{}) string
|
||||
}
|
||||
|
||||
// CommonDB common db interface
|
||||
type CommonDB interface {
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embeded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time `gorm:"index"`
|
||||
}
|
36
statement.go
36
statement.go
|
@ -14,30 +14,6 @@ import (
|
|||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
// Instance db instance
|
||||
type Instance struct {
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Context context.Context
|
||||
Statement *Statement
|
||||
}
|
||||
|
||||
func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
||||
if len(clauses) > 0 {
|
||||
instance.Statement.Build(clauses...)
|
||||
}
|
||||
return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
|
||||
}
|
||||
|
||||
// AddError add error to instance
|
||||
func (inst *Instance) AddError(err error) {
|
||||
if inst.Error == nil {
|
||||
inst.Error = err
|
||||
} else if err != nil {
|
||||
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Statement statement
|
||||
type Statement struct {
|
||||
Table string
|
||||
|
@ -48,8 +24,12 @@ type Statement struct {
|
|||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
Settings sync.Map
|
||||
ConnPool ConnPool
|
||||
DB *DB
|
||||
Schema *schema.Schema
|
||||
Context context.Context
|
||||
Error error
|
||||
RowsAffected int64
|
||||
RaiseErrorOnNotFound bool
|
||||
|
||||
// SQL Builder
|
||||
|
@ -246,6 +226,14 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
|||
return conditions
|
||||
}
|
||||
|
||||
func (stmt *Statement) AddError(err error) {
|
||||
if stmt.Error == nil {
|
||||
stmt.Error = err
|
||||
} else if err != nil {
|
||||
stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
func (stmt *Statement) Build(clauses ...string) {
|
||||
var firstClauseWritten bool
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`)
|
||||
|
@ -18,3 +19,7 @@ func FileWithLineNum() string {
|
|||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func IsChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue