diff --git a/callback.go b/callback.go index a17fb2aa..9704e584 100644 --- a/callback.go +++ b/callback.go @@ -4,7 +4,7 @@ import ( "fmt" ) -// defaultCallback hold default callbacks defined by gorm +// DefaultCallback default callbacks defined by gorm var DefaultCallback = &Callback{} // Callback contains callbacks that used when CURD objects diff --git a/callback_create.go b/callback_create.go index 2a0b9b2a..e3cd2f0b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -85,14 +85,14 @@ func createCallback(scope *Scope) { returningColumn = scope.Quote(primaryField.DBName) } - lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn) + lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) if len(columns) == 0 { scope.Raw(fmt.Sprintf( "INSERT INTO %v DEFAULT VALUES%v%v", quotedTableName, addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIdReturningSuffix), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } else { scope.Raw(fmt.Sprintf( @@ -101,13 +101,13 @@ func createCallback(scope *Scope) { strings.Join(columns, ","), strings.Join(placeholders, ","), addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIdReturningSuffix), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } // execute create sql - if lastInsertIdReturningSuffix == "" || primaryField == nil { - if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + if lastInsertIDReturningSuffix == "" || primaryField == nil { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -119,7 +119,7 @@ func createCallback(scope *Scope) { } } } else { - if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { scope.db.RowsAffected = 1 } } diff --git a/callback_query.go b/callback_query.go index acdf149b..08678f92 100644 --- a/callback_query.go +++ b/callback_query.go @@ -48,15 +48,15 @@ func queryCallback(scope *Scope) { return } - scope.prepareQuerySql() + scope.prepareQuerySQL() if !scope.HasError() { scope.db.RowsAffected = 0 if str, ok := scope.Get("gorm:query_option"); ok { - scope.Sql += addExtraSpaceIfExist(fmt.Sprint(str)) + scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } - if rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { defer rows.Close() columns, _ := rows.Columns() @@ -80,7 +80,7 @@ func queryCallback(scope *Scope) { } if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(RecordNotFound) + scope.Err(ErrRecordNotFound) } } } diff --git a/dialect.go b/dialect.go index adaf43af..1d757078 100644 --- a/dialect.go +++ b/dialect.go @@ -39,7 +39,7 @@ type Dialect interface { // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIdReturningSuffix(tableName, columnName string) string + LastInsertIDReturningSuffix(tableName, columnName string) string } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index 95553c97..f009271b 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -132,6 +132,6 @@ func (commonDialect) SelectFromDummyTable() string { return "" } -func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string { +func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } diff --git a/dialect_postgres.go b/dialect_postgres.go index baf76659..18cbe028 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -106,11 +106,11 @@ func (s postgres) currentDatabase() (name string) { return } -func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string { +func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } -func (postgres) SupportLastInsertId() bool { +func (postgres) SupportLastInsertID() bool { return false } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 657bada9..5b994f9d 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -145,6 +145,6 @@ func (mssql) SelectFromDummyTable() string { return "" } -func (mssql) LastInsertIdReturningSuffix(tableName, columnName string) string { +func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } diff --git a/errors.go b/errors.go index c59dd968..32b7f220 100644 --- a/errors.go +++ b/errors.go @@ -6,24 +6,31 @@ import ( ) var ( - RecordNotFound = errors.New("record not found") - InvalidSql = errors.New("invalid sql") - NoValidTransaction = errors.New("no valid transaction") - CantStartTransaction = errors.New("can't start transaction") + // ErrRecordNotFound record not found, happens when you are looking up with a struct, and haven't find any matched data + ErrRecordNotFound = errors.New("record not found") + // ErrInvalidSQL invalid SQL, happens when you passed invalid SQL + ErrInvalidSQL = errors.New("invalid SQL") + // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + ErrInvalidTransaction = errors.New("no valid transaction") + // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` + ErrCantStartTransaction = errors.New("can't start transaction") ) type errorsInterface interface { GetErrors() []error } +// Errors contains all happened errors type Errors struct { errors []error } +// GetErrors get all happened errors func (errs Errors) GetErrors() []error { return errs.errors } +// Add add an error func (errs *Errors) Add(err error) { if errors, ok := err.(errorsInterface); ok { for _, err := range errors.GetErrors() { @@ -39,6 +46,7 @@ func (errs *Errors) Add(err error) { } } +// Error format happened errors func (errs Errors) Error() string { var errors = []string{} for _, e := range errs.errors { diff --git a/field.go b/field.go index ff252e1a..cadc1a72 100644 --- a/field.go +++ b/field.go @@ -7,12 +7,14 @@ import ( "reflect" ) +// Field model field definition type Field struct { *StructField IsBlank bool Field reflect.Value } +// Set set a value to the field func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { return errors.New("field value not valid") diff --git a/join_table_handler.go b/join_table_handler.go index 6251cd22..18c12a85 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -7,40 +7,54 @@ import ( "strings" ) +// JoinTableHandlerInterface is an interface for how to handle many2many relations type JoinTableHandlerInterface interface { + // initialize join table handler Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) + // Table return join table's table name Table(db *DB) string + // Add create relationship in join table for source and destination Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error + // Delete delete relationship in join table for sources Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error + // JoinWith query with `Join` conditions JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + // SourceForeignKeys return source foreign keys SourceForeignKeys() []JoinTableForeignKey + // DestinationForeignKeys return destination foreign keys DestinationForeignKeys() []JoinTableForeignKey } +// JoinTableForeignKey join table foreign key struct type JoinTableForeignKey struct { DBName string AssociationDBName string } +// JoinTableSource is a struct that contains model type and foreign keys type JoinTableSource struct { ModelType reflect.Type ForeignKeys []JoinTableForeignKey } +// JoinTableHandler default join table handler type JoinTableHandler struct { TableName string `sql:"-"` Source JoinTableSource `sql:"-"` Destination JoinTableSource `sql:"-"` } +// SourceForeignKeys return source foreign keys func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { return s.Source.ForeignKeys } +// DestinationForeignKeys return destination foreign keys func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { return s.Destination.ForeignKeys } +// Setup initialize a default join table handler func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { s.TableName = tableName @@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } } +// Table return join table's table name func (s JoinTableHandler) Table(db *DB) string { return s.TableName } -func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { +func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { values := map[string]interface{}{} for _, source := range sources { @@ -89,9 +104,10 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin return values } -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { +// Add create relationship in join table for source and destination +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { scope := db.NewScope("") - searchMap := s.GetSearchMap(db, source1, source2) + searchMap := s.getSearchMap(db, source, destination) var assignColumns, binVars, conditions []string var values []interface{} @@ -120,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 return db.Exec(sql, values...).Error } +// Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( scope = db.NewScope(nil) @@ -127,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour values []interface{} ) - for key, value := range s.GetSearchMap(db, sources...) { + for key, value := range s.getSearchMap(db, sources...) { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } @@ -135,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } +// JoinWith query with `Join` conditions func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { var ( scope = db.NewScope(source) diff --git a/logger.go b/logger.go index 6b948804..f4c79ead 100644 --- a/logger.go +++ b/logger.go @@ -11,23 +11,25 @@ import ( "unicode" ) +var ( + defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} + sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) +) + type logger interface { Print(v ...interface{}) } -type LogWriter interface { +type logWriter interface { Println(v ...interface{}) } +// Logger default logger type Logger struct { - LogWriter + logWriter } -var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - -// Format log -var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) - +// Print format & print log func (logger Logger) Print(values ...interface{}) { if len(values) > 1 { level := values[0] diff --git a/main.go b/main.go index 09b6df74..b4b8a169 100644 --- a/main.go +++ b/main.go @@ -6,19 +6,9 @@ import ( "fmt" "reflect" "strings" - "time" ) -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs -// -// e.g: return time.Now().UTC() -// -var NowFunc = func() time.Time { - return time.Now() -} - +// DB contains information for current db connection type DB struct { Value interface{} Error error @@ -36,6 +26,14 @@ type DB struct { joinTableHandlers map[string]JoinTableHandler } +// Open open a new db connection, need to import driver first, for example: +// +// import _ "github.com/go-sql-driver/mysql" +// func main() { +// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") +// } +// GORM has wrapped some drivers, for easier to remember its name, so you could import the mysql driver with +// import _ "github.com/jinzhu/gorm/dialects/mysql" func Open(dialect string, args ...interface{}) (*DB, error) { var db DB var err error @@ -44,7 +42,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) { err = errors.New("invalid database source") } else { var source string - var dbSql sqlCommon + var dbSQL sqlCommon switch value := args[0].(type) { case string: @@ -55,19 +53,19 @@ func Open(dialect string, args ...interface{}) (*DB, error) { driver = value source = args[1].(string) } - dbSql, err = sql.Open(driver, source) + dbSQL, err = sql.Open(driver, source) case sqlCommon: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSql = value + dbSQL = value } db = DB{ - dialect: newDialect(dialect, dbSql.(*sql.DB)), + dialect: newDialect(dialect, dbSQL.(*sql.DB)), logger: defaultLogger, callbacks: DefaultCallback, source: source, values: map[string]interface{}{}, - db: dbSql, + db: dbSQL, } db.parent = &db @@ -79,14 +77,17 @@ func Open(dialect string, args ...interface{}) (*DB, error) { return &db, err } +// Close close current db connection func (s *DB) Close() error { return s.parent.db.(*sql.DB).Close() } +// DB get `*sql.DB` from current connection func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// New initialize a new db connection without any search conditions func (s *DB) New() *DB { clone := s.clone() clone.search = nil @@ -94,29 +95,34 @@ func (s *DB) New() *DB { return clone } -// NewScope create scope for callbacks, including DB's search information +// NewScope create a scope for current operation func (s *DB) NewScope(value interface{}) *Scope { dbClone := s.clone() dbClone.Value = value return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} } -// CommonDB Return the underlying sql.DB or sql.Tx instance. +// CommonDB return the underlying sql.DB or sql.Tx instance. // Use of this method is discouraged. It's mainly intended to allow // coexistence with legacy non-GORM code. func (s *DB) CommonDB() sqlCommon { return s.db } +// Callback return Callbacks container, you could add/remove/change callbacks with it +// db.Callback().Create().Register("update_created_at", updateCreated) +// Refer: https://jinzhu.github.io/gorm/development.html#callbacks for more func (s *DB) Callback() *Callback { s.parent.callbacks = s.parent.callbacks.clone() return s.parent.callbacks } -func (s *DB) SetLogger(l logger) { - s.logger = l +// SetLogger replace default logger +func (s *DB) SetLogger(log logger) { + s.logger = log } +// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { s.logMode = 2 @@ -126,51 +132,65 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { modelStructsMap = newModelStructsMap() s.parent.singularTable = enable } +// Where return a new relation, accepts use `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db } +// Or match before conditions or this one, similar to `Where` func (s *DB) Or(query interface{}, args ...interface{}) *DB { return s.clone().search.Or(query, args...).db } +// Not don't match current conditions, similar to `Where` func (s *DB) Not(query interface{}, args ...interface{}) *DB { return s.clone().search.Not(query, args...).db } +// Limit specify the number of records to be retrieved func (s *DB) Limit(limit int) *DB { return s.clone().search.Limit(limit).db } +// Offset specify the number of records to skip before starting to return the records func (s *DB) Offset(offset int) *DB { return s.clone().search.Offset(offset).db } +// Order specify order when retrieve records from database, pass `true` as the second argument to overwrite `Order` conditions func (s *DB) Order(value string, reorder ...bool) *DB { return s.clone().search.Order(value, reorder...).db } +// Select When querying, specify fields that you want to retrieve from database, by default, will select all fields; +// When creating/updating, specify fields that you want to save to database func (s *DB) Select(query interface{}, args ...interface{}) *DB { return s.clone().search.Select(query, args...).db } +// Omit specify fields that you want to ignore when save to database when creating/updating func (s *DB) Omit(columns ...string) *DB { return s.clone().search.Omit(columns...).db } +// Group specify the group method on the find func (s *DB) Group(query string) *DB { return s.clone().search.Group(query).db } +// Having specify HAVING conditions for GROUP BY func (s *DB) Having(query string, values ...interface{}) *DB { return s.clone().search.Having(query, values...).db } +// Joins specify Joins conditions +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (s *DB) Joins(query string, args ...interface{}) *DB { return s.clone().search.Joins(query, args...).db } @@ -352,7 +372,7 @@ func (s *DB) Begin() *DB { c.db = interface{}(tx).(sqlCommon) c.AddError(err) } else { - c.AddError(CantStartTransaction) + c.AddError(ErrCantStartTransaction) } return c } @@ -361,7 +381,7 @@ func (s *DB) Commit() *DB { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Commit()) } else { - s.AddError(NoValidTransaction) + s.AddError(ErrInvalidTransaction) } return s } @@ -370,7 +390,7 @@ func (s *DB) Rollback() *DB { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Rollback()) } else { - s.AddError(NoValidTransaction) + s.AddError(ErrInvalidTransaction) } return s } @@ -380,7 +400,7 @@ func (s *DB) NewRecord(value interface{}) bool { } func (s *DB) RecordNotFound() bool { - return s.Error == RecordNotFound + return s.Error == ErrRecordNotFound } // CreateTable create table for models @@ -541,7 +561,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join func (s *DB) AddError(err error) error { if err != nil { - if err != RecordNotFound { + if err != ErrRecordNotFound { if s.logMode == 0 { go s.print(fileWithLineNum(), err) } else { diff --git a/main_test.go b/main_test.go index c732e7fc..2da3de94 100644 --- a/main_test.go +++ b/main_test.go @@ -479,7 +479,7 @@ func TestRaw(t *testing.T) { } DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { t.Error("Raw sql to update records") } } @@ -709,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) { } var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { + if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { t.Errorf("Should have found existing record") } } diff --git a/model.go b/model.go index ffa68b07..1ffdf2ef 100644 --- a/model.go +++ b/model.go @@ -2,6 +2,7 @@ package gorm import "time" +// Model base model definition, including `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models type Model struct { ID uint `gorm:"primary_key"` CreatedAt time.Time diff --git a/model_struct.go b/model_struct.go index a17d2257..7773a1bf 100644 --- a/model_struct.go +++ b/model_struct.go @@ -12,6 +12,7 @@ import ( "github.com/jinzhu/inflection" ) +// DefaultTableNameHandler default table name handler var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } @@ -39,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap { var modelStructsMap = newModelStructsMap() +// ModelStruct model definition type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField @@ -46,10 +48,12 @@ type ModelStruct struct { defaultTableName string } +// TableName get model's table name func (s *ModelStruct) TableName(db *DB) string { return DefaultTableNameHandler(db, s.defaultTableName) } +// StructField model field's struct definition type StructField struct { DBName string Name string @@ -506,6 +510,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { return &modelStruct } +// GetStructFields get model's field structs func (scope *Scope) GetStructFields() (fields []*StructField) { return scope.GetModelStruct().StructFields } diff --git a/preload_test.go b/preload_test.go index 8f21bc97..cde8e800 100644 --- a/preload_test.go +++ b/preload_test.go @@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -981,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -1038,7 +1038,7 @@ func TestNestedManyToManyPreload2(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } diff --git a/scope.go b/scope.go index 9f4e821d..8d76489a 100644 --- a/scope.go +++ b/scope.go @@ -10,11 +10,12 @@ import ( "reflect" ) +// Scope contain any information of current operation when you perform any operation on the database type Scope struct { Search *search Value interface{} - Sql string - SqlVars []interface{} + SQL string + SQLVars []interface{} db *DB instanceID string primaryKeyField *Field @@ -23,6 +24,7 @@ type Scope struct { selectAttrs *[]string } +// IndirectValue return scope's reflect value's indirect value func (scope *Scope) IndirectValue() reflect.Value { return indirect(reflect.ValueOf(scope.Value)) } @@ -43,12 +45,13 @@ func (scope *Scope) NewDB() *DB { return nil } +// DB return scope's DB connection func (scope *Scope) DB() *DB { return scope.db } -// SqlDB return *sql.DB -func (scope *Scope) SqlDB() sqlCommon { +// SQLDB return *sql.DB +func (scope *Scope) SQLDB() sqlCommon { return scope.db.db } @@ -100,6 +103,7 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } +// PrimaryFields return scope's primary fields func (scope *Scope) PrimaryFields() (fields []*Field) { for _, field := range scope.Fields() { if field.IsPrimaryKey { @@ -109,6 +113,7 @@ func (scope *Scope) PrimaryFields() (fields []*Field) { return fields } +// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 { @@ -241,8 +246,8 @@ func (scope *Scope) AddToVars(value interface{}) string { return exp } - scope.SqlVars = append(scope.SqlVars, value) - return scope.Dialect().BindVar(len(scope.SqlVars)) + scope.SQLVars = append(scope.SQLVars, value) + return scope.Dialect().BindVar(len(scope.SQLVars)) } type tabler interface { @@ -282,10 +287,10 @@ func (scope *Scope) QuotedTableName() (name string) { return scope.Quote(scope.TableName()) } -// CombinedConditionSql get combined condition sql +// CombinedConditionSql return combined condition sql func (scope *Scope) CombinedConditionSql() string { - return scope.joinsSql() + scope.whereSql() + scope.groupSql() + - scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql() + return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + + scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() } // FieldByName find gorm.Field with name and db name @@ -308,7 +313,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { // Raw set sql func (scope *Scope) Raw(sql string) *Scope { - scope.Sql = strings.Replace(sql, "$$", "?", -1) + scope.SQL = strings.Replace(sql, "$$", "?", -1) return scope } @@ -317,7 +322,7 @@ func (scope *Scope) Exec() *Scope { defer scope.trace(NowFunc()) if !scope.HasError() { - if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } @@ -345,17 +350,19 @@ func (scope *Scope) InstanceID() string { return scope.instanceID } +// InstanceSet set value for current instance, but not for associations func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { return scope.Set(name+scope.InstanceID(), value) } +// InstanceGet get setting from current instance func (scope *Scope) InstanceGet(name string) (interface{}, bool) { return scope.Get(name + scope.InstanceID()) } // Begin start a transaction func (scope *Scope) Begin() *Scope { - if db, ok := scope.SqlDB().(sqlDb); ok { + if db, ok := scope.SQLDB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { scope.db.db = interface{}(tx).(sqlCommon) scope.InstanceSet("gorm:started_transaction", true) @@ -379,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope { return scope } +// SelectAttrs retur nselected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { attrs := []string{} @@ -398,6 +406,7 @@ func (scope *Scope) SelectAttrs() []string { return *scope.selectAttrs } +// OmitAttrs return omited attributes func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } diff --git a/scope_private.go b/scope_private.go index 31db4a0b..9309b6f4 100644 --- a/scope_private.go +++ b/scope_private.go @@ -76,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSql string + var notEqualSQL string var primaryKey = scope.PrimaryKey() switch value := clause["query"].(type) { @@ -87,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSql = fmt.Sprintf("NOT (%v)", value) + notEqualSQL = fmt.Sprintf("NOT (%v)", value) } else { str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) - notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) + notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) @@ -139,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1) + str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) } } return @@ -173,7 +173,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) return } -func (scope *Scope) whereSql() (sql string) { +func (scope *Scope) whereSQL() (sql string) { var ( quotedTableName = scope.QuotedTableName() primaryConditions, andConditions, orConditions []string @@ -209,28 +209,28 @@ func (scope *Scope) whereSql() (sql string) { } } - orSql := strings.Join(orConditions, " OR ") - combinedSql := strings.Join(andConditions, " AND ") - if len(combinedSql) > 0 { - if len(orSql) > 0 { - combinedSql = combinedSql + " OR " + orSql + orSQL := strings.Join(orConditions, " OR ") + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) > 0 { + if len(orSQL) > 0 { + combinedSQL = combinedSQL + " OR " + orSQL } } else { - combinedSql = orSql + combinedSQL = orSQL } if len(primaryConditions) > 0 { sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSql) > 0 { - sql = sql + " AND (" + combinedSql + ")" + if len(combinedSQL) > 0 { + sql = sql + " AND (" + combinedSQL + ")" } - } else if len(combinedSql) > 0 { - sql = "WHERE " + combinedSql + } else if len(combinedSQL) > 0 { + sql = "WHERE " + combinedSQL } return } -func (scope *Scope) selectSql() string { +func (scope *Scope) selectSQL() string { if len(scope.Search.selects) == 0 { if len(scope.Search.joinConditions) > 0 { return fmt.Sprintf("%v.*", scope.QuotedTableName()) @@ -240,25 +240,25 @@ func (scope *Scope) selectSql() string { return scope.buildSelectQuery(scope.Search.selects) } -func (scope *Scope) orderSql() string { +func (scope *Scope) orderSQL() string { if len(scope.Search.orders) == 0 || scope.Search.countingQuery { return "" } return " ORDER BY " + strings.Join(scope.Search.orders, ",") } -func (scope *Scope) limitAndOffsetSql() string { +func (scope *Scope) limitAndOffsetSQL() string { return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) } -func (scope *Scope) groupSql() string { +func (scope *Scope) groupSQL() string { if len(scope.Search.group) == 0 { return "" } return " GROUP BY " + scope.Search.group } -func (scope *Scope) havingSql() string { +func (scope *Scope) havingSQL() string { if len(scope.Search.havingConditions) == 0 { return "" } @@ -270,15 +270,15 @@ func (scope *Scope) havingSql() string { } } - combinedSql := strings.Join(andConditions, " AND ") - if len(combinedSql) == 0 { + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) == 0 { return "" } - return " HAVING " + combinedSql + return " HAVING " + combinedSQL } -func (scope *Scope) joinsSql() string { +func (scope *Scope) joinsSQL() string { var joinConditions []string for _, clause := range scope.Search.joinConditions { if sql := scope.buildWhereCondition(clause); sql != "" { @@ -289,11 +289,11 @@ func (scope *Scope) joinsSql() string { return strings.Join(joinConditions, " ") + " " } -func (scope *Scope) prepareQuerySql() { +func (scope *Scope) prepareQuerySQL() { if scope.Search.raw { scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) } return } @@ -345,15 +345,15 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (resul func (scope *Scope) row() *sql.Row { defer scope.trace(NowFunc()) scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySql() - return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) + scope.prepareQuerySQL() + return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) } func (scope *Scope) rows() (*sql.Rows, error) { defer scope.trace(NowFunc()) scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySql() - return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) + scope.prepareQuerySQL() + return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) } func (scope *Scope) initialize() *Scope { @@ -404,8 +404,8 @@ func (scope *Scope) typeName() string { // trace print sql log func (scope *Scope) trace(t time.Time) { - if len(scope.Sql) > 0 { - scope.db.slog(scope.Sql, t, scope.SqlVars...) + if len(scope.SQL) > 0 { + scope.db.slog(scope.SQL, t, scope.SQLVars...) } } @@ -599,7 +599,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { sqlCreate = "CREATE UNIQUE INDEX" } - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() + scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { diff --git a/search.go b/search.go index 4e31ae03..078bd429 100644 --- a/search.go +++ b/search.go @@ -93,7 +93,7 @@ func (s *search) Offset(offset int) *search { } func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSql(query) + s.group = s.getInterfaceAsSQL(query) return s } @@ -134,12 +134,12 @@ func (s *search) Table(name string) *search { return s } -func (s *search) getInterfaceAsSql(value interface{}) (str string) { +func (s *search) getInterfaceAsSQL(value interface{}) (str string) { switch value.(type) { case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: str = fmt.Sprintf("%v", value) default: - s.db.AddError(InvalidSql) + s.db.AddError(ErrInvalidSQL) } if str == "-1" { diff --git a/utils.go b/utils.go index 7fc53fa5..4ac2ab10 100644 --- a/utils.go +++ b/utils.go @@ -9,8 +9,19 @@ import ( "runtime" "strings" "sync" + "time" ) +// NowFunc returns current time, this function is exported in order to be able +// to give the flexibility to the developer to customize it according to their +// needs +// +// e.g: return time.Now().UTC() +// +var NowFunc = func() time.Time { + return time.Now() +} + // Copied from golint var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer