Refactor based on golint

This commit is contained in:
Jinzhu 2016-03-07 14:54:20 +08:00
parent 3055bad1e8
commit ec110657da
20 changed files with 185 additions and 109 deletions

View File

@ -4,7 +4,7 @@ import (
"fmt"
)
// defaultCallback hold default callbacks defined by gorm
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
// Callback contains callbacks that used when CURD objects

View File

@ -85,14 +85,14 @@ func createCallback(scope *Scope) {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn)
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v DEFAULT VALUES%v%v",
quotedTableName,
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
} else {
scope.Raw(fmt.Sprintf(
@ -101,13 +101,13 @@ func createCallback(scope *Scope) {
strings.Join(columns, ","),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
// execute create sql
if lastInsertIdReturningSuffix == "" || primaryField == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@ -119,7 +119,7 @@ func createCallback(scope *Scope) {
}
}
} else {
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
scope.db.RowsAffected = 1
}
}

View File

@ -48,15 +48,15 @@ func queryCallback(scope *Scope) {
return
}
scope.prepareQuerySql()
scope.prepareQuerySQL()
if !scope.HasError() {
scope.db.RowsAffected = 0
if str, ok := scope.Get("gorm:query_option"); ok {
scope.Sql += addExtraSpaceIfExist(fmt.Sprint(str))
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}
if rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()
columns, _ := rows.Columns()
@ -80,7 +80,7 @@ func queryCallback(scope *Scope) {
}
if scope.db.RowsAffected == 0 && !isSlice {
scope.Err(RecordNotFound)
scope.Err(ErrRecordNotFound)
}
}
}

View File

@ -39,7 +39,7 @@ type Dialect interface {
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIdReturningSuffix(tableName, columnName string) string
LastInsertIDReturningSuffix(tableName, columnName string) string
}
var dialectsMap = map[string]Dialect{}

View File

