Refactor structure

This commit is contained in:
Jinzhu 2020-03-09 13:10:48 +08:00
parent 078ba75b9c
commit a145d7e019
19 changed files with 91 additions and 91 deletions

View File

@ -90,6 +90,9 @@ func (p *processor) Execute(db *DB) {
} }
if stmt := db.Statement; stmt != nil { if stmt := db.Statement; stmt != nil {
db.Error = stmt.Error
db.RowsAffected = stmt.RowsAffected
db.Logger.Trace(curTime, func() (string, int64) { db.Logger.Trace(curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error) }, db.Error)

View File

@ -50,7 +50,7 @@ func Create(db *gorm.DB) {
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if err == nil {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {

View File

@ -57,7 +57,7 @@ func Delete(db *gorm.DB) {
db.Statement.Build("DELETE", "FROM", "WHERE") db.Statement.Build("DELETE", "FROM", "WHERE")
} }
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if err == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()

View File

@ -14,7 +14,7 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
} }
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
return return

View File

@ -5,7 +5,7 @@ import (
) )
func RawExec(db *gorm.DB) { func RawExec(db *gorm.DB) {
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
} else { } else {

View File

@ -14,8 +14,8 @@ func RowQuery(db *gorm.DB) {
} }
if _, ok := db.Get("rows"); ok { if _, ok := db.Get("rows"); ok {
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else { } else {
db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} }
} }

View File

@ -47,7 +47,7 @@ func Update(db *gorm.DB) {
db.Statement.AddClause(ConvertToAssignments(db.Statement)) db.Statement.AddClause(ConvertToAssignments(db.Statement))
db.Statement.Build("UPDATE", "SET", "WHERE") db.Statement.Build("UPDATE", "SET", "WHERE")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if err == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()

View File

@ -5,6 +5,7 @@ import (
"strings" "strings"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/utils"
) )
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
@ -64,7 +65,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
} }
} }
case string: case string:
fields := strings.FieldsFunc(v, isChar) fields := strings.FieldsFunc(v, utils.IsChar)
// normal field names // normal field names
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
@ -100,7 +101,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
} else { } else {
tx.Statement.Omits = columns tx.Statement.Omits = columns
} }

View File

@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db)
db.ConnPool, err = sql.Open("sqlserver", dialector.DSN)
db.DB, err = sql.Open("sqlserver", dialector.DSN)
return return
} }

View File

@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("mysql", dialector.DSN) db.ConnPool, err = sql.Open("mysql", dialector.DSN)
return return
} }

View File

@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db)
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
db.DB, err = sql.Open("postgres", dialector.DSN)
return return
} }

View File

@ -23,7 +23,7 @@ func Open(dsn string) gorm.Dialector {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks // register callbacks
callbacks.RegisterDefaultCallbacks(db) callbacks.RegisterDefaultCallbacks(db)
db.DB, err = sql.Open("sqlite3", dialector.DSN) db.ConnPool, err = sql.Open("sqlite3", dialector.DSN)
return return
} }

View File

@ -2,8 +2,6 @@ package gorm
import ( import (
"errors" "errors"
"time"
"unicode"
) )
var ( var (
@ -20,19 +18,3 @@ var (
// ErrMissingWhereClause missing where clause // ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ErrMissingWhereClause = errors.New("missing WHERE clause while deleting")
) )
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}
func isChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
}

View File

