forked from mirror/gorm
Test Override NowFunc
This commit is contained in:
parent
1e7eb12cba
commit
b3b19a5577
24
gorm.go
24
gorm.go
|
@ -30,9 +30,8 @@ type Config struct {
|
|||
// Dialector database dialector
|
||||
Dialector
|
||||
|
||||
statementPool sync.Pool
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
|
@ -77,17 +76,6 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
config.cacheStore = &sync.Map{}
|
||||
}
|
||||
|
||||
config.statementPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Context: context.Background(),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
db = &DB{
|
||||
Config: config,
|
||||
clone: true,
|
||||
|
@ -179,7 +167,13 @@ func (db *DB) AddError(err error) error {
|
|||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone {
|
||||
stmt := db.Config.statementPool.Get().(*Statement)
|
||||
stmt := &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Context: context.Background(),
|
||||
}
|
||||
|
||||
if db.Statement != nil {
|
||||
stmt.Context = db.Statement.Context
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
|
@ -55,7 +54,7 @@ func (SoftDeleteClause) MergeClause(*clause.Clause) {
|
|||
|
||||
func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.String() == "" {
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: time.Now()}})
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}})
|
||||
|
||||
if stmt.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
"github.com/jinzhu/now"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
|
@ -201,3 +205,43 @@ func TestCreateEmptyStrut(t *testing.T) {
|
|||
t.Errorf("No error should happen when creating user, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithExistingTimestamp(t *testing.T) {
|
||||
user := User{Name: "CreateUserExistingTimestamp"}
|
||||
curTime := now.MustParse("2016-01-01")
|
||||
user.CreatedAt = curTime
|
||||
user.UpdatedAt = curTime
|
||||
DB.Save(&user)
|
||||
|
||||
AssertEqual(t, user.CreatedAt, curTime)
|
||||
AssertEqual(t, user.UpdatedAt, curTime)
|
||||
|
||||
var newUser User
|
||||
DB.First(&newUser, user.ID)
|
||||
|
||||
AssertEqual(t, newUser.CreatedAt, curTime)
|
||||
AssertEqual(t, newUser.UpdatedAt, curTime)
|
||||
}
|
||||
|
||||
func TestCreateWithNowFuncOverride(t *testing.T) {
|
||||
user := User{Name: "CreateUserTimestampOverride"}
|
||||
curTime := now.MustParse("2016-01-01")
|
||||
|
||||
NEW := DB.Session(&gorm.Session{
|
||||
NowFunc: func() time.Time {
|
||||
fmt.Println("11iiiin")
|
||||
return curTime
|
||||
},
|
||||
})
|
||||
|
||||
NEW.Save(&user)
|
||||
|
||||
AssertEqual(t, user.CreatedAt, curTime)
|
||||
AssertEqual(t, user.UpdatedAt, curTime)
|
||||
|
||||
var newUser User
|
||||
NEW.First(&newUser, user.ID)
|
||||
|
||||
AssertEqual(t, newUser.CreatedAt, curTime)
|
||||
AssertEqual(t, newUser.UpdatedAt, curTime)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue