From a145d7e01946a4f0777b0c1764bd8e24d3425789 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 9 Mar 2020 13:10:48 +0800 Subject: [PATCH] Refactor structure --- callbacks.go | 3 ++ callbacks/create.go | 2 +- callbacks/delete.go | 2 +- callbacks/query.go | 2 +- callbacks/raw.go | 2 +- callbacks/row.go | 4 +-- callbacks/update.go | 2 +- chainable_api.go | 5 +-- dialects/mssql/mssql.go | 3 +- dialects/mysql/mysql.go | 2 +- dialects/postgres/postgres.go | 3 +- dialects/sqlite/sqlite.go | 2 +- helpers.go => errors.go | 18 ---------- finisher_api.go | 8 ++--- gorm.go | 64 ++++++++++++++++++++--------------- interfaces.go | 4 +-- model.go | 15 ++++++++ statement.go | 36 +++++++------------- utils/utils.go | 5 +++ 19 files changed, 91 insertions(+), 91 deletions(-) rename helpers.go => errors.go (60%) create mode 100644 model.go diff --git a/callbacks.go b/callbacks.go index db8261c4..d1164019 100644 --- a/callbacks.go +++ b/callbacks.go @@ -90,6 +90,9 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { + db.Error = stmt.Error + db.RowsAffected = stmt.RowsAffected + db.Logger.Trace(curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) diff --git a/callbacks/create.go b/callbacks/create.go index 2e1b3381..42dcda27 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -50,7 +50,7 @@ func Create(db *gorm.DB) { db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { if db.Statement.Schema != nil { diff --git a/callbacks/delete.go b/callbacks/delete.go index 05d00d0a..50b2880a 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -57,7 +57,7 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { db.RowsAffected, _ = result.RowsAffected() diff --git a/callbacks/query.go b/callbacks/query.go index 26c0e0ad..00820bfd 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -14,7 +14,7 @@ func Query(db *gorm.DB) { db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return diff --git a/callbacks/raw.go b/callbacks/raw.go index e8cad25d..ce125e61 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -5,7 +5,7 @@ import ( ) func RawExec(db *gorm.DB) { - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) } else { diff --git a/callbacks/row.go b/callbacks/row.go index f7d6752d..b84cf694 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -14,8 +14,8 @@ func RowQuery(db *gorm.DB) { } if _, ok := db.Get("rows"); ok { - db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { - db.Statement.Dest = db.DB.QueryRowContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } } diff --git a/callbacks/update.go b/callbacks/update.go index ca31bf18..eab9f929 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -47,7 +47,7 @@ func Update(db *gorm.DB) { db.Statement.AddClause(ConvertToAssignments(db.Statement)) db.Statement.Build("UPDATE", "SET", "WHERE") - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { db.RowsAffected, _ = result.RowsAffected() diff --git a/chainable_api.go b/chainable_api.go index 6f80d4be..98c1898e 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/utils" ) // Model specify the model you would like to run db operations @@ -64,7 +65,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } case string: - fields := strings.FieldsFunc(v, isChar) + fields := strings.FieldsFunc(v, utils.IsChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -100,7 +101,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) } else { tx.Statement.Omits = columns } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 91574787..7e51de75 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - - db.DB, err = sql.Open("sqlserver", dialector.DSN) + db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) return } diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 9d16507e..55b5a53f 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("mysql", dialector.DSN) + db.ConnPool, err = sql.Open("mysql", dialector.DSN) return } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 0005f7ed..e90fa4ae 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -26,8 +26,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - - db.DB, err = sql.Open("postgres", dialector.DSN) + db.ConnPool, err = sql.Open("postgres", dialector.DSN) return } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 91762343..8e3cc058 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -23,7 +23,7 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("sqlite3", dialector.DSN) + db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) return } diff --git a/helpers.go b/errors.go similarity index 60% rename from helpers.go rename to errors.go index 241d3fbd..32f55e01 100644 --- a/helpers.go +++ b/errors.go @@ -2,8 +2,6 @@ package gorm import ( "errors" - "time" - "unicode" ) var ( @@ -20,19 +18,3 @@ var ( // ErrMissingWhereClause missing where clause ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") ) - -// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embeded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primarykey"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `gorm:"index"` -} - -func isChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) -} diff --git a/finisher_api.go b/finisher_api.go index 51d9b409..62c1af30 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -196,14 +196,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er // Begin begins a transaction func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { tx = db.getInstance() - if beginner, ok := tx.DB.(TxBeginner); ok { + if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { var opt *sql.TxOptions var err error if len(opts) > 0 { opt = opts[0] } - if tx.DB, err = beginner.BeginTx(db.Context, opt); err != nil { + if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil { tx.AddError(err) } } else { @@ -214,7 +214,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { // Commit commit a transaction func (db *DB) Commit() *DB { - if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -224,7 +224,7 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { - if comminter, ok := db.DB.(TxCommiter); ok && comminter != nil { + if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { db.AddError(comminter.Rollback()) } else { db.AddError(ErrInvalidTransaction) diff --git a/gorm.go b/gorm.go index eac95868..b238d572 100644 --- a/gorm.go +++ b/gorm.go @@ -21,23 +21,25 @@ type Config struct { Logger logger.Interface // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time -} -type shared struct { + // ClauseBuilders clause builder + ClauseBuilders map[string]clause.ClauseBuilder + // ConnPool db conn pool + ConnPool ConnPool + // Dialector database dialector + Dialector + callbacks *callbacks cacheStore *sync.Map - quoteChars [2]byte } // DB GORM DB definition type DB struct { *Config - Dialector - Instance - ClauseBuilders map[string]clause.ClauseBuilder - DB CommonDB - clone bool - *shared + Error error + RowsAffected int64 + Statement *Statement + clone bool } // Session session config when create session with Session() method @@ -65,14 +67,17 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.NowFunc = func() time.Time { return time.Now().Local() } } + if dialector != nil { + config.Dialector = dialector + } + + if config.cacheStore == nil { + config.cacheStore = &sync.Map{} + } + db = &DB{ - Config: config, - Dialector: dialector, - ClauseBuilders: map[string]clause.ClauseBuilder{}, - clone: true, - shared: &shared{ - cacheStore: &sync.Map{}, - }, + Config: config, + clone: true, } db.callbacks = initializeCallbacks(db) @@ -91,7 +96,7 @@ func (db *DB) Session(config *Session) *DB { ) if config.Context != nil { - tx.Context = config.Context + tx.Statement.Context = config.Context } if config.Logger != nil { @@ -142,23 +147,26 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { return db.Migrator().AutoMigrate(dst...) } +// AddError add error to db +func (db *DB) AddError(err error) { + db.Statement.AddError(err) +} + func (db *DB) getInstance() *DB { if db.clone { - ctx := db.Instance.Context - if ctx == nil { - ctx = context.Background() + ctx := context.Background() + if db.Statement != nil { + ctx = db.Statement.Context } return &DB{ - Instance: Instance{ - Context: ctx, - Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, + Config: db.Config, + Statement: &Statement{ + DB: db, + ConnPool: db.ConnPool, + Clauses: map[string]clause.Clause{}, + Context: ctx, }, - Config: db.Config, - Dialector: db.Dialector, - ClauseBuilders: db.ClauseBuilders, - DB: db.DB, - shared: db.shared, } } diff --git a/interfaces.go b/interfaces.go index c89c3624..9859d1fa 100644 --- a/interfaces.go +++ b/interfaces.go @@ -18,8 +18,8 @@ type Dialector interface { Explain(sql string, vars ...interface{}) string } -// CommonDB common db interface -type CommonDB interface { +// ConnPool db conns pool interface +type ConnPool interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) diff --git a/model.go b/model.go new file mode 100644 index 00000000..fdee99dc --- /dev/null +++ b/model.go @@ -0,0 +1,15 @@ +package gorm + +import "time" + +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embeded into your model or you may build your own model without it +// type User struct { +// gorm.Model +// } +type Model struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `gorm:"index"` +} diff --git a/statement.go b/statement.go index f04ea269..10b62567 100644 --- a/statement.go +++ b/statement.go @@ -14,30 +14,6 @@ import ( "github.com/jinzhu/gorm/schema" ) -// Instance db instance -type Instance struct { - Error error - RowsAffected int64 - Context context.Context - Statement *Statement -} - -func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { - if len(clauses) > 0 { - instance.Statement.Build(clauses...) - } - return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars -} - -// AddError add error to instance -func (inst *Instance) AddError(err error) { - if inst.Error == nil { - inst.Error = err - } else if err != nil { - inst.Error = fmt.Errorf("%v; %w", inst.Error, err) - } -} - // Statement statement type Statement struct { Table string @@ -48,8 +24,12 @@ type Statement struct { Selects []string // selected columns Omits []string // omit columns Settings sync.Map + ConnPool ConnPool DB *DB Schema *schema.Schema + Context context.Context + Error error + RowsAffected int64 RaiseErrorOnNotFound bool // SQL Builder @@ -246,6 +226,14 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con return conditions } +func (stmt *Statement) AddError(err error) { + if stmt.Error == nil { + stmt.Error = err + } else if err != nil { + stmt.Error = fmt.Errorf("%v; %w", stmt.Error, err) + } +} + // Build build sql with clauses names func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool diff --git a/utils/utils.go b/utils/utils.go index e7ed512c..25cd585a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -4,6 +4,7 @@ import ( "fmt" "regexp" "runtime" + "unicode" ) var goSrcRegexp = regexp.MustCompile(`/gorm/.*.go`) @@ -18,3 +19,7 @@ func FileWithLineNum() string { } return "" } + +func IsChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) +}