forked from mirror/gorm
Add GormValuer interface support
This commit is contained in:
parent
0d96f99499
commit
ce8853e7a6
|
@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||
* Context, Prepared Statment Mode, DryRun Mode
|
||||
* Batch Insert, FindInBatches, Find To Map
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||
* Composite Primary Key
|
||||
* Auto Migrations
|
||||
* Logger
|
||||
|
|
|
@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
|
|||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(gorm.BeforeCreateInterface); ok {
|
||||
if i, ok := value.(BeforeCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
|
@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
|
|||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(gorm.AfterCreateInterface); ok {
|
||||
if i, ok := value.(AfterCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
|
|||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
package callbacks
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*gorm.DB) error
|
||||
}
|
|
@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
|
|||
func AfterQuery(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
|
|||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
|
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
|
|||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
||||
if i, ok := value.(AfterUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
|
|
|
@ -53,38 +53,7 @@ type TxCommitter interface {
|
|||
Rollback() error
|
||||
}
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*DB) error
|
||||
// Valuer gorm valuer interface
|
||||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package schema
|
||||
|
||||
import "gorm.io/gorm/clause"
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
|
|
|
@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||
stmt.Vars = append(stmt.Vars, v.Value)
|
||||
case clause.Column, clause.Table:
|
||||
stmt.QuoteTo(writer, v)
|
||||
case Valuer:
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
case clause.Expr:
|
||||
var varStr strings.Builder
|
||||
var sql = v.SQL
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
|
@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) {
|
|||
type NullString struct {
|
||||
sql.NullString
|
||||
}
|
||||
|
||||
type Point struct {
|
||||
X, Y int
|
||||
}
|
||||
|
||||
func (point *Point) Scan(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (point Point) GormDataType() string {
|
||||
return "geo"
|
||||
}
|
||||
|
||||
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
||||
return clause.Expr{
|
||||
SQL: "ST_PointFromText(?)",
|
||||
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
|
||||
}
|
||||
}
|
||||
|
||||
func TestGORMValuer(t *testing.T) {
|
||||
type UserWithPoint struct {
|
||||
Name string
|
||||
Point Point
|
||||
}
|
||||
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
|
||||
stmt := dryRunDB.Create(&UserWithPoint{
|
||||
Name: "jinzhu",
|
||||
Point: Point{X: 100, Y: 100},
|
||||
}).Statement
|
||||
|
||||
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
|
||||
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue