Add GormValuer interface support

This commit is contained in:
Jinzhu 2020-08-27 15:03:57 +08:00
parent 0d96f99499
commit ce8853e7a6
10 changed files with 108 additions and 47 deletions

View File

@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statment Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Composite Primary Key
* Auto Migrations
* Logger

View File

@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok {
if i, ok := value.(BeforeCreateInterface); ok {
called = true
db.AddError(i.BeforeCreate(tx))
}
@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}

View File

@ -11,7 +11,7 @@ import (
func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
}
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
}

39
callbacks/interfaces.go Normal file
View File

@ -0,0 +1,39 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

View File

@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.AfterFindInterface); ok {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
}

View File

@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
if i, ok := value.(BeforeUpdateInterface); ok {
called = true
db.AddError(i.BeforeUpdate(tx))
}
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}

View File

@ -53,38 +53,7 @@ type TxCommitter interface {
Rollback() error
}
type BeforeCreateInterface interface {
BeforeCreate(*DB) error
}
type AfterCreateInterface interface {
AfterCreate(*DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*DB) error
}
type AfterSaveInterface interface {
AfterSave(*DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*DB) error
}
type AfterFindInterface interface {
AfterFind(*DB) error
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}

View File

@ -1,6 +1,8 @@
package schema
import "gorm.io/gorm/clause"
import (
"gorm.io/gorm/clause"
)
type GormDataTypeInterface interface {
GormDataType() string

View File

@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
case clause.Expr:
var varStr strings.Builder
var sql = v.SQL

View File

@ -1,16 +1,20 @@
package tests_test
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) {
type NullString struct {
sql.NullString
}
type Point struct {
X, Y int
}
func (point *Point) Scan(v interface{}) error {
return nil
}
func (point Point) GormDataType() string {
return "geo"
}
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
return clause.Expr{
SQL: "ST_PointFromText(?)",
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
}
}
func TestGORMValuer(t *testing.T) {
type UserWithPoint struct {
Name string
Point Point
}
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Create(&UserWithPoint{
Name: "jinzhu",
Point: Point{X: 100, Y: 100},
}).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
}
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
}
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
}
}