From b3b19a55773b2c4a004c469960dcac78eb068a96 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 11:34:59 +0800 Subject: [PATCH] Test Override NowFunc --- gorm.go | 24 +++++++++--------------- soft_delete.go | 3 +-- tests/create_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/gorm.go b/gorm.go index 6b2a6d75..70751cb3 100644 --- a/gorm.go +++ b/gorm.go @@ -30,9 +30,8 @@ type Config struct { // Dialector database dialector Dialector - statementPool sync.Pool - callbacks *callbacks - cacheStore *sync.Map + callbacks *callbacks + cacheStore *sync.Map } // DB GORM DB definition @@ -77,17 +76,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.cacheStore = &sync.Map{} } - config.statementPool = sync.Pool{ - New: func() interface{} { - return &Statement{ - DB: db, - ConnPool: db.ConnPool, - Clauses: map[string]clause.Clause{}, - Context: context.Background(), - } - }, - } - db = &DB{ Config: config, clone: true, @@ -179,7 +167,13 @@ func (db *DB) AddError(err error) error { func (db *DB) getInstance() *DB { if db.clone { - stmt := db.Config.statementPool.Get().(*Statement) + stmt := &Statement{ + DB: db, + ConnPool: db.ConnPool, + Clauses: map[string]clause.Clause{}, + Context: context.Background(), + } + if db.Statement != nil { stmt.Context = db.Statement.Context } diff --git a/soft_delete.go b/soft_delete.go index 138c9c63..09cfff37 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -4,7 +4,6 @@ import ( "database/sql" "database/sql/driver" "reflect" - "time" "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" @@ -55,7 +54,7 @@ func (SoftDeleteClause) MergeClause(*clause.Clause) { func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) diff --git a/tests/create_test.go b/tests/create_test.go index 43e2c718..a3b3b598 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,9 +1,13 @@ package tests_test import ( + "fmt" "testing" + "time" + "github.com/jinzhu/gorm" . "github.com/jinzhu/gorm/tests" + "github.com/jinzhu/now" ) func TestCreate(t *testing.T) { @@ -201,3 +205,43 @@ func TestCreateEmptyStrut(t *testing.T) { t.Errorf("No error should happen when creating user, but got %v", err) } } + +func TestCreateWithExistingTimestamp(t *testing.T) { + user := User{Name: "CreateUserExistingTimestamp"} + curTime := now.MustParse("2016-01-01") + user.CreatedAt = curTime + user.UpdatedAt = curTime + DB.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + DB.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} + +func TestCreateWithNowFuncOverride(t *testing.T) { + user := User{Name: "CreateUserTimestampOverride"} + curTime := now.MustParse("2016-01-01") + + NEW := DB.Session(&gorm.Session{ + NowFunc: func() time.Time { + fmt.Println("11iiiin") + return curTime + }, + }) + + NEW.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + NEW.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +}