diff --git a/callback_create.go b/callback_create.go index 763a2dfd..87aba8ee 100644 --- a/callback_create.go +++ b/callback_create.go @@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) { // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { - now := NowFunc() + now := scope.db.nowFunc() if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { if createdAtField.IsBlank { @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( columns, placeholders []string diff --git a/callback_delete.go b/callback_delete.go index 73d90880..50242e48 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) { "UPDATE %v SET %v=%v%v%v", scope.QuotedTableName(), scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), + scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() diff --git a/callback_query.go b/callback_query.go index 7facc42b..e3b3d534 100644 --- a/callback_query.go +++ b/callback_query.go @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) { return } - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( isSlice, isPtr bool diff --git a/callback_update.go b/callback_update.go index c52162c8..56711d37 100644 --- a/callback_update.go +++ b/callback_update.go @@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) { // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) + scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } diff --git a/create_test.go b/create_test.go index 450dd8a4..c80bdcbb 100644 --- a/create_test.go +++ b/create_test.go @@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) { } } +func TestCreateWithNowFuncOverride(t *testing.T) { + user1 := User{Name: "CreateUserTimestampOverride"} + + timeA := now.MustParse("2016-01-01") + + // do DB.New() because we don't want this test to affect other tests + db1 := DB.New() + // set the override to use static timeA + db1.SetNowFuncOverride(func() time.Time { + return timeA + }) + // call .New again to check the override is carried over as well during clone + db1 = db1.New() + + db1.Save(&user1) + + if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt be using the nowFuncOverride") + } + if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt be using the nowFuncOverride") + } + + // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set + // to make sure that setting it only affected the above instance + + user2 := User{Name: "CreateUserTimestampOverrideNoMore"} + + db2 := DB.New() + + db2.Save(&user2) + + if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt no longer be using the nowFuncOverride") + } + if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt no longer be using the nowFuncOverride") + } +} + type AutoIncrementUser struct { User Sequence uint `gorm:"AUTO_INCREMENT"` diff --git a/main.go b/main.go index 3093ec80..ec84906b 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,9 @@ type DB struct { callbacks *Callback dialect Dialect singularTable bool + + // function to be used to override the creating of a new timestamp + nowFuncOverride func() time.Time } type logModeValue int @@ -158,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// SetNowFuncOverride set the function to be used when creating a new timestamp +func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { + s.nowFuncOverride = nowFuncOverride + return s +} + +// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, +// otherwise defaults to the global NowFunc() +func (s *DB) nowFunc() time.Time { + if s.nowFuncOverride != nil { + return s.nowFuncOverride() + } + + return NowFunc() +} + // BlockGlobalUpdate if true, generates an error on update/delete without where clause. // This is to prevent eventual error with empty objects updates/deletions func (s *DB) BlockGlobalUpdate(enable bool) *DB { @@ -800,6 +819,7 @@ func (s *DB) clone() *DB { Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), + nowFuncOverride: s.nowFuncOverride, } s.values.Range(func(k, v interface{}) bool { diff --git a/scope.go b/scope.go index 0e639c70..c962c165 100644 --- a/scope.go +++ b/scope.go @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) if !scope.HasError() { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { @@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) @@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row { } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result)