@ -196,14 +196,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
// Begin begins a transaction // Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if beginner, ok := tx.DB.(TxBeginner); ok { if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
var opt *sql.TxOptions var opt *sql.TxOptions
var err error var err error
if len(opts) > 0 { if len(opts) > 0 {
opt = opts[0] opt = opts[0]
} }
if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil { if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil {
tx.AddError(err) tx.AddError(err)
} }
} else { } else {
@ -214,7 +214,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
// Commit commit a transaction // Commit commit a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Commit()) db.AddError(comminter.Commit())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
@ -224,7 +224,7 @@ func (db *DB) Commit() *DB {
// Rollback rollback a transaction // Rollback rollback a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Rollback()) db.AddError(comminter.Rollback())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)

64
gorm.go
View File

@ -21,23 +21,25 @@ type Config struct {
Logger logger.Interface Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp // NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time NowFunc func() time.Time
}
type shared struct { // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
callbacks *callbacks callbacks *callbacks
cacheStore *sync.Map cacheStore *sync.Map
quoteChars [2]byte
} }
// DB GORM DB definition // DB GORM DB definition
type DB struct { type DB struct {
*Config *Config
Dialector Error error
Instance RowsAffected int64
ClauseBuilders map[string]clause.ClauseBuilder Statement *Statement
DB CommonDB clone bool
clone bool
*shared
} }
// Session session config when create session with Session() method // Session session config when create session with Session() method
@ -65,14 +67,17 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
config.NowFunc = func() time.Time { return time.Now().Local() } config.NowFunc = func() time.Time { return time.Now().Local() }
} }
if dialector != nil {
config.Dialector = dialector
}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
db = &DB{ db = &DB{
Config: config, Config: config,
Dialector: dialector, clone: true,
ClauseBuilders: map[string]clause.ClauseBuilder{},
clone: true,
shared: &shared{
cacheStore: &sync.Map{},
},
} }
db.callbacks = initializeCallbacks(db) db.callbacks = initializeCallbacks(db)
@ -91,7 +96,7 @@ func (db *DB) Session(config *Session) *DB {
) )
if config.Context != nil { if config.Context != nil {
tx.Context = config.Context tx.Statement.Context = config.Context
} }
if config.Logger != nil { if config.Logger != nil {
@ -142,23 +147,26 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...) return db.Migrator().AutoMigrate(dst...)
} }
// AddError add error to db
func (db *DB) AddError(err error) {
db.Statement.AddError(err)
}
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {
if db.clone { if db.clone {
ctx := db.Instance.Context ctx := context.Background()
if ctx == nil { if db.Statement != nil {
ctx = context.Background() ctx = db.Statement.Context
} }
return &DB{ return &DB{
Instance: Instance{ Config: db.Config,
Context: ctx, Statement: &Statement{
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, DB: db,
ConnPool: db.ConnPool,
Clauses: map[string]clause.Clause{},
Context: ctx,
}, },
Config: db.Config,
Dialector: db.Dialector,
ClauseBuilders: db.ClauseBuilders,
DB: db.DB,
shared: db.shared,
} }
} }

View File

@ -18,8 +18,8 @@ type Dialector interface {
Explain(sql string, vars ...interface{}) string Explain(sql string, vars ...interface{}) string
} }
// CommonDB common db interface // ConnPool db conns pool interface
type CommonDB interface { type ConnPool interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)

15
model.go Normal file
View File

@ -0,0 +1,15 @@
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `gorm:"index"`
}

View File

@ -14,30 +14,6 @@ import (
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
) )
// Instance db instance
type Instance struct {
Error error
RowsAffected int64
Context context.Context
Statement *Statement
}
func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
if len(clauses) > 0 {
instance.Statement.Build(clauses...)
}
return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
}
// AddError add error to instance
func (inst *Instance) AddError(err error) {
if inst.Error == nil {
inst.Error = err
} else if err != nil {
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
}
}
// Statement statement // Statement statement
type Statement struct { type Statement struct {
Table string Table string
@ -48,8 +24,12 @@ type Statement struct {
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
Settings sync.Map Settings sync.Map
ConnPool ConnPool
DB *DB DB *DB
Schema *schema.Schema Schema *schema.Schema
Context context.Context
Error error
RowsAffected int64
RaiseErrorOnNotFound bool RaiseErrorOnNotFound bool
// SQL Builder // SQL Builder
@ -246,6 +226,14 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
return conditions return conditions
} }
func (stmt *Statement) AddError(err error) {
if stmt.Error == nil {
stmt.Error = err
} else if err != nil {
stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err)
}
}
// Build build sql with clauses names // Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) { func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool var firstClauseWritten bool

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"runtime" "runtime"
"unicode"
) )
var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`)
@ -18,3 +19,7 @@ func FileWithLineNum() string {
} }
return "" return ""
} }
func IsChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
}