Add GormValuer interface support

This commit is contained in:
Jinzhu 2020-08-27 15:03:57 +08:00
parent 0d96f99499
commit ce8853e7a6
10 changed files with 108 additions and 47 deletions

View File

@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statment Mode, DryRun Mode * Context, Prepared Statment Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map * Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Composite Primary Key * Composite Primary Key
* Auto Migrations * Auto Migrations
* Logger * Logger

View File

@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(BeforeSaveInterface); ok {
called = true called = true
db.AddError(i.BeforeSave(tx)) db.AddError(i.BeforeSave(tx))
} }
} }
if db.Statement.Schema.BeforeCreate { if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok { if i, ok := value.(BeforeCreateInterface); ok {
called = true called = true
db.AddError(i.BeforeCreate(tx)) db.AddError(i.BeforeCreate(tx))
} }
@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {
called = true called = true
db.AddError(i.AfterSave(tx)) db.AddError(i.AfterSave(tx))
} }
} }
if db.Statement.Schema.AfterCreate { if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok { if i, ok := value.(AfterCreateInterface); ok {
called = true called = true
db.AddError(i.AfterCreate(tx)) db.AddError(i.AfterCreate(tx))
} }

View File

@ -11,7 +11,7 @@ import (
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.BeforeDeleteInterface); ok { if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx)) db.AddError(i.BeforeDelete(tx))
return true return true
} }
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.AfterDeleteInterface); ok { if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx)) db.AddError(i.AfterDelete(tx))
return true return true
} }

39
callbacks/interfaces.go Normal file
View File

@ -0,0 +1,39 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

View File

@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(gorm.AfterFindInterface); ok { if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx)) db.AddError(i.AfterFind(tx))
return true return true
} }

View File

@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(BeforeSaveInterface); ok {
called = true called = true
db.AddError(i.BeforeSave(tx)) db.AddError(i.BeforeSave(tx))
} }
} }
if db.Statement.Schema.BeforeUpdate { if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok { if i, ok := value.(BeforeUpdateInterface); ok {
called = true called = true
db.AddError(i.BeforeUpdate(tx)) db.AddError(i.BeforeUpdate(tx))
} }
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {
called = true called = true
db.AddError(i.AfterSave(tx)) db.AddError(i.AfterSave(tx))
} }
} }
if db.Statement.Schema.AfterUpdate { if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok { if i, ok := value.(AfterUpdateInterface); ok {
called = true called = true
db.AddError(i.AfterUpdate(tx)) db.AddError(i.AfterUpdate(tx))
} }

View File

@ -53,38 +53,7 @@ type TxCommitter interface {
Rollback() error Rollback() error
} }
type BeforeCreateInterface interface { // Valuer gorm valuer interface
BeforeCreate(*DB) error type Valuer interface {
} GormValue(context.Context, *DB) clause.Expr
type AfterCreateInterface interface {
AfterCreate(*DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*DB) error
}
type AfterSaveInterface interface {
AfterSave(*DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*DB) error
}
type AfterFindInterface interface {
AfterFind(*DB) error
} }

View File

@ -1,6 +1,8 @@
package schema package schema
import "gorm.io/gorm/clause" import (
"gorm.io/gorm/clause"
)
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDataType() string GormDataType() string

View File

@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
stmt.Vars = append(stmt.Vars, v.Value) stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table: case clause.Column, clause.Table:
stmt.QuoteTo(writer, v) stmt.QuoteTo(writer, v)
case Valuer:
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
case clause.Expr: case clause.Expr:
var varStr strings.Builder var varStr strings.Builder
var sql = v.SQL var sql = v.SQL

View File

@ -1,16 +1,20 @@
package tests_test package tests_test
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) {
type NullString struct { type NullString struct {
sql.NullString sql.NullString
} }
type Point struct {
X, Y int
}
func (point *Point) Scan(v interface{}) error {
return nil
}
func (point Point) GormDataType() string {
return "geo"
}
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
return clause.Expr{
SQL: "ST_PointFromText(?)",
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
}
}
func TestGORMValuer(t *testing.T) {
type UserWithPoint struct {
Name string
Point Point
}
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Create(&UserWithPoint{
Name: "jinzhu",
Point: Point{X: 100, Y: 100},
}).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
}
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
}
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
}
}