forked from mirror/gorm
add an override on the DB instance instead of using the global NowFunc. (#2142)
This commit is contained in:
parent
fec06da6a3
commit
c44c6027fb
|
@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) {
|
||||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
now := NowFunc()
|
now := scope.db.nowFunc()
|
||||||
|
|
||||||
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
||||||
if createdAtField.IsBlank {
|
if createdAtField.IsBlank {
|
||||||
|
@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
|
||||||
// createCallback the callback used to insert data into database
|
// createCallback the callback used to insert data into database
|
||||||
func createCallback(scope *Scope) {
|
func createCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
columns, placeholders []string
|
columns, placeholders []string
|
||||||
|
|
|
@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) {
|
||||||
"UPDATE %v SET %v=%v%v%v",
|
"UPDATE %v SET %v=%v%v%v",
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
scope.Quote(deletedAtField.DBName),
|
scope.Quote(deletedAtField.DBName),
|
||||||
scope.AddToVars(NowFunc()),
|
scope.AddToVars(scope.db.nowFunc()),
|
||||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
)).Exec()
|
)).Exec()
|
||||||
|
|
|
@ -24,7 +24,7 @@ func queryCallback(scope *Scope) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isSlice, isPtr bool
|
isSlice, isPtr bool
|
||||||
|
|
|
@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) {
|
||||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
scope.SetColumn("UpdatedAt", NowFunc())
|
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateWithNowFuncOverride(t *testing.T) {
|
||||||
|
user1 := User{Name: "CreateUserTimestampOverride"}
|
||||||
|
|
||||||
|
timeA := now.MustParse("2016-01-01")
|
||||||
|
|
||||||
|
// do DB.New() because we don't want this test to affect other tests
|
||||||
|
db1 := DB.New()
|
||||||
|
// set the override to use static timeA
|
||||||
|
db1.SetNowFuncOverride(func() time.Time {
|
||||||
|
return timeA
|
||||||
|
})
|
||||||
|
// call .New again to check the override is carried over as well during clone
|
||||||
|
db1 = db1.New()
|
||||||
|
|
||||||
|
db1.Save(&user1)
|
||||||
|
|
||||||
|
if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("CreatedAt be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("UpdatedAt be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
|
||||||
|
// now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
|
||||||
|
// to make sure that setting it only affected the above instance
|
||||||
|
|
||||||
|
user2 := User{Name: "CreateUserTimestampOverrideNoMore"}
|
||||||
|
|
||||||
|
db2 := DB.New()
|
||||||
|
|
||||||
|
db2.Save(&user2)
|
||||||
|
|
||||||
|
if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("CreatedAt no longer be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
|
||||||
|
t.Errorf("UpdatedAt no longer be using the nowFuncOverride")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type AutoIncrementUser struct {
|
type AutoIncrementUser struct {
|
||||||
User
|
User
|
||||||
Sequence uint `gorm:"AUTO_INCREMENT"`
|
Sequence uint `gorm:"AUTO_INCREMENT"`
|
||||||
|
|
20
main.go
20
main.go
|
@ -31,6 +31,9 @@ type DB struct {
|
||||||
callbacks *Callback
|
callbacks *Callback
|
||||||
dialect Dialect
|
dialect Dialect
|
||||||
singularTable bool
|
singularTable bool
|
||||||
|
|
||||||
|
// function to be used to override the creating of a new timestamp
|
||||||
|
nowFuncOverride func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type logModeValue int
|
type logModeValue int
|
||||||
|
@ -158,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNowFuncOverride set the function to be used when creating a new timestamp
|
||||||
|
func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
|
||||||
|
s.nowFuncOverride = nowFuncOverride
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
|
||||||
|
// otherwise defaults to the global NowFunc()
|
||||||
|
func (s *DB) nowFunc() time.Time {
|
||||||
|
if s.nowFuncOverride != nil {
|
||||||
|
return s.nowFuncOverride()
|
||||||
|
}
|
||||||
|
|
||||||
|
return NowFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
|
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
|
||||||
// This is to prevent eventual error with empty objects updates/deletions
|
// This is to prevent eventual error with empty objects updates/deletions
|
||||||
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
|
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
|
||||||
|
@ -800,6 +819,7 @@ func (s *DB) clone() *DB {
|
||||||
Error: s.Error,
|
Error: s.Error,
|
||||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||||
dialect: newDialect(s.dialect.GetName(), s.db),
|
dialect: newDialect(s.dialect.GetName(), s.db),
|
||||||
|
nowFuncOverride: s.nowFuncOverride,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.values.Range(func(k, v interface{}) bool {
|
s.values.Range(func(k, v interface{}) bool {
|
||||||
|
|
6
scope.go
6
scope.go
|
@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope {
|
||||||
|
|
||||||
// Exec perform generated SQL
|
// Exec perform generated SQL
|
||||||
func (scope *Scope) Exec() *Scope {
|
func (scope *Scope) Exec() *Scope {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
|
@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) row() *sql.Row {
|
func (scope *Scope) row() *sql.Row {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
result := &RowQueryResult{}
|
result := &RowQueryResult{}
|
||||||
scope.InstanceSet("row_query_result", result)
|
scope.InstanceSet("row_query_result", result)
|
||||||
|
@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(scope.db.nowFunc())
|
||||||
|
|
||||||
result := &RowsQueryResult{}
|
result := &RowsQueryResult{}
|
||||||
scope.InstanceSet("row_query_result", result)
|
scope.InstanceSet("row_query_result", result)
|
||||||
|
|
Loading…
Reference in New Issue