forked from mirror/gorm
Refactor based on golint
This commit is contained in:
parent
3055bad1e8
commit
ec110657da
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// defaultCallback hold default callbacks defined by gorm
|
||||
// DefaultCallback default callbacks defined by gorm
|
||||
var DefaultCallback = &Callback{}
|
||||
|
||||
// Callback contains callbacks that used when CURD objects
|
||||
|
|
|
@ -85,14 +85,14 @@ func createCallback(scope *Scope) {
|
|||
returningColumn = scope.Quote(primaryField.DBName)
|
||||
}
|
||||
|
||||
lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn)
|
||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||
|
||||
if len(columns) == 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v DEFAULT VALUES%v%v",
|
||||
quotedTableName,
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
|
@ -101,13 +101,13 @@ func createCallback(scope *Scope) {
|
|||
strings.Join(columns, ","),
|
||||
strings.Join(placeholders, ","),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql
|
||||
if lastInsertIdReturningSuffix == "" || primaryField == nil {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
|
@ -119,7 +119,7 @@ func createCallback(scope *Scope) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,15 +48,15 @@ func queryCallback(scope *Scope) {
|
|||
return
|
||||
}
|
||||
|
||||
scope.prepareQuerySql()
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.db.RowsAffected = 0
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.Sql += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
||||
if rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
|
@ -80,7 +80,7 @@ func queryCallback(scope *Scope) {
|
|||
}
|
||||
|
||||
if scope.db.RowsAffected == 0 && !isSlice {
|
||||
scope.Err(RecordNotFound)
|
||||
scope.Err(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ type Dialect interface {
|
|||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||
LastInsertIdReturningSuffix(tableName, columnName string) string
|
||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||
}
|
||||
|
||||
var dialectsMap = map[string]Dialect{}
|
||||
|
|
|
@ -132,6 +132,6 @@ func (commonDialect) SelectFromDummyTable() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string {
|
||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -106,11 +106,11 @@ func (s postgres) currentDatabase() (name string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string {
|
||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertId() bool {
|
||||
func (postgres) SupportLastInsertID() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -145,6 +145,6 @@ func (mssql) SelectFromDummyTable() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (mssql) LastInsertIdReturningSuffix(tableName, columnName string) string {
|
||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
|
16
errors.go
16
errors.go
|
@ -6,24 +6,31 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
RecordNotFound = errors.New("record not found")
|
||||
InvalidSql = errors.New("invalid sql")
|
||||
NoValidTransaction = errors.New("no valid transaction")
|
||||
CantStartTransaction = errors.New("can't start transaction")
|
||||
// ErrRecordNotFound record not found, happens when you are looking up with a struct, and haven't find any matched data
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
// ErrInvalidSQL invalid SQL, happens when you passed invalid SQL
|
||||
ErrInvalidSQL = errors.New("invalid SQL")
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||
)
|
||||
|
||||
type errorsInterface interface {
|
||||
GetErrors() []error
|
||||
}
|
||||
|
||||
// Errors contains all happened errors
|
||||
type Errors struct {
|
||||
errors []error
|
||||
}
|
||||
|
||||
// GetErrors get all happened errors
|
||||
func (errs Errors) GetErrors() []error {
|
||||
return errs.errors
|
||||
}
|
||||
|
||||
// Add add an error
|
||||
func (errs *Errors) Add(err error) {
|
||||
if errors, ok := err.(errorsInterface); ok {
|
||||
for _, err := range errors.GetErrors() {
|
||||
|
@ -39,6 +46,7 @@ func (errs *Errors) Add(err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Error format happened errors
|
||||
func (errs Errors) Error() string {
|
||||
var errors = []string{}
|
||||
for _, e := range errs.errors {
|
||||
|
|
2
field.go
2
field.go
|
@ -7,12 +7,14 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// Field model field definition
|
||||
type Field struct {
|
||||
*StructField
|
||||
IsBlank bool
|
||||
Field reflect.Value
|
||||
}
|
||||
|
||||
// Set set a value to the field
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
if !field.Field.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
|
|
|
@ -7,40 +7,54 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||
type JoinTableHandlerInterface interface {
|
||||
// initialize join table handler
|
||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||
// Table return join table's table name
|
||||
Table(db *DB) string
|
||||
// Add create relationship in join table for source and destination
|
||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||
// Delete delete relationship in join table for sources
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
// JoinWith query with `Join` conditions
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
// SourceForeignKeys return source foreign keys
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
DestinationForeignKeys() []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableForeignKey join table foreign key struct
|
||||
type JoinTableForeignKey struct {
|
||||
DBName string
|
||||
AssociationDBName string
|
||||
}
|
||||
|
||||
// JoinTableSource is a struct that contains model type and foreign keys
|
||||
type JoinTableSource struct {
|
||||
ModelType reflect.Type
|
||||
ForeignKeys []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableHandler default join table handler
|
||||
type JoinTableHandler struct {
|
||||
TableName string `sql:"-"`
|
||||
Source JoinTableSource `sql:"-"`
|
||||
Destination JoinTableSource `sql:"-"`
|
||||
}
|
||||
|
||||
// SourceForeignKeys return source foreign keys
|
||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||
return s.Source.ForeignKeys
|
||||
}
|
||||
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||
return s.Destination.ForeignKeys
|
||||
}
|
||||
|
||||
// Setup initialize a default join table handler
|
||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||
s.TableName = tableName
|
||||
|
||||
|
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
|
|||
}
|
||||
}
|
||||
|
||||
// Table return join table's table name
|
||||
func (s JoinTableHandler) Table(db *DB) string {
|
||||
return s.TableName
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
values := map[string]interface{}{}
|
||||
|
||||
for _, source := range sources {
|
||||
|
@ -89,9 +104,10 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
|
|||
return values
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
|
||||
// Add create relationship in join table for source and destination
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||
scope := db.NewScope("")
|
||||
searchMap := s.GetSearchMap(db, source1, source2)
|
||||
searchMap := s.getSearchMap(db, source, destination)
|
||||
|
||||
var assignColumns, binVars, conditions []string
|
||||
var values []interface{}
|
||||
|
@ -120,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
|
|||
return db.Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
// Delete delete relationship in join table for sources
|
||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||
var (
|
||||
scope = db.NewScope(nil)
|
||||
|
@ -127,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
|||
values []interface{}
|
||||
)
|
||||
|
||||
for key, value := range s.GetSearchMap(db, sources...) {
|
||||
for key, value := range s.getSearchMap(db, sources...) {
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
@ -135,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
|||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
// JoinWith query with `Join` conditions
|
||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||
var (
|
||||
scope = db.NewScope(source)
|
||||
|
|
16
logger.go
16
logger.go
|
@ -11,23 +11,25 @@ import (
|
|||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
Print(v ...interface{})
|
||||
}
|
||||
|
||||
type LogWriter interface {
|
||||
type logWriter interface {
|
||||
Println(v ...interface{})
|
||||
}
|
||||
|
||||
// Logger default logger
|
||||
type Logger struct {
|
||||
LogWriter
|
||||
logWriter
|
||||
}
|
||||
|
||||
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||
|
||||
// Format log
|
||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||
|
||||
// Print format & print log
|
||||
func (logger Logger) Print(values ...interface{}) {
|
||||
if len(values) > 1 {
|
||||
level := values[0]
|
||||
|
|
70
main.go
70
main.go
|
@ -6,19 +6,9 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NowFunc returns current time, this function is exported in order to be able
|
||||
// to give the flexibility to the developer to customize it according to their
|
||||
// needs
|
||||
//
|
||||
// e.g: return time.Now().UTC()
|
||||
//
|
||||
var NowFunc = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// DB contains information for current db connection
|
||||
type DB struct {
|
||||
Value interface{}
|
||||
Error error
|
||||
|
@ -36,6 +26,14 @@ type DB struct {
|
|||
joinTableHandlers map[string]JoinTableHandler
|
||||
}
|
||||
|
||||
// Open open a new db connection, need to import driver first, for example:
|
||||
//
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
// func main() {
|
||||
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||
// }
|
||||
// GORM has wrapped some drivers, for easier to remember its name, so you could import the mysql driver with
|
||||
// import _ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||
var db DB
|
||||
var err error
|
||||
|
@ -44,7 +42,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
|||
err = errors.New("invalid database source")
|
||||
} else {
|
||||
var source string
|
||||
var dbSql sqlCommon
|
||||
var dbSQL sqlCommon
|
||||
|
||||
switch value := args[0].(type) {
|
||||
case string:
|
||||
|
@ -55,19 +53,19 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
|||
driver = value
|
||||
source = args[1].(string)
|
||||
}
|
||||
dbSql, err = sql.Open(driver, source)
|
||||
dbSQL, err = sql.Open(driver, source)
|
||||
case sqlCommon:
|
||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||
dbSql = value
|
||||
dbSQL = value
|
||||
}
|
||||
|
||||
db = DB{
|
||||
dialect: newDialect(dialect, dbSql.(*sql.DB)),
|
||||
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
|
||||
logger: defaultLogger,
|
||||
callbacks: DefaultCallback,
|
||||
source: source,
|
||||
values: map[string]interface{}{},
|
||||
db: dbSql,
|
||||
db: dbSQL,
|
||||
}
|
||||
db.parent = &db
|
||||
|
||||
|
@ -79,14 +77,17 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
|||
return &db, err
|
||||
}
|
||||
|
||||
// Close close current db connection
|
||||
func (s *DB) Close() error {
|
||||
return s.parent.db.(*sql.DB).Close()
|
||||
}
|
||||
|
||||
// DB get `*sql.DB` from current connection
|
||||
func (s *DB) DB() *sql.DB {
|
||||
return s.db.(*sql.DB)
|
||||
}
|
||||
|
||||
// New initialize a new db connection without any search conditions
|
||||
func (s *DB) New() *DB {
|
||||
clone := s.clone()
|
||||
clone.search = nil
|
||||
|
@ -94,29 +95,34 @@ func (s *DB) New() *DB {
|
|||
return clone
|
||||
}
|
||||
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
// NewScope create a scope for current operation
|
||||
func (s *DB) NewScope(value interface{}) *Scope {
|
||||
dbClone := s.clone()
|
||||
dbClone.Value = value
|
||||
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
||||
}
|
||||
|
||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
||||
// CommonDB return the underlying sql.DB or sql.Tx instance.
|
||||
// Use of this method is discouraged. It's mainly intended to allow
|
||||
// coexistence with legacy non-GORM code.
|
||||
func (s *DB) CommonDB() sqlCommon {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// Callback return Callbacks container, you could add/remove/change callbacks with it
|
||||
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||
// Refer: https://jinzhu.github.io/gorm/development.html#callbacks for more
|
||||
func (s *DB) Callback() *Callback {
|
||||
s.parent.callbacks = s.parent.callbacks.clone()
|
||||
return s.parent.callbacks
|
||||
}
|
||||
|
||||
func (s *DB) SetLogger(l logger) {
|
||||
s.logger = l
|
||||
// SetLogger replace default logger
|
||||
func (s *DB) SetLogger(log logger) {
|
||||
s.logger = log
|
||||
}
|
||||
|
||||
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||
func (s *DB) LogMode(enable bool) *DB {
|
||||
if enable {
|
||||
s.logMode = 2
|
||||
|
@ -126,51 +132,65 @@ func (s *DB) LogMode(enable bool) *DB {
|
|||
return s
|
||||
}
|
||||
|
||||
// SingularTable use singular table by default
|
||||
func (s *DB) SingularTable(enable bool) {
|
||||
modelStructsMap = newModelStructsMap()
|
||||
s.parent.singularTable = enable
|
||||
}
|
||||
|
||||
// Where return a new relation, accepts use `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
|
||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Where(query, args...).db
|
||||
}
|
||||
|
||||
// Or match before conditions or this one, similar to `Where`
|
||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Or(query, args...).db
|
||||
}
|
||||
|
||||
// Not don't match current conditions, similar to `Where`
|
||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Not(query, args...).db
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
func (s *DB) Limit(limit int) *DB {
|
||||
return s.clone().search.Limit(limit).db
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
func (s *DB) Offset(offset int) *DB {
|
||||
return s.clone().search.Offset(offset).db
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database, pass `true` as the second argument to overwrite `Order` conditions
|
||||
func (s *DB) Order(value string, reorder ...bool) *DB {
|
||||
return s.clone().search.Order(value, reorder...).db
|
||||
}
|
||||
|
||||
// Select When querying, specify fields that you want to retrieve from database, by default, will select all fields;
|
||||
// When creating/updating, specify fields that you want to save to database
|
||||
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
||||
return s.clone().search.Select(query, args...).db
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when save to database when creating/updating
|
||||
func (s *DB) Omit(columns ...string) *DB {
|
||||
return s.clone().search.Omit(columns...).db
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
func (s *DB) Group(query string) *DB {
|
||||
return s.clone().search.Group(query).db
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
func (s *DB) Having(query string, values ...interface{}) *DB {
|
||||
return s.clone().search.Having(query, values...).db
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
||||
return s.clone().search.Joins(query, args...).db
|
||||
}
|
||||
|
@ -352,7 +372,7 @@ func (s *DB) Begin() *DB {
|
|||
c.db = interface{}(tx).(sqlCommon)
|
||||
c.AddError(err)
|
||||
} else {
|
||||
c.AddError(CantStartTransaction)
|
||||
c.AddError(ErrCantStartTransaction)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
@ -361,7 +381,7 @@ func (s *DB) Commit() *DB {
|
|||
if db, ok := s.db.(sqlTx); ok {
|
||||
s.AddError(db.Commit())
|
||||
} else {
|
||||
s.AddError(NoValidTransaction)
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
@ -370,7 +390,7 @@ func (s *DB) Rollback() *DB {
|
|||
if db, ok := s.db.(sqlTx); ok {
|
||||
s.AddError(db.Rollback())
|
||||
} else {
|
||||
s.AddError(NoValidTransaction)
|
||||
s.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
@ -380,7 +400,7 @@ func (s *DB) NewRecord(value interface{}) bool {
|
|||
}
|
||||
|
||||
func (s *DB) RecordNotFound() bool {
|
||||
return s.Error == RecordNotFound
|
||||
return s.Error == ErrRecordNotFound
|
||||
}
|
||||
|
||||
// CreateTable create table for models
|
||||
|
@ -541,7 +561,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
|||
|
||||
func (s *DB) AddError(err error) error {
|
||||
if err != nil {
|
||||
if err != RecordNotFound {
|
||||
if err != ErrRecordNotFound {
|
||||
if s.logMode == 0 {
|
||||
go s.print(fileWithLineNum(), err)
|
||||
} else {
|
||||
|
|
|
@ -479,7 +479,7 @@ func TestRaw(t *testing.T) {
|
|||
}
|
||||
|
||||
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
|
||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
|
||||
t.Error("Raw sql to update records")
|
||||
}
|
||||
}
|
||||
|
@ -709,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
|
|||
}
|
||||
|
||||
var user User
|
||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
|
||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
|
||||
t.Errorf("Should have found existing record")
|
||||
}
|
||||
}
|
||||
|
|
1
model.go
1
model.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import "time"
|
||||
|
||||
// Model base model definition, including `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// DefaultTableNameHandler default table name handler
|
||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||
return defaultTableName
|
||||
}
|
||||
|
@ -39,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
|
|||
|
||||
var modelStructsMap = newModelStructsMap()
|
||||
|
||||
// ModelStruct model definition
|
||||
type ModelStruct struct {
|
||||
PrimaryFields []*StructField
|
||||
StructFields []*StructField
|
||||
|
@ -46,10 +48,12 @@ type ModelStruct struct {
|
|||
defaultTableName string
|
||||
}
|
||||
|
||||
// TableName get model's table name
|
||||
func (s *ModelStruct) TableName(db *DB) string {
|
||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||
}
|
||||
|
||||
// StructField model field's struct definition
|
||||
type StructField struct {
|
||||
DBName string
|
||||
Name string
|
||||
|
@ -506,6 +510,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
return &modelStruct
|
||||
}
|
||||
|
||||
// GetStructFields get model's field structs
|
||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||
return scope.GetModelStruct().StructFields
|
||||
}
|
||||
|
|
|
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
|
|||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
@ -981,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
|
|||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
@ -1038,7 +1038,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
|
|||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
|
33
scope.go
33
scope.go
|
@ -10,11 +10,12 @@ import (
|
|||
"reflect"
|
||||
)
|
||||
|
||||
// Scope contain any information of current operation when you perform any operation on the database
|
||||
type Scope struct {
|
||||
Search *search
|
||||
Value interface{}
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
SQL string
|
||||
SQLVars []interface{}
|
||||
db *DB
|
||||
instanceID string
|
||||
primaryKeyField *Field
|
||||
|
@ -23,6 +24,7 @@ type Scope struct {
|
|||
selectAttrs *[]string
|
||||
}
|
||||
|
||||
// IndirectValue return scope's reflect value's indirect value
|
||||
func (scope *Scope) IndirectValue() reflect.Value {
|
||||
return indirect(reflect.ValueOf(scope.Value))
|
||||
}
|
||||
|
@ -43,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
|
|||
return nil
|
||||
}
|
||||
|
||||
// DB return scope's DB connection
|
||||
func (scope *Scope) DB() *DB {
|
||||
return scope.db
|
||||
}
|
||||
|
||||
// SqlDB return *sql.DB
|
||||
func (scope *Scope) SqlDB() sqlCommon {
|
||||
// SQLDB return *sql.DB
|
||||
func (scope *Scope) SQLDB() sqlCommon {
|
||||
return scope.db.db
|
||||
}
|
||||
|
||||
|
@ -100,6 +103,7 @@ func (scope *Scope) HasError() bool {
|
|||
return scope.db.Error != nil
|
||||
}
|
||||
|
||||
// PrimaryFields return scope's primary fields
|
||||
func (scope *Scope) PrimaryFields() (fields []*Field) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsPrimaryKey {
|
||||
|
@ -109,6 +113,7 @@ func (scope *Scope) PrimaryFields() (fields []*Field) {
|
|||
return fields
|
||||
}
|
||||
|
||||
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
|
||||
func (scope *Scope) PrimaryField() *Field {
|
||||
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||
if len(primaryFields) > 1 {
|
||||
|
@ -241,8 +246,8 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||
return exp
|
||||
}
|
||||
|
||||
scope.SqlVars = append(scope.SqlVars, value)
|
||||
return scope.Dialect().BindVar(len(scope.SqlVars))
|
||||
scope.SQLVars = append(scope.SQLVars, value)
|
||||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||
}
|
||||
|
||||
type tabler interface {
|
||||
|
@ -282,10 +287,10 @@ func (scope *Scope) QuotedTableName() (name string) {
|
|||
return scope.Quote(scope.TableName())
|
||||
}
|
||||
|
||||
// CombinedConditionSql get combined condition sql
|
||||
// CombinedConditionSql return combined condition sql
|
||||
func (scope *Scope) CombinedConditionSql() string {
|
||||
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
|
||||
scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql()
|
||||
return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
|
||||
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
|
||||
}
|
||||
|
||||
// FieldByName find gorm.Field with name and db name
|
||||
|
@ -308,7 +313,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
|||
|
||||
// Raw set sql
|
||||
func (scope *Scope) Raw(sql string) *Scope {
|
||||
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
||||
scope.SQL = strings.Replace(sql, "$$", "?", -1)
|
||||
return scope
|
||||
}
|
||||
|
||||
|
@ -317,7 +322,7 @@ func (scope *Scope) Exec() *Scope {
|
|||
defer scope.trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
|
@ -345,17 +350,19 @@ func (scope *Scope) InstanceID() string {
|
|||
return scope.instanceID
|
||||
}
|
||||
|
||||
// InstanceSet set value for current instance, but not for associations
|
||||
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
||||
return scope.Set(name+scope.InstanceID(), value)
|
||||
}
|
||||
|
||||
// InstanceGet get setting from current instance
|
||||
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||
return scope.Get(name + scope.InstanceID())
|
||||
}
|
||||
|
||||
// Begin start a transaction
|
||||
func (scope *Scope) Begin() *Scope {
|
||||
if db, ok := scope.SqlDB().(sqlDb); ok {
|
||||
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||
if tx, err := db.Begin(); err == nil {
|
||||
scope.db.db = interface{}(tx).(sqlCommon)
|
||||
scope.InstanceSet("gorm:started_transaction", true)
|
||||
|
@ -379,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
|||
return scope
|
||||
}
|
||||
|
||||
// SelectAttrs retur nselected attributes
|
||||
func (scope *Scope) SelectAttrs() []string {
|
||||
if scope.selectAttrs == nil {
|
||||
attrs := []string{}
|
||||
|
@ -398,6 +406,7 @@ func (scope *Scope) SelectAttrs() []string {
|
|||
return *scope.selectAttrs
|
||||
}
|
||||
|
||||
// OmitAttrs return omited attributes
|
||||
func (scope *Scope) OmitAttrs() []string {
|
||||
return scope.Search.omits
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||
}
|
||||
|
||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||
var notEqualSql string
|
||||
var notEqualSQL string
|
||||
var primaryKey = scope.PrimaryKey()
|
||||
|
||||
switch value := clause["query"].(type) {
|
||||
|
@ -87,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
||||
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||
} else {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||
|
@ -139,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = scanner.Value()
|
||||
}
|
||||
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
|
||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -173,7 +173,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) whereSql() (sql string) {
|
||||
func (scope *Scope) whereSQL() (sql string) {
|
||||
var (
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryConditions, andConditions, orConditions []string
|
||||
|
@ -209,28 +209,28 @@ func (scope *Scope) whereSql() (sql string) {
|
|||
}
|
||||
}
|
||||
|
||||
orSql := strings.Join(orConditions, " OR ")
|
||||
combinedSql := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSql) > 0 {
|
||||
if len(orSql) > 0 {
|
||||
combinedSql = combinedSql + " OR " + orSql
|
||||
orSQL := strings.Join(orConditions, " OR ")
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) > 0 {
|
||||
if len(orSQL) > 0 {
|
||||
combinedSQL = combinedSQL + " OR " + orSQL
|
||||
}
|
||||
} else {
|
||||
combinedSql = orSql
|
||||
combinedSQL = orSQL
|
||||
}
|
||||
|
||||
if len(primaryConditions) > 0 {
|
||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||
if len(combinedSql) > 0 {
|
||||
sql = sql + " AND (" + combinedSql + ")"
|
||||
if len(combinedSQL) > 0 {
|
||||
sql = sql + " AND (" + combinedSQL + ")"
|
||||
}
|
||||
} else if len(combinedSql) > 0 {
|
||||
sql = "WHERE " + combinedSql
|
||||
} else if len(combinedSQL) > 0 {
|
||||
sql = "WHERE " + combinedSQL
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) selectSql() string {
|
||||
func (scope *Scope) selectSQL() string {
|
||||
if len(scope.Search.selects) == 0 {
|
||||
if len(scope.Search.joinConditions) > 0 {
|
||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||
|
@ -240,25 +240,25 @@ func (scope *Scope) selectSql() string {
|
|||
return scope.buildSelectQuery(scope.Search.selects)
|
||||
}
|
||||
|
||||
func (scope *Scope) orderSql() string {
|
||||
func (scope *Scope) orderSQL() string {
|
||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||
return ""
|
||||
}
|
||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||
}
|
||||
|
||||
func (scope *Scope) limitAndOffsetSql() string {
|
||||
func (scope *Scope) limitAndOffsetSQL() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSql() string {
|
||||
func (scope *Scope) groupSQL() string {
|
||||
if len(scope.Search.group) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " GROUP BY " + scope.Search.group
|
||||
}
|
||||
|
||||
func (scope *Scope) havingSql() string {
|
||||
func (scope *Scope) havingSQL() string {
|
||||
if len(scope.Search.havingConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
@ -270,15 +270,15 @@ func (scope *Scope) havingSql() string {
|
|||
}
|
||||
}
|
||||
|
||||
combinedSql := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSql) == 0 {
|
||||
combinedSQL := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSQL) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return " HAVING " + combinedSql
|
||||
return " HAVING " + combinedSQL
|
||||
}
|
||||
|
||||
func (scope *Scope) joinsSql() string {
|
||||
func (scope *Scope) joinsSQL() string {
|
||||
var joinConditions []string
|
||||
for _, clause := range scope.Search.joinConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
|
@ -289,11 +289,11 @@ func (scope *Scope) joinsSql() string {
|
|||
return strings.Join(joinConditions, " ") + " "
|
||||
}
|
||||
|
||||
func (scope *Scope) prepareQuerySql() {
|
||||
func (scope *Scope) prepareQuerySQL() {
|
||||
if scope.Search.raw {
|
||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -345,15 +345,15 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (resul
|
|||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySql()
|
||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||
scope.prepareQuerySql()
|
||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||
scope.prepareQuerySQL()
|
||||
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) initialize() *Scope {
|
||||
|
@ -404,8 +404,8 @@ func (scope *Scope) typeName() string {
|
|||
|
||||
// trace print sql log
|
||||
func (scope *Scope) trace(t time.Time) {
|
||||
if len(scope.Sql) > 0 {
|
||||
scope.db.slog(scope.Sql, t, scope.SqlVars...)
|
||||
if len(scope.SQL) > 0 {
|
||||
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -599,7 +599,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|||
sqlCreate = "CREATE UNIQUE INDEX"
|
||||
}
|
||||
|
||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
|
||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
|
|
|
@ -93,7 +93,7 @@ func (s *search) Offset(offset int) *search {
|
|||
}
|
||||
|
||||
func (s *search) Group(query string) *search {
|
||||
s.group = s.getInterfaceAsSql(query)
|
||||
s.group = s.getInterfaceAsSQL(query)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
|
||||
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
||||
switch value.(type) {
|
||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
str = fmt.Sprintf("%v", value)
|
||||
default:
|
||||
s.db.AddError(InvalidSql)
|
||||
s.db.AddError(ErrInvalidSQL)
|
||||
}
|
||||
|
||||
if str == "-1" {
|
||||
|
|
11
utils.go
11
utils.go
|
@ -9,8 +9,19 @@ import (
|
|||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NowFunc returns current time, this function is exported in order to be able
|
||||
// to give the flexibility to the developer to customize it according to their
|
||||
// needs
|
||||
//
|
||||
// e.g: return time.Now().UTC()
|
||||
//
|
||||
var NowFunc = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// Copied from golint
|
||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
var commonInitialismsReplacer *strings.Replacer
|
||||
|
|
Loading…
Reference in New Issue