forked from mirror/gorm
Refactor based on golint
This commit is contained in:
parent
3055bad1e8
commit
ec110657da
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultCallback hold default callbacks defined by gorm
|
// DefaultCallback default callbacks defined by gorm
|
||||||
var DefaultCallback = &Callback{}
|
var DefaultCallback = &Callback{}
|
||||||
|
|
||||||
// Callback contains callbacks that used when CURD objects
|
// Callback contains callbacks that used when CURD objects
|
||||||
|
|
|
@ -85,14 +85,14 @@ func createCallback(scope *Scope) {
|
||||||
returningColumn = scope.Quote(primaryField.DBName)
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn)
|
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"INSERT INTO %v DEFAULT VALUES%v%v",
|
"INSERT INTO %v DEFAULT VALUES%v%v",
|
||||||
quotedTableName,
|
quotedTableName,
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
|
@ -101,13 +101,13 @@ func createCallback(scope *Scope) {
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
strings.Join(placeholders, ","),
|
strings.Join(placeholders, ","),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute create sql
|
// execute create sql
|
||||||
if lastInsertIdReturningSuffix == "" || primaryField == nil {
|
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
// set rows affected count
|
// set rows affected count
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ func createCallback(scope *Scope) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
scope.db.RowsAffected = 1
|
scope.db.RowsAffected = 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,15 +48,15 @@ func queryCallback(scope *Scope) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySQL()
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
scope.db.RowsAffected = 0
|
scope.db.RowsAffected = 0
|
||||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||||
scope.Sql += addExtraSpaceIfExist(fmt.Sprint(str))
|
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
|
@ -80,7 +80,7 @@ func queryCallback(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.db.RowsAffected == 0 && !isSlice {
|
if scope.db.RowsAffected == 0 && !isSlice {
|
||||||
scope.Err(RecordNotFound)
|
scope.Err(ErrRecordNotFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ type Dialect interface {
|
||||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIdReturningSuffix(tableName, columnName string) string
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
|
|
|
@ -132,6 +132,6 @@ func (commonDialect) SelectFromDummyTable() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string {
|
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,11 +106,11 @@ func (s postgres) currentDatabase() (name string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string {
|
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (postgres) SupportLastInsertId() bool {
|
func (postgres) SupportLastInsertID() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -145,6 +145,6 @@ func (mssql) SelectFromDummyTable() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIdReturningSuffix(tableName, columnName string) string {
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
16
errors.go
16
errors.go
|
@ -6,24 +6,31 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
RecordNotFound = errors.New("record not found")
|
// ErrRecordNotFound record not found, happens when you are looking up with a struct, and haven't find any matched data
|
||||||
InvalidSql = errors.New("invalid sql")
|
ErrRecordNotFound = errors.New("record not found")
|
||||||
NoValidTransaction = errors.New("no valid transaction")
|
// ErrInvalidSQL invalid SQL, happens when you passed invalid SQL
|
||||||
CantStartTransaction = errors.New("can't start transaction")
|
ErrInvalidSQL = errors.New("invalid SQL")
|
||||||
|
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||||
|
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||||
|
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||||
|
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||||
)
|
)
|
||||||
|
|
||||||
type errorsInterface interface {
|
type errorsInterface interface {
|
||||||
GetErrors() []error
|
GetErrors() []error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Errors contains all happened errors
|
||||||
type Errors struct {
|
type Errors struct {
|
||||||
errors []error
|
errors []error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetErrors get all happened errors
|
||||||
func (errs Errors) GetErrors() []error {
|
func (errs Errors) GetErrors() []error {
|
||||||
return errs.errors
|
return errs.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add add an error
|
||||||
func (errs *Errors) Add(err error) {
|
func (errs *Errors) Add(err error) {
|
||||||
if errors, ok := err.(errorsInterface); ok {
|
if errors, ok := err.(errorsInterface); ok {
|
||||||
for _, err := range errors.GetErrors() {
|
for _, err := range errors.GetErrors() {
|
||||||
|
@ -39,6 +46,7 @@ func (errs *Errors) Add(err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Error format happened errors
|
||||||
func (errs Errors) Error() string {
|
func (errs Errors) Error() string {
|
||||||
var errors = []string{}
|
var errors = []string{}
|
||||||
for _, e := range errs.errors {
|
for _, e := range errs.errors {
|
||||||
|
|
2
field.go
2
field.go
|
@ -7,12 +7,14 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Field model field definition
|
||||||
type Field struct {
|
type Field struct {
|
||||||
*StructField
|
*StructField
|
||||||
IsBlank bool
|
IsBlank bool
|
||||||
Field reflect.Value
|
Field reflect.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set set a value to the field
|
||||||
func (field *Field) Set(value interface{}) (err error) {
|
func (field *Field) Set(value interface{}) (err error) {
|
||||||
if !field.Field.IsValid() {
|
if !field.Field.IsValid() {
|
||||||
return errors.New("field value not valid")
|
return errors.New("field value not valid")
|
||||||
|
|
|
@ -7,40 +7,54 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||||
type JoinTableHandlerInterface interface {
|
type JoinTableHandlerInterface interface {
|
||||||
|
// initialize join table handler
|
||||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||||
|
// Table return join table's table name
|
||||||
Table(db *DB) string
|
Table(db *DB) string
|
||||||
|
// Add create relationship in join table for source and destination
|
||||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||||
|
// Delete delete relationship in join table for sources
|
||||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||||
|
// JoinWith query with `Join` conditions
|
||||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
|
// SourceForeignKeys return source foreign keys
|
||||||
SourceForeignKeys() []JoinTableForeignKey
|
SourceForeignKeys() []JoinTableForeignKey
|
||||||
|
// DestinationForeignKeys return destination foreign keys
|
||||||
DestinationForeignKeys() []JoinTableForeignKey
|
DestinationForeignKeys() []JoinTableForeignKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinTableForeignKey join table foreign key struct
|
||||||
type JoinTableForeignKey struct {
|
type JoinTableForeignKey struct {
|
||||||
DBName string
|
DBName string
|
||||||
AssociationDBName string
|
AssociationDBName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinTableSource is a struct that contains model type and foreign keys
|
||||||
type JoinTableSource struct {
|
type JoinTableSource struct {
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
ForeignKeys []JoinTableForeignKey
|
ForeignKeys []JoinTableForeignKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinTableHandler default join table handler
|
||||||
type JoinTableHandler struct {
|
type JoinTableHandler struct {
|
||||||
TableName string `sql:"-"`
|
TableName string `sql:"-"`
|
||||||
Source JoinTableSource `sql:"-"`
|
Source JoinTableSource `sql:"-"`
|
||||||
Destination JoinTableSource `sql:"-"`
|
Destination JoinTableSource `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SourceForeignKeys return source foreign keys
|
||||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||||
return s.Source.ForeignKeys
|
return s.Source.ForeignKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DestinationForeignKeys return destination foreign keys
|
||||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||||
return s.Destination.ForeignKeys
|
return s.Destination.ForeignKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup initialize a default join table handler
|
||||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||||
s.TableName = tableName
|
s.TableName = tableName
|
||||||
|
|
||||||
|
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Table return join table's table name
|
||||||
func (s JoinTableHandler) Table(db *DB) string {
|
func (s JoinTableHandler) Table(db *DB) string {
|
||||||
return s.TableName
|
return s.TableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||||
values := map[string]interface{}{}
|
values := map[string]interface{}{}
|
||||||
|
|
||||||
for _, source := range sources {
|
for _, source := range sources {
|
||||||
|
@ -89,9 +104,10 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
|
// Add create relationship in join table for source and destination
|
||||||
|
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||||
scope := db.NewScope("")
|
scope := db.NewScope("")
|
||||||
searchMap := s.GetSearchMap(db, source1, source2)
|
searchMap := s.getSearchMap(db, source, destination)
|
||||||
|
|
||||||
var assignColumns, binVars, conditions []string
|
var assignColumns, binVars, conditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
|
@ -120,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
|
||||||
return db.Exec(sql, values...).Error
|
return db.Exec(sql, values...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete delete relationship in join table for sources
|
||||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||||
var (
|
var (
|
||||||
scope = db.NewScope(nil)
|
scope = db.NewScope(nil)
|
||||||
|
@ -127,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||||
values []interface{}
|
values []interface{}
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value := range s.GetSearchMap(db, sources...) {
|
for key, value := range s.getSearchMap(db, sources...) {
|
||||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
@ -135,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
|
||||||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinWith query with `Join` conditions
|
||||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||||
var (
|
var (
|
||||||
scope = db.NewScope(source)
|
scope = db.NewScope(source)
|
||||||
|
|
16
logger.go
16
logger.go
|
@ -11,23 +11,25 @@ import (
|
||||||
"unicode"
|
"unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||||
|
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||||
|
)
|
||||||
|
|
||||||
type logger interface {
|
type logger interface {
|
||||||
Print(v ...interface{})
|
Print(v ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogWriter interface {
|
type logWriter interface {
|
||||||
Println(v ...interface{})
|
Println(v ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger default logger
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
LogWriter
|
logWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
// Print format & print log
|
||||||
|
|
||||||
// Format log
|
|
||||||
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
|
||||||
|
|
||||||
func (logger Logger) Print(values ...interface{}) {
|
func (logger Logger) Print(values ...interface{}) {
|
||||||
if len(values) > 1 {
|
if len(values) > 1 {
|
||||||
level := values[0]
|
level := values[0]
|
||||||
|
|
70
main.go
70
main.go
|
@ -6,19 +6,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NowFunc returns current time, this function is exported in order to be able
|
// DB contains information for current db connection
|
||||||
// to give the flexibility to the developer to customize it according to their
|
|
||||||
// needs
|
|
||||||
//
|
|
||||||
// e.g: return time.Now().UTC()
|
|
||||||
//
|
|
||||||
var NowFunc = func() time.Time {
|
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Error error
|
Error error
|
||||||
|
@ -36,6 +26,14 @@ type DB struct {
|
||||||
joinTableHandlers map[string]JoinTableHandler
|
joinTableHandlers map[string]JoinTableHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Open open a new db connection, need to import driver first, for example:
|
||||||
|
//
|
||||||
|
// import _ "github.com/go-sql-driver/mysql"
|
||||||
|
// func main() {
|
||||||
|
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||||
|
// }
|
||||||
|
// GORM has wrapped some drivers, for easier to remember its name, so you could import the mysql driver with
|
||||||
|
// import _ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
func Open(dialect string, args ...interface{}) (*DB, error) {
|
func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
var db DB
|
var db DB
|
||||||
var err error
|
var err error
|
||||||
|
@ -44,7 +42,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
err = errors.New("invalid database source")
|
err = errors.New("invalid database source")
|
||||||
} else {
|
} else {
|
||||||
var source string
|
var source string
|
||||||
var dbSql sqlCommon
|
var dbSQL sqlCommon
|
||||||
|
|
||||||
switch value := args[0].(type) {
|
switch value := args[0].(type) {
|
||||||
case string:
|
case string:
|
||||||
|
@ -55,19 +53,19 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
driver = value
|
driver = value
|
||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
}
|
}
|
||||||
dbSql, err = sql.Open(driver, source)
|
dbSQL, err = sql.Open(driver, source)
|
||||||
case sqlCommon:
|
case sqlCommon:
|
||||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||||
dbSql = value
|
dbSQL = value
|
||||||
}
|
}
|
||||||
|
|
||||||
db = DB{
|
db = DB{
|
||||||
dialect: newDialect(dialect, dbSql.(*sql.DB)),
|
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
callbacks: DefaultCallback,
|
callbacks: DefaultCallback,
|
||||||
source: source,
|
source: source,
|
||||||
values: map[string]interface{}{},
|
values: map[string]interface{}{},
|
||||||
db: dbSql,
|
db: dbSQL,
|
||||||
}
|
}
|
||||||
db.parent = &db
|
db.parent = &db
|
||||||
|
|
||||||
|
@ -79,14 +77,17 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
return &db, err
|
return &db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close close current db connection
|
||||||
func (s *DB) Close() error {
|
func (s *DB) Close() error {
|
||||||
return s.parent.db.(*sql.DB).Close()
|
return s.parent.db.(*sql.DB).Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB get `*sql.DB` from current connection
|
||||||
func (s *DB) DB() *sql.DB {
|
func (s *DB) DB() *sql.DB {
|
||||||
return s.db.(*sql.DB)
|
return s.db.(*sql.DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New initialize a new db connection without any search conditions
|
||||||
func (s *DB) New() *DB {
|
func (s *DB) New() *DB {
|
||||||
clone := s.clone()
|
clone := s.clone()
|
||||||
clone.search = nil
|
clone.search = nil
|
||||||
|
@ -94,29 +95,34 @@ func (s *DB) New() *DB {
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewScope create scope for callbacks, including DB's search information
|
// NewScope create a scope for current operation
|
||||||
func (s *DB) NewScope(value interface{}) *Scope {
|
func (s *DB) NewScope(value interface{}) *Scope {
|
||||||
dbClone := s.clone()
|
dbClone := s.clone()
|
||||||
dbClone.Value = value
|
dbClone.Value = value
|
||||||
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
// CommonDB return the underlying sql.DB or sql.Tx instance.
|
||||||
// Use of this method is discouraged. It's mainly intended to allow
|
// Use of this method is discouraged. It's mainly intended to allow
|
||||||
// coexistence with legacy non-GORM code.
|
// coexistence with legacy non-GORM code.
|
||||||
func (s *DB) CommonDB() sqlCommon {
|
func (s *DB) CommonDB() sqlCommon {
|
||||||
return s.db
|
return s.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Callback return Callbacks container, you could add/remove/change callbacks with it
|
||||||
|
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||||
|
// Refer: https://jinzhu.github.io/gorm/development.html#callbacks for more
|
||||||
func (s *DB) Callback() *Callback {
|
func (s *DB) Callback() *Callback {
|
||||||
s.parent.callbacks = s.parent.callbacks.clone()
|
s.parent.callbacks = s.parent.callbacks.clone()
|
||||||
return s.parent.callbacks
|
return s.parent.callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SetLogger(l logger) {
|
// SetLogger replace default logger
|
||||||
s.logger = l
|
func (s *DB) SetLogger(log logger) {
|
||||||
|
s.logger = log
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||||
func (s *DB) LogMode(enable bool) *DB {
|
func (s *DB) LogMode(enable bool) *DB {
|
||||||
if enable {
|
if enable {
|
||||||
s.logMode = 2
|
s.logMode = 2
|
||||||
|
@ -126,51 +132,65 @@ func (s *DB) LogMode(enable bool) *DB {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SingularTable use singular table by default
|
||||||
func (s *DB) SingularTable(enable bool) {
|
func (s *DB) SingularTable(enable bool) {
|
||||||
modelStructsMap = newModelStructsMap()
|
modelStructsMap = newModelStructsMap()
|
||||||
s.parent.singularTable = enable
|
s.parent.singularTable = enable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Where return a new relation, accepts use `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
|
||||||
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Where(query, args...).db
|
return s.clone().search.Where(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Or match before conditions or this one, similar to `Where`
|
||||||
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Or(query, args...).db
|
return s.clone().search.Or(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not don't match current conditions, similar to `Where`
|
||||||
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Not(query, args...).db
|
return s.clone().search.Not(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Limit specify the number of records to be retrieved
|
||||||
func (s *DB) Limit(limit int) *DB {
|
func (s *DB) Limit(limit int) *DB {
|
||||||
return s.clone().search.Limit(limit).db
|
return s.clone().search.Limit(limit).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Offset specify the number of records to skip before starting to return the records
|
||||||
func (s *DB) Offset(offset int) *DB {
|
func (s *DB) Offset(offset int) *DB {
|
||||||
return s.clone().search.Offset(offset).db
|
return s.clone().search.Offset(offset).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Order specify order when retrieve records from database, pass `true` as the second argument to overwrite `Order` conditions
|
||||||
func (s *DB) Order(value string, reorder ...bool) *DB {
|
func (s *DB) Order(value string, reorder ...bool) *DB {
|
||||||
return s.clone().search.Order(value, reorder...).db
|
return s.clone().search.Order(value, reorder...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Select When querying, specify fields that you want to retrieve from database, by default, will select all fields;
|
||||||
|
// When creating/updating, specify fields that you want to save to database
|
||||||
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
||||||
return s.clone().search.Select(query, args...).db
|
return s.clone().search.Select(query, args...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Omit specify fields that you want to ignore when save to database when creating/updating
|
||||||
func (s *DB) Omit(columns ...string) *DB {
|
func (s *DB) Omit(columns ...string) *DB {
|
||||||
return s.clone().search.Omit(columns...).db
|
return s.clone().search.Omit(columns...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Group specify the group method on the find
|
||||||
func (s *DB) Group(query string) *DB {
|
func (s *DB) Group(query string) *DB {
|
||||||
return s.clone().search.Group(query).db
|
return s.clone().search.Group(query).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Having specify HAVING conditions for GROUP BY
|
||||||
func (s *DB) Having(query string, values ...interface{}) *DB {
|
func (s *DB) Having(query string, values ...interface{}) *DB {
|
||||||
return s.clone().search.Having(query, values...).db
|
return s.clone().search.Having(query, values...).db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Joins specify Joins conditions
|
||||||
|
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||||
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
||||||
return s.clone().search.Joins(query, args...).db
|
return s.clone().search.Joins(query, args...).db
|
||||||
}
|
}
|
||||||
|
@ -352,7 +372,7 @@ func (s *DB) Begin() *DB {
|
||||||
c.db = interface{}(tx).(sqlCommon)
|
c.db = interface{}(tx).(sqlCommon)
|
||||||
c.AddError(err)
|
c.AddError(err)
|
||||||
} else {
|
} else {
|
||||||
c.AddError(CantStartTransaction)
|
c.AddError(ErrCantStartTransaction)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
@ -361,7 +381,7 @@ func (s *DB) Commit() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
s.AddError(db.Commit())
|
s.AddError(db.Commit())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(NoValidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -370,7 +390,7 @@ func (s *DB) Rollback() *DB {
|
||||||
if db, ok := s.db.(sqlTx); ok {
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
s.AddError(db.Rollback())
|
s.AddError(db.Rollback())
|
||||||
} else {
|
} else {
|
||||||
s.AddError(NoValidTransaction)
|
s.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -380,7 +400,7 @@ func (s *DB) NewRecord(value interface{}) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) RecordNotFound() bool {
|
func (s *DB) RecordNotFound() bool {
|
||||||
return s.Error == RecordNotFound
|
return s.Error == ErrRecordNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTable create table for models
|
// CreateTable create table for models
|
||||||
|
@ -541,7 +561,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
||||||
|
|
||||||
func (s *DB) AddError(err error) error {
|
func (s *DB) AddError(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != RecordNotFound {
|
if err != ErrRecordNotFound {
|
||||||
if s.logMode == 0 {
|
if s.logMode == 0 {
|
||||||
go s.print(fileWithLineNum(), err)
|
go s.print(fileWithLineNum(), err)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -479,7 +479,7 @@ func TestRaw(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
||||||
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
|
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
|
||||||
t.Error("Raw sql to update records")
|
t.Error("Raw sql to update records")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -709,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
|
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
|
||||||
t.Errorf("Should have found existing record")
|
t.Errorf("Should have found existing record")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
1
model.go
1
model.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// Model base model definition, including `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
|
||||||
type Model struct {
|
type Model struct {
|
||||||
ID uint `gorm:"primary_key"`
|
ID uint `gorm:"primary_key"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DefaultTableNameHandler default table name handler
|
||||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||||
return defaultTableName
|
return defaultTableName
|
||||||
}
|
}
|
||||||
|
@ -39,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
|
||||||
|
|
||||||
var modelStructsMap = newModelStructsMap()
|
var modelStructsMap = newModelStructsMap()
|
||||||
|
|
||||||
|
// ModelStruct model definition
|
||||||
type ModelStruct struct {
|
type ModelStruct struct {
|
||||||
PrimaryFields []*StructField
|
PrimaryFields []*StructField
|
||||||
StructFields []*StructField
|
StructFields []*StructField
|
||||||
|
@ -46,10 +48,12 @@ type ModelStruct struct {
|
||||||
defaultTableName string
|
defaultTableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TableName get model's table name
|
||||||
func (s *ModelStruct) TableName(db *DB) string {
|
func (s *ModelStruct) TableName(db *DB) string {
|
||||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StructField model field's struct definition
|
||||||
type StructField struct {
|
type StructField struct {
|
||||||
DBName string
|
DBName string
|
||||||
Name string
|
Name string
|
||||||
|
@ -506,6 +510,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
return &modelStruct
|
return &modelStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStructFields get model's field structs
|
||||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||||
return scope.GetModelStruct().StructFields
|
return scope.GetModelStruct().StructFields
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
|
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -981,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1038,7 +1038,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
|
||||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
|
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
33
scope.go
33
scope.go
|
@ -10,11 +10,12 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Scope contain any information of current operation when you perform any operation on the database
|
||||||
type Scope struct {
|
type Scope struct {
|
||||||
Search *search
|
Search *search
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Sql string
|
SQL string
|
||||||
SqlVars []interface{}
|
SQLVars []interface{}
|
||||||
db *DB
|
db *DB
|
||||||
instanceID string
|
instanceID string
|
||||||
primaryKeyField *Field
|
primaryKeyField *Field
|
||||||
|
@ -23,6 +24,7 @@ type Scope struct {
|
||||||
selectAttrs *[]string
|
selectAttrs *[]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IndirectValue return scope's reflect value's indirect value
|
||||||
func (scope *Scope) IndirectValue() reflect.Value {
|
func (scope *Scope) IndirectValue() reflect.Value {
|
||||||
return indirect(reflect.ValueOf(scope.Value))
|
return indirect(reflect.ValueOf(scope.Value))
|
||||||
}
|
}
|
||||||
|
@ -43,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB return scope's DB connection
|
||||||
func (scope *Scope) DB() *DB {
|
func (scope *Scope) DB() *DB {
|
||||||
return scope.db
|
return scope.db
|
||||||
}
|
}
|
||||||
|
|
||||||
// SqlDB return *sql.DB
|
// SQLDB return *sql.DB
|
||||||
func (scope *Scope) SqlDB() sqlCommon {
|
func (scope *Scope) SQLDB() sqlCommon {
|
||||||
return scope.db.db
|
return scope.db.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,6 +103,7 @@ func (scope *Scope) HasError() bool {
|
||||||
return scope.db.Error != nil
|
return scope.db.Error != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrimaryFields return scope's primary fields
|
||||||
func (scope *Scope) PrimaryFields() (fields []*Field) {
|
func (scope *Scope) PrimaryFields() (fields []*Field) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
|
@ -109,6 +113,7 @@ func (scope *Scope) PrimaryFields() (fields []*Field) {
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
|
||||||
func (scope *Scope) PrimaryField() *Field {
|
func (scope *Scope) PrimaryField() *Field {
|
||||||
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||||
if len(primaryFields) > 1 {
|
if len(primaryFields) > 1 {
|
||||||
|
@ -241,8 +246,8 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
return exp
|
return exp
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.SqlVars = append(scope.SqlVars, value)
|
scope.SQLVars = append(scope.SQLVars, value)
|
||||||
return scope.Dialect().BindVar(len(scope.SqlVars))
|
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
type tabler interface {
|
type tabler interface {
|
||||||
|
@ -282,10 +287,10 @@ func (scope *Scope) QuotedTableName() (name string) {
|
||||||
return scope.Quote(scope.TableName())
|
return scope.Quote(scope.TableName())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CombinedConditionSql get combined condition sql
|
// CombinedConditionSql return combined condition sql
|
||||||
func (scope *Scope) CombinedConditionSql() string {
|
func (scope *Scope) CombinedConditionSql() string {
|
||||||
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
|
return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
|
||||||
scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql()
|
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
|
||||||
}
|
}
|
||||||
|
|
||||||
// FieldByName find gorm.Field with name and db name
|
// FieldByName find gorm.Field with name and db name
|
||||||
|
@ -308,7 +313,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
|
|
||||||
// Raw set sql
|
// Raw set sql
|
||||||
func (scope *Scope) Raw(sql string) *Scope {
|
func (scope *Scope) Raw(sql string) *Scope {
|
||||||
scope.Sql = strings.Replace(sql, "$$", "?", -1)
|
scope.SQL = strings.Replace(sql, "$$", "?", -1)
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,7 +322,7 @@ func (scope *Scope) Exec() *Scope {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||||
scope.db.RowsAffected = count
|
scope.db.RowsAffected = count
|
||||||
}
|
}
|
||||||
|
@ -345,17 +350,19 @@ func (scope *Scope) InstanceID() string {
|
||||||
return scope.instanceID
|
return scope.instanceID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstanceSet set value for current instance, but not for associations
|
||||||
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
|
||||||
return scope.Set(name+scope.InstanceID(), value)
|
return scope.Set(name+scope.InstanceID(), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstanceGet get setting from current instance
|
||||||
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
|
||||||
return scope.Get(name + scope.InstanceID())
|
return scope.Get(name + scope.InstanceID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin start a transaction
|
// Begin start a transaction
|
||||||
func (scope *Scope) Begin() *Scope {
|
func (scope *Scope) Begin() *Scope {
|
||||||
if db, ok := scope.SqlDB().(sqlDb); ok {
|
if db, ok := scope.SQLDB().(sqlDb); ok {
|
||||||
if tx, err := db.Begin(); err == nil {
|
if tx, err := db.Begin(); err == nil {
|
||||||
scope.db.db = interface{}(tx).(sqlCommon)
|
scope.db.db = interface{}(tx).(sqlCommon)
|
||||||
scope.InstanceSet("gorm:started_transaction", true)
|
scope.InstanceSet("gorm:started_transaction", true)
|
||||||
|
@ -379,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectAttrs retur nselected attributes
|
||||||
func (scope *Scope) SelectAttrs() []string {
|
func (scope *Scope) SelectAttrs() []string {
|
||||||
if scope.selectAttrs == nil {
|
if scope.selectAttrs == nil {
|
||||||
attrs := []string{}
|
attrs := []string{}
|
||||||
|
@ -398,6 +406,7 @@ func (scope *Scope) SelectAttrs() []string {
|
||||||
return *scope.selectAttrs
|
return *scope.selectAttrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OmitAttrs return omited attributes
|
||||||
func (scope *Scope) OmitAttrs() []string {
|
func (scope *Scope) OmitAttrs() []string {
|
||||||
return scope.Search.omits
|
return scope.Search.omits
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
var notEqualSql string
|
var notEqualSQL string
|
||||||
var primaryKey = scope.PrimaryKey()
|
var primaryKey = scope.PrimaryKey()
|
||||||
|
|
||||||
switch value := clause["query"].(type) {
|
switch value := clause["query"].(type) {
|
||||||
|
@ -87,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
|
||||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
|
||||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||||
notEqualSql = fmt.Sprintf("NOT (%v)", value)
|
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
|
||||||
} else {
|
} else {
|
||||||
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
|
||||||
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
|
||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
|
||||||
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
|
||||||
|
@ -139,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = scanner.Value()
|
arg, _ = scanner.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
|
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -173,7 +173,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) whereSql() (sql string) {
|
func (scope *Scope) whereSQL() (sql string) {
|
||||||
var (
|
var (
|
||||||
quotedTableName = scope.QuotedTableName()
|
quotedTableName = scope.QuotedTableName()
|
||||||
primaryConditions, andConditions, orConditions []string
|
primaryConditions, andConditions, orConditions []string
|
||||||
|
@ -209,28 +209,28 @@ func (scope *Scope) whereSql() (sql string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
orSql := strings.Join(orConditions, " OR ")
|
orSQL := strings.Join(orConditions, " OR ")
|
||||||
combinedSql := strings.Join(andConditions, " AND ")
|
combinedSQL := strings.Join(andConditions, " AND ")
|
||||||
if len(combinedSql) > 0 {
|
if len(combinedSQL) > 0 {
|
||||||
if len(orSql) > 0 {
|
if len(orSQL) > 0 {
|
||||||
combinedSql = combinedSql + " OR " + orSql
|
combinedSQL = combinedSQL + " OR " + orSQL
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
combinedSql = orSql
|
combinedSQL = orSQL
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(primaryConditions) > 0 {
|
if len(primaryConditions) > 0 {
|
||||||
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
|
||||||
if len(combinedSql) > 0 {
|
if len(combinedSQL) > 0 {
|
||||||
sql = sql + " AND (" + combinedSql + ")"
|
sql = sql + " AND (" + combinedSQL + ")"
|
||||||
}
|
}
|
||||||
} else if len(combinedSql) > 0 {
|
} else if len(combinedSQL) > 0 {
|
||||||
sql = "WHERE " + combinedSql
|
sql = "WHERE " + combinedSQL
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) selectSql() string {
|
func (scope *Scope) selectSQL() string {
|
||||||
if len(scope.Search.selects) == 0 {
|
if len(scope.Search.selects) == 0 {
|
||||||
if len(scope.Search.joinConditions) > 0 {
|
if len(scope.Search.joinConditions) > 0 {
|
||||||
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
return fmt.Sprintf("%v.*", scope.QuotedTableName())
|
||||||
|
@ -240,25 +240,25 @@ func (scope *Scope) selectSql() string {
|
||||||
return scope.buildSelectQuery(scope.Search.selects)
|
return scope.buildSelectQuery(scope.Search.selects)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) orderSql() string {
|
func (scope *Scope) orderSQL() string {
|
||||||
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) limitAndOffsetSql() string {
|
func (scope *Scope) limitAndOffsetSQL() string {
|
||||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) groupSql() string {
|
func (scope *Scope) groupSQL() string {
|
||||||
if len(scope.Search.group) == 0 {
|
if len(scope.Search.group) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return " GROUP BY " + scope.Search.group
|
return " GROUP BY " + scope.Search.group
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) havingSql() string {
|
func (scope *Scope) havingSQL() string {
|
||||||
if len(scope.Search.havingConditions) == 0 {
|
if len(scope.Search.havingConditions) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -270,15 +270,15 @@ func (scope *Scope) havingSql() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
combinedSql := strings.Join(andConditions, " AND ")
|
combinedSQL := strings.Join(andConditions, " AND ")
|
||||||
if len(combinedSql) == 0 {
|
if len(combinedSQL) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return " HAVING " + combinedSql
|
return " HAVING " + combinedSQL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) joinsSql() string {
|
func (scope *Scope) joinsSQL() string {
|
||||||
var joinConditions []string
|
var joinConditions []string
|
||||||
for _, clause := range scope.Search.joinConditions {
|
for _, clause := range scope.Search.joinConditions {
|
||||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||||
|
@ -289,11 +289,11 @@ func (scope *Scope) joinsSql() string {
|
||||||
return strings.Join(joinConditions, " ") + " "
|
return strings.Join(joinConditions, " ") + " "
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) prepareQuerySql() {
|
func (scope *Scope) prepareQuerySQL() {
|
||||||
if scope.Search.raw {
|
if scope.Search.raw {
|
||||||
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -345,15 +345,15 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (resul
|
||||||
func (scope *Scope) row() *sql.Row {
|
func (scope *Scope) row() *sql.Row {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySQL()
|
||||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySQL()
|
||||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) initialize() *Scope {
|
func (scope *Scope) initialize() *Scope {
|
||||||
|
@ -404,8 +404,8 @@ func (scope *Scope) typeName() string {
|
||||||
|
|
||||||
// trace print sql log
|
// trace print sql log
|
||||||
func (scope *Scope) trace(t time.Time) {
|
func (scope *Scope) trace(t time.Time) {
|
||||||
if len(scope.Sql) > 0 {
|
if len(scope.SQL) > 0 {
|
||||||
scope.db.slog(scope.Sql, t, scope.SqlVars...)
|
scope.db.slog(scope.SQL, t, scope.SQLVars...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -599,7 +599,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
sqlCreate = "CREATE UNIQUE INDEX"
|
sqlCreate = "CREATE UNIQUE INDEX"
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
|
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||||
|
|
|
@ -93,7 +93,7 @@ func (s *search) Offset(offset int) *search {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) Group(query string) *search {
|
func (s *search) Group(query string) *search {
|
||||||
s.group = s.getInterfaceAsSql(query)
|
s.group = s.getInterfaceAsSQL(query)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
|
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
||||||
switch value.(type) {
|
switch value.(type) {
|
||||||
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
str = fmt.Sprintf("%v", value)
|
str = fmt.Sprintf("%v", value)
|
||||||
default:
|
default:
|
||||||
s.db.AddError(InvalidSql)
|
s.db.AddError(ErrInvalidSQL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if str == "-1" {
|
if str == "-1" {
|
||||||
|
|
11
utils.go
11
utils.go
|
@ -9,8 +9,19 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NowFunc returns current time, this function is exported in order to be able
|
||||||
|
// to give the flexibility to the developer to customize it according to their
|
||||||
|
// needs
|
||||||
|
//
|
||||||
|
// e.g: return time.Now().UTC()
|
||||||
|
//
|
||||||
|
var NowFunc = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
// Copied from golint
|
// Copied from golint
|
||||||
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
var commonInitialismsReplacer *strings.Replacer
|
var commonInitialismsReplacer *strings.Replacer
|
||||||
|
|
Loading…
Reference in New Issue