@ -132,6 +132,6 @@ func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string {
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

View File

@ -106,11 +106,11 @@ func (s postgres) currentDatabase() (name string) {
return
}
func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string {
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (postgres) SupportLastInsertId() bool {
func (postgres) SupportLastInsertID() bool {
return false
}

View File

@ -145,6 +145,6 @@ func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIdReturningSuffix(tableName, columnName string) string {
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

View File

@ -6,24 +6,31 @@ import (
)
var (
RecordNotFound = errors.New("record not found")
InvalidSql = errors.New("invalid sql")
NoValidTransaction = errors.New("no valid transaction")
CantStartTransaction = errors.New("can't start transaction")
// ErrRecordNotFound record not found, happens when you are looking up with a struct, and haven't find any matched data
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
ErrCantStartTransaction = errors.New("can't start transaction")
)
type errorsInterface interface {
GetErrors() []error
}
// Errors contains all happened errors
type Errors struct {
errors []error
}
// GetErrors get all happened errors
func (errs Errors) GetErrors() []error {
return errs.errors
}
// Add add an error
func (errs *Errors) Add(err error) {
if errors, ok := err.(errorsInterface); ok {
for _, err := range errors.GetErrors() {
@ -39,6 +46,7 @@ func (errs *Errors) Add(err error) {
}
}
// Error format happened errors
func (errs Errors) Error() string {
var errors = []string{}
for _, e := range errs.errors {

View File

@ -7,12 +7,14 @@ import (
"reflect"
)
// Field model field definition
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}
// Set set a value to the field
func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() {
return errors.New("field value not valid")

View File

@ -7,40 +7,54 @@ import (
"strings"
)
// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
// Table return join table's table name
Table(db *DB) string
// Add create relationship in join table for source and destination
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
// Delete delete relationship in join table for sources
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
DestinationForeignKeys() []JoinTableForeignKey
}
// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
// JoinTableHandler default join table handler
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
// SourceForeignKeys return source foreign keys
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
// DestinationForeignKeys return destination foreign keys
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
// Setup initialize a default join table handler
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
}
}
// Table return join table's table name
func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
for _, source := range sources {
@ -89,9 +104,10 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
return values
}
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
searchMap := s.getSearchMap(db, source, destination)
var assignColumns, binVars, conditions []string
var values []interface{}
@ -120,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
return db.Exec(sql, values...).Error
}
// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
@ -127,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
values []interface{}
)
for key, value := range s.GetSearchMap(db, sources...) {
for key, value := range s.getSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
@ -135,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var (
scope = db.NewScope(source)

View File

@ -11,23 +11,25 @@ import (
"unicode"
)
var (
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
)
type logger interface {
Print(v ...interface{})
}
type LogWriter interface {
type logWriter interface {
Println(v ...interface{})
}
// Logger default logger
type Logger struct {
LogWriter
logWriter
}
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
// Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
// Print format & print log
func (logger Logger) Print(values ...interface{}) {
if len(values) > 1 {
level := values[0]

70
main.go
View File

@ -6,19 +6,9 @@ import (
"fmt"
"reflect"
"strings"
"time"
)
// NowFunc returns current time, this function is exported in order to be able
// to give the flexibility to the developer to customize it according to their
// needs
//
// e.g: return time.Now().UTC()
//
var NowFunc = func() time.Time {
return time.Now()
}
// DB contains information for current db connection
type DB struct {
Value interface{}
Error error
@ -36,6 +26,14 @@ type DB struct {
joinTableHandlers map[string]JoinTableHandler
}
// Open open a new db connection, need to import driver first, for example:
//
// import _ "github.com/go-sql-driver/mysql"
// func main() {
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// }
// GORM has wrapped some drivers, for easier to remember its name, so you could import the mysql driver with
// import _ "github.com/jinzhu/gorm/dialects/mysql"
func Open(dialect string, args ...interface{}) (*DB, error) {
var db DB
var err error
@ -44,7 +42,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
err = errors.New("invalid database source")
} else {
var source string
var dbSql sqlCommon
var dbSQL sqlCommon
switch value := args[0].(type) {
case string:
@ -55,19 +53,19 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
driver = value
source = args[1].(string)
}
dbSql, err = sql.Open(driver, source)
dbSQL, err = sql.Open(driver, source)
case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSql = value
dbSQL = value
}
db = DB{
dialect: newDialect(dialect, dbSql.(*sql.DB)),
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
logger: defaultLogger,
callbacks: DefaultCallback,
source: source,
values: map[string]interface{}{},
db: dbSql,
db: dbSQL,
}
db.parent = &db
@ -79,14 +77,17 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
return &db, err
}
// Close close current db connection
func (s *DB) Close() error {
return s.parent.db.(*sql.DB).Close()
}
// DB get `*sql.DB` from current connection
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
// New initialize a new db connection without any search conditions
func (s *DB) New() *DB {
clone := s.clone()
clone.search = nil
@ -94,29 +95,34 @@ func (s *DB) New() *DB {
return clone
}
// NewScope create scope for callbacks, including DB's search information
// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// CommonDB Return the underlying sql.DB or sql.Tx instance.
// CommonDB return the underlying sql.DB or sql.Tx instance.
// Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon {
return s.db
}
// Callback return Callbacks container, you could add/remove/change callbacks with it
// db.Callback().Create().Register("update_created_at", updateCreated)
// Refer: https://jinzhu.github.io/gorm/development.html#callbacks for more
func (s *DB) Callback() *Callback {
s.parent.callbacks = s.parent.callbacks.clone()
return s.parent.callbacks
}
func (s *DB) SetLogger(l logger) {
s.logger = l
// SetLogger replace default logger
func (s *DB) SetLogger(log logger) {
s.logger = log
}
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
func (s *DB) LogMode(enable bool) *DB {
if enable {
s.logMode = 2
@ -126,51 +132,65 @@ func (s *DB) LogMode(enable bool) *DB {
return s
}
// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap()
s.parent.singularTable = enable
}
// Where return a new relation, accepts use `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db
}
// Or match before conditions or this one, similar to `Where`
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.Or(query, args...).db
}
// Not don't match current conditions, similar to `Where`
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.Not(query, args...).db
}
// Limit specify the number of records to be retrieved
func (s *DB) Limit(limit int) *DB {
return s.clone().search.Limit(limit).db
}
// Offset specify the number of records to skip before starting to return the records
func (s *DB) Offset(offset int) *DB {
return s.clone().search.Offset(offset).db
}
// Order specify order when retrieve records from database, pass `true` as the second argument to overwrite `Order` conditions
func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.Order(value, reorder...).db
}
// Select When querying, specify fields that you want to retrieve from database, by default, will select all fields;
// When creating/updating, specify fields that you want to save to database
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.Select(query, args...).db
}
// Omit specify fields that you want to ignore when save to database when creating/updating
func (s *DB) Omit(columns ...string) *DB {
return s.clone().search.Omit(columns...).db
}
// Group specify the group method on the find
func (s *DB) Group(query string) *DB {
return s.clone().search.Group(query).db
}
// Having specify HAVING conditions for GROUP BY
func (s *DB) Having(query string, values ...interface{}) *DB {
return s.clone().search.Having(query, values...).db
}
// Joins specify Joins conditions
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (s *DB) Joins(query string, args ...interface{}) *DB {
return s.clone().search.Joins(query, args...).db
}
@ -352,7 +372,7 @@ func (s *DB) Begin() *DB {
c.db = interface{}(tx).(sqlCommon)
c.AddError(err)
} else {
c.AddError(CantStartTransaction)
c.AddError(ErrCantStartTransaction)
}
return c
}
@ -361,7 +381,7 @@ func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Commit())
} else {
s.AddError(NoValidTransaction)
s.AddError(ErrInvalidTransaction)
}
return s
}
@ -370,7 +390,7 @@ func (s *DB) Rollback() *DB {
if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Rollback())
} else {
s.AddError(NoValidTransaction)
s.AddError(ErrInvalidTransaction)
}
return s
}
@ -380,7 +400,7 @@ func (s *DB) NewRecord(value interface{}) bool {
}
func (s *DB) RecordNotFound() bool {
return s.Error == RecordNotFound
return s.Error == ErrRecordNotFound
}
// CreateTable create table for models
@ -541,7 +561,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
func (s *DB) AddError(err error) error {
if err != nil {
if err != RecordNotFound {
if err != ErrRecordNotFound {
if s.logMode == 0 {
go s.print(fileWithLineNum(), err)
} else {

View File

@ -479,7 +479,7 @@ func TestRaw(t *testing.T) {
}
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
t.Error("Raw sql to update records")
}
}
@ -709,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) {
}
var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
t.Errorf("Should have found existing record")
}
}

View File

@ -2,6 +2,7 @@ package gorm
import "time"
// Model base model definition, including `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time

View File

@ -12,6 +12,7 @@ import (
"github.com/jinzhu/inflection"
)
// DefaultTableNameHandler default table name handler
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
@ -39,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap {
var modelStructsMap = newModelStructsMap()
// ModelStruct model definition
type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
@ -46,10 +48,12 @@ type ModelStruct struct {
defaultTableName string
}
// TableName get model's table name
func (s *ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName)
}
// StructField model field's struct definition
type StructField struct {
DBName string
Name string
@ -506,6 +510,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
return &modelStruct
}
// GetStructFields get model's field structs
func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}

View File

@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
t.Error(err)
}
}
@ -981,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
t.Error(err)
}
}
@ -1038,7 +1038,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound {
t.Error(err)
}
}

View File

@ -10,11 +10,12 @@ import (
"reflect"
)
// Scope contain any information of current operation when you perform any operation on the database
type Scope struct {
Search *search
Value interface{}
Sql string
SqlVars []interface{}
SQL string
SQLVars []interface{}
db *DB
instanceID string
primaryKeyField *Field
@ -23,6 +24,7 @@ type Scope struct {
selectAttrs *[]string
}
// IndirectValue return scope's reflect value's indirect value
func (scope *Scope) IndirectValue() reflect.Value {
return indirect(reflect.ValueOf(scope.Value))
}
@ -43,12 +45,13 @@ func (scope *Scope) NewDB() *DB {
return nil
}
// DB return scope's DB connection
func (scope *Scope) DB() *DB {
return scope.db
}
// SqlDB return *sql.DB
func (scope *Scope) SqlDB() sqlCommon {
// SQLDB return *sql.DB
func (scope *Scope) SQLDB() sqlCommon {
return scope.db.db
}
@ -100,6 +103,7 @@ func (scope *Scope) HasError() bool {
return scope.db.Error != nil
}
// PrimaryFields return scope's primary fields
func (scope *Scope) PrimaryFields() (fields []*Field) {
for _, field := range scope.Fields() {
if field.IsPrimaryKey {
@ -109,6 +113,7 @@ func (scope *Scope) PrimaryFields() (fields []*Field) {
return fields
}
// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one
func (scope *Scope) PrimaryField() *Field {
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
if len(primaryFields) > 1 {
@ -241,8 +246,8 @@ func (scope *Scope) AddToVars(value interface{}) string {
return exp
}
scope.SqlVars = append(scope.SqlVars, value)
return scope.Dialect().BindVar(len(scope.SqlVars))
scope.SQLVars = append(scope.SQLVars, value)
return scope.Dialect().BindVar(len(scope.SQLVars))
}
type tabler interface {
@ -282,10 +287,10 @@ func (scope *Scope) QuotedTableName() (name string) {
return scope.Quote(scope.TableName())
}
// CombinedConditionSql get combined condition sql
// CombinedConditionSql return combined condition sql
func (scope *Scope) CombinedConditionSql() string {
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql()
return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() +
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
}
// FieldByName find gorm.Field with name and db name
@ -308,7 +313,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
// Raw set sql
func (scope *Scope) Raw(sql string) *Scope {
scope.Sql = strings.Replace(sql, "$$", "?", -1)
scope.SQL = strings.Replace(sql, "$$", "?", -1)
return scope
}
@ -317,7 +322,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc())
if !scope.HasError() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count
}
@ -345,17 +350,19 @@ func (scope *Scope) InstanceID() string {
return scope.instanceID
}
// InstanceSet set value for current instance, but not for associations
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
return scope.Set(name+scope.InstanceID(), value)
}
// InstanceGet get setting from current instance
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
return scope.Get(name + scope.InstanceID())
}
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.SqlDB().(sqlDb); ok {
if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction", true)
@ -379,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope {
return scope
}
// SelectAttrs retur nselected attributes
func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil {
attrs := []string{}
@ -398,6 +406,7 @@ func (scope *Scope) SelectAttrs() []string {
return *scope.selectAttrs
}
// OmitAttrs return omited attributes
func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}

View File

@ -76,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
}
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSql string
var notEqualSQL string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) {
@ -87,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSql = fmt.Sprintf("NOT (%v)", value)
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
@ -139,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
}
}
return
@ -173,7 +173,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
return
}
func (scope *Scope) whereSql() (sql string) {
func (scope *Scope) whereSQL() (sql string) {
var (
quotedTableName = scope.QuotedTableName()
primaryConditions, andConditions, orConditions []string
@ -209,28 +209,28 @@ func (scope *Scope) whereSql() (sql string) {
}
}
orSql := strings.Join(orConditions, " OR ")
combinedSql := strings.Join(andConditions, " AND ")
if len(combinedSql) > 0 {
if len(orSql) > 0 {
combinedSql = combinedSql + " OR " + orSql
orSQL := strings.Join(orConditions, " OR ")
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) > 0 {
if len(orSQL) > 0 {
combinedSQL = combinedSQL + " OR " + orSQL
}
} else {
combinedSql = orSql
combinedSQL = orSQL
}
if len(primaryConditions) > 0 {
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
if len(combinedSql) > 0 {
sql = sql + " AND (" + combinedSql + ")"
if len(combinedSQL) > 0 {
sql = sql + " AND (" + combinedSQL + ")"
}
} else if len(combinedSql) > 0 {
sql = "WHERE " + combinedSql
} else if len(combinedSQL) > 0 {
sql = "WHERE " + combinedSQL
}
return
}
func (scope *Scope) selectSql() string {
func (scope *Scope) selectSQL() string {
if len(scope.Search.selects) == 0 {
if len(scope.Search.joinConditions) > 0 {
return fmt.Sprintf("%v.*", scope.QuotedTableName())
@ -240,25 +240,25 @@ func (scope *Scope) selectSql() string {
return scope.buildSelectQuery(scope.Search.selects)
}
func (scope *Scope) orderSql() string {
func (scope *Scope) orderSQL() string {
if len(scope.Search.orders) == 0 || scope.Search.countingQuery {
return ""
}
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}
func (scope *Scope) limitAndOffsetSql() string {
func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
}
func (scope *Scope) groupSql() string {
func (scope *Scope) groupSQL() string {
if len(scope.Search.group) == 0 {
return ""
}
return " GROUP BY " + scope.Search.group
}
func (scope *Scope) havingSql() string {
func (scope *Scope) havingSQL() string {
if len(scope.Search.havingConditions) == 0 {
return ""
}
@ -270,15 +270,15 @@ func (scope *Scope) havingSql() string {
}
}
combinedSql := strings.Join(andConditions, " AND ")
if len(combinedSql) == 0 {
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) == 0 {
return ""
}
return " HAVING " + combinedSql
return " HAVING " + combinedSQL
}
func (scope *Scope) joinsSql() string {
func (scope *Scope) joinsSQL() string {
var joinConditions []string
for _, clause := range scope.Search.joinConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
@ -289,11 +289,11 @@ func (scope *Scope) joinsSql() string {
return strings.Join(joinConditions, " ") + " "
}
func (scope *Scope) prepareQuerySql() {
func (scope *Scope) prepareQuerySQL() {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
}
return
}
@ -345,15 +345,15 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (resul
func (scope *Scope) row() *sql.Row {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySql()
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
scope.prepareQuerySQL()
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySql()
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.prepareQuerySQL()
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
}
func (scope *Scope) initialize() *Scope {
@ -404,8 +404,8 @@ func (scope *Scope) typeName() string {
// trace print sql log
func (scope *Scope) trace(t time.Time) {
if len(scope.Sql) > 0 {
scope.db.slog(scope.Sql, t, scope.SqlVars...)
if len(scope.SQL) > 0 {
scope.db.slog(scope.SQL, t, scope.SQLVars...)
}
}
@ -599,7 +599,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
sqlCreate = "CREATE UNIQUE INDEX"
}
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec()
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {

View File

@ -93,7 +93,7 @@ func (s *search) Offset(offset int) *search {
}
func (s *search) Group(query string) *search {
s.group = s.getInterfaceAsSql(query)
s.group = s.getInterfaceAsSQL(query)
return s
}
@ -134,12 +134,12 @@ func (s *search) Table(name string) *search {
return s
}
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
switch value.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
str = fmt.Sprintf("%v", value)
default:
s.db.AddError(InvalidSql)
s.db.AddError(ErrInvalidSQL)
}
if str == "-1" {

View File

@ -9,8 +9,19 @@ import (
"runtime"
"strings"
"sync"
"time"
)
// NowFunc returns current time, this function is exported in order to be able
// to give the flexibility to the developer to customize it according to their
// needs
//
// e.g: return time.Now().UTC()
//
var NowFunc = func() time.Time {
return time.Now()
}
// Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer