forked from mirror/gorm
Add GormValuer interface support
This commit is contained in:
parent
0d96f99499
commit
ce8853e7a6
|
@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||||
* Context, Prepared Statment Mode, DryRun Mode
|
* Context, Prepared Statment Mode, DryRun Mode
|
||||||
* Batch Insert, FindInBatches, Find To Map
|
* Batch Insert, FindInBatches, Find To Map
|
||||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg
|
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||||
* Composite Primary Key
|
* Composite Primary Key
|
||||||
* Auto Migrations
|
* Auto Migrations
|
||||||
* Logger
|
* Logger
|
||||||
|
|
|
@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
if i, ok := value.(BeforeSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeSave(tx))
|
db.AddError(i.BeforeSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.BeforeCreate {
|
if db.Statement.Schema.BeforeCreate {
|
||||||
if i, ok := value.(gorm.BeforeCreateInterface); ok {
|
if i, ok := value.(BeforeCreateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeCreate(tx))
|
db.AddError(i.BeforeCreate(tx))
|
||||||
}
|
}
|
||||||
|
@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterSave(tx))
|
db.AddError(i.AfterSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.AfterCreate {
|
if db.Statement.Schema.AfterCreate {
|
||||||
if i, ok := value.(gorm.AfterCreateInterface); ok {
|
if i, ok := value.(AfterCreateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterCreate(tx))
|
db.AddError(i.AfterCreate(tx))
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
func BeforeDelete(db *gorm.DB) {
|
func BeforeDelete(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||||
db.AddError(i.BeforeDelete(tx))
|
db.AddError(i.BeforeDelete(tx))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -75,7 +75,7 @@ func Delete(db *gorm.DB) {
|
||||||
func AfterDelete(db *gorm.DB) {
|
func AfterDelete(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
if i, ok := value.(AfterDeleteInterface); ok {
|
||||||
db.AddError(i.AfterDelete(tx))
|
db.AddError(i.AfterDelete(tx))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
package callbacks
|
||||||
|
|
||||||
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
|
type BeforeCreateInterface interface {
|
||||||
|
BeforeCreate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterCreateInterface interface {
|
||||||
|
AfterCreate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeforeUpdateInterface interface {
|
||||||
|
BeforeUpdate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterUpdateInterface interface {
|
||||||
|
AfterUpdate(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeforeSaveInterface interface {
|
||||||
|
BeforeSave(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterSaveInterface interface {
|
||||||
|
AfterSave(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeforeDeleteInterface interface {
|
||||||
|
BeforeDelete(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterDeleteInterface interface {
|
||||||
|
AfterDelete(*gorm.DB) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AfterFindInterface interface {
|
||||||
|
AfterFind(*gorm.DB) error
|
||||||
|
}
|
|
@ -214,7 +214,7 @@ func Preload(db *gorm.DB) {
|
||||||
func AfterQuery(db *gorm.DB) {
|
func AfterQuery(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
if i, ok := value.(AfterFindInterface); ok {
|
||||||
db.AddError(i.AfterFind(tx))
|
db.AddError(i.AfterFind(tx))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
if i, ok := value.(BeforeSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeSave(tx))
|
db.AddError(i.BeforeSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.BeforeUpdate {
|
if db.Statement.Schema.BeforeUpdate {
|
||||||
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
|
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.BeforeUpdate(tx))
|
db.AddError(i.BeforeUpdate(tx))
|
||||||
}
|
}
|
||||||
|
@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterSave(tx))
|
db.AddError(i.AfterSave(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.AfterUpdate {
|
if db.Statement.Schema.AfterUpdate {
|
||||||
if i, ok := value.(gorm.AfterUpdateInterface); ok {
|
if i, ok := value.(AfterUpdateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterUpdate(tx))
|
db.AddError(i.AfterUpdate(tx))
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,38 +53,7 @@ type TxCommitter interface {
|
||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type BeforeCreateInterface interface {
|
// Valuer gorm valuer interface
|
||||||
BeforeCreate(*DB) error
|
type Valuer interface {
|
||||||
}
|
GormValue(context.Context, *DB) clause.Expr
|
||||||
|
|
||||||
type AfterCreateInterface interface {
|
|
||||||
AfterCreate(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type BeforeUpdateInterface interface {
|
|
||||||
BeforeUpdate(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterUpdateInterface interface {
|
|
||||||
AfterUpdate(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type BeforeSaveInterface interface {
|
|
||||||
BeforeSave(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterSaveInterface interface {
|
|
||||||
AfterSave(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type BeforeDeleteInterface interface {
|
|
||||||
BeforeDelete(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterDeleteInterface interface {
|
|
||||||
AfterDelete(*DB) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type AfterFindInterface interface {
|
|
||||||
AfterFind(*DB) error
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package schema
|
package schema
|
||||||
|
|
||||||
import "gorm.io/gorm/clause"
|
import (
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
type GormDataTypeInterface interface {
|
type GormDataTypeInterface interface {
|
||||||
GormDataType() string
|
GormDataType() string
|
||||||
|
|
|
@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||||
stmt.Vars = append(stmt.Vars, v.Value)
|
stmt.Vars = append(stmt.Vars, v.Value)
|
||||||
case clause.Column, clause.Table:
|
case clause.Column, clause.Table:
|
||||||
stmt.QuoteTo(writer, v)
|
stmt.QuoteTo(writer, v)
|
||||||
|
case Valuer:
|
||||||
|
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||||
case clause.Expr:
|
case clause.Expr:
|
||||||
var varStr strings.Builder
|
var varStr strings.Builder
|
||||||
var sql = v.SQL
|
var sql = v.SQL
|
||||||
|
|
|
@ -1,16 +1,20 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) {
|
||||||
type NullString struct {
|
type NullString struct {
|
||||||
sql.NullString
|
sql.NullString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Point struct {
|
||||||
|
X, Y int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (point *Point) Scan(v interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (point Point) GormDataType() string {
|
||||||
|
return "geo"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
||||||
|
return clause.Expr{
|
||||||
|
SQL: "ST_PointFromText(?)",
|
||||||
|
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGORMValuer(t *testing.T) {
|
||||||
|
type UserWithPoint struct {
|
||||||
|
Name string
|
||||||
|
Point Point
|
||||||
|
}
|
||||||
|
|
||||||
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
|
||||||
|
stmt := dryRunDB.Create(&UserWithPoint{
|
||||||
|
Name: "jinzhu",
|
||||||
|
Point: Point{X: 100, Y: 100},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue