Refactor structure

This commit is contained in:
Jinzhu 2020-03-09 13:10:48 +08:00
parent 078ba75b9c
commit a145d7e019
19 changed files with 91 additions and 91 deletions

View File

@ -90,6 +90,9 @@ func (p *processor) Execute(db *DB) {
}
if stmt := db.Statement; stmt != nil {
db.Error = stmt.Error
db.RowsAffected = stmt.RowsAffected
db.Logger.Trace(curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)

View File

@ -50,7 +50,7 @@ func Create(db *gorm.DB) {
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
if db.Statement.Schema != nil {

View File

@ -57,7 +57,7 @@ func Delete(db *gorm.DB) {
db.Statement.Build("DELETE", "FROM", "WHERE")
}
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()

View File

@ -14,7 +14,7 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return

View File

@ -5,7 +5,7 @@ import (
)
func RawExec(db *gorm.DB) {
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
} else {

View File

@ -14,8 +14,8 @@ func RowQuery(db *gorm.DB) {
}
if _, ok := db.Get("rows"); ok {
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
}

View File

@ -47,7 +47,7 @@ func Update(db *gorm.DB) {
db.Statement.AddClause(ConvertToAssignments(db.Statement))
db.Statement.Build("UPDATE", "SET", "WHERE")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()

View File

@ -5,6 +5,7 @@ import (
"strings"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/utils"
)
// Model specify the model you would like to run db operations
@ -64,7 +65,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
}
}
case string:
fields := strings.FieldsFunc(v, isChar)
fields := strings.FieldsFunc(v, utils.IsChar)
// normal field names
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
@ -100,7 +101,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar)
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
} else {
tx.Statement.Omits = columns
}

View File

@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("sqlserver", dialector.DSN)
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
return
}

View File

@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("mysql", dialector.DSN)
db.ConnPool, err = sql.Open("mysql", dialector.DSN)
return
}

View File

@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("postgres", dialector.DSN)
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
return
}

View File

@ -23,7 +23,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("sqlite3", dialector.DSN)
db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
return
}

View File

@ -2,8 +2,6 @@ package gorm
import (
"errors"
"time"
"unicode"
)
var (
@ -20,19 +18,3 @@ var (
// ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
)
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}
func isChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
}

View File

@ -196,14 +196,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
// Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
tx = db.getInstance()
if beginner, ok := tx.DB.(TxBeginner); ok {
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
var opt *sql.TxOptions
var err error
if len(opts) > 0 {
opt = opts[0]
}
if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil {
if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil {
tx.AddError(err)
}
} else {
@ -214,7 +214,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
// Commit commit a transaction
func (db *DB) Commit() *DB {
if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Commit())
} else {
db.AddError(ErrInvalidTransaction)
@ -224,7 +224,7 @@ func (db *DB) Commit() *DB {
// Rollback rollback a transaction
func (db *DB) Rollback() *DB {
if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Rollback())
} else {
db.AddError(ErrInvalidTransaction)

64
gorm.go
View File

@ -21,23 +21,25 @@ type Config struct {
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
}
type shared struct {
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
callbacks *callbacks
cacheStore *sync.Map
quoteChars [2]byte
}
// DB GORM DB definition
type DB struct {
*Config
Dialector
Instance
ClauseBuilders map[string]clause.ClauseBuilder
DB CommonDB
clone bool
*shared
Error error
RowsAffected int64
Statement *Statement
clone bool
}
// Session session config when create session with Session() method
@ -65,14 +67,17 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
config.NowFunc = func() time.Time { return time.Now().Local() }
}
if dialector != nil {
config.Dialector = dialector
}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
db = &DB{
Config: config,
Dialector: dialector,
ClauseBuilders: map[string]clause.ClauseBuilder{},
clone: true,
shared: &shared{
cacheStore: &sync.Map{},
},
Config: config,
clone: true,
}
db.callbacks = initializeCallbacks(db)
@ -91,7 +96,7 @@ func (db *DB) Session(config *Session) *DB {
)
if config.Context != nil {
tx.Context = config.Context
tx.Statement.Context = config.Context
}
if config.Logger != nil {
@ -142,23 +147,26 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// AddError add error to db
func (db *DB) AddError(err error) {
db.Statement.AddError(err)
}
func (db *DB) getInstance() *DB {
if db.clone {
ctx := db.Instance.Context
if ctx == nil {
ctx = context.Background()
ctx := context.Background()
if db.Statement != nil {
ctx = db.Statement.Context
}
return &DB{
Instance: Instance{
Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
Config: db.Config,
Statement: &Statement{
DB: db,
ConnPool: db.ConnPool,
Clauses: map[string]clause.Clause{},
Context: ctx,
},
Config: db.Config,
Dialector: db.Dialector,
ClauseBuilders: db.ClauseBuilders,
DB: db.DB,
shared: db.shared,
}
}

View File

@ -18,8 +18,8 @@ type Dialector interface {
Explain(sql string, vars ...interface{}) string
}
// CommonDB common db interface
type CommonDB interface {
// ConnPool db conns pool interface
type ConnPool interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)

15
model.go Normal file
View File

@ -0,0 +1,15 @@
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}

View File

@ -14,30 +14,6 @@ import (
"github.com/jinzhu/gorm/schema"
)
// Instance db instance
type Instance struct {
Error error
RowsAffected int64
Context context.Context
Statement *Statement
}
func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
if len(clauses) > 0 {
instance.Statement.Build(clauses...)
}
return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
}
// AddError add error to instance
func (inst *Instance) AddError(err error) {
if inst.Error == nil {
inst.Error = err
} else if err != nil {
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
}
}
// Statement statement
type Statement struct {
Table string
@ -48,8 +24,12 @@ type Statement struct {
Selects []string // selected columns
Omits []string // omit columns
Settings sync.Map
ConnPool ConnPool
DB *DB
Schema *schema.Schema
Context context.Context
Error error
RowsAffected int64
RaiseErrorOnNotFound bool
// SQL Builder
@ -246,6 +226,14 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
return conditions
}
func (stmt *Statement) AddError(err error) {
if stmt.Error == nil {
stmt.Error = err
} else if err != nil {
stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err)
}
}
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool

View File

@ -4,6 +4,7 @@ import (
"fmt"
"regexp"
"runtime"
"unicode"
)
var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`)
@ -18,3 +19,7 @@ func FileWithLineNum() string {
}
return ""
}
func IsChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
}