diff --git a/README.md b/README.md index c727e2cf..9c0aded0 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key * Auto Migrations * Logger diff --git a/callbacks/create.go b/callbacks/create.go index 7a32ed5c..cc7e2671 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { - if i, ok := value.(gorm.BeforeCreateInterface); ok { + if i, ok := value.(BeforeCreateInterface); ok { called = true db.AddError(i.BeforeCreate(tx)) } @@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { - if i, ok := value.(gorm.AfterCreateInterface); ok { + if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 76b78fb4..e95117a1 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -11,7 +11,7 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { + if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) return true } @@ -75,7 +75,7 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterDeleteInterface); ok { + if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) return true } diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/query.go b/callbacks/query.go index f6cb32d5..0703b92e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -214,7 +214,7 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterFindInterface); ok { + if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) return true } diff --git a/callbacks/update.go b/callbacks/update.go index bd8a4150..73c062b4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(gorm.BeforeUpdateInterface); ok { + if i, ok := value.(BeforeUpdateInterface); ok { called = true db.AddError(i.BeforeUpdate(tx)) } @@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { - if i, ok := value.(gorm.AfterUpdateInterface); ok { + if i, ok := value.(AfterUpdateInterface); ok { called = true db.AddError(i.AfterUpdate(tx)) } diff --git a/interfaces.go b/interfaces.go index b2ce59b3..e933952b 100644 --- a/interfaces.go +++ b/interfaces.go @@ -53,38 +53,7 @@ type TxCommitter interface { Rollback() error } -type BeforeCreateInterface interface { - BeforeCreate(*DB) error -} - -type AfterCreateInterface interface { - AfterCreate(*DB) error -} - -type BeforeUpdateInterface interface { - BeforeUpdate(*DB) error -} - -type AfterUpdateInterface interface { - AfterUpdate(*DB) error -} - -type BeforeSaveInterface interface { - BeforeSave(*DB) error -} - -type AfterSaveInterface interface { - AfterSave(*DB) error -} - -type BeforeDeleteInterface interface { - BeforeDelete(*DB) error -} - -type AfterDeleteInterface interface { - AfterDelete(*DB) error -} - -type AfterFindInterface interface { - AfterFind(*DB) error +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr } diff --git a/schema/interfaces.go b/schema/interfaces.go index e8e51e4c..98abffbd 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -1,6 +1,8 @@ package schema -import "gorm.io/gorm/clause" +import ( + "gorm.io/gorm/clause" +) type GormDataTypeInterface interface { GormDataType() string diff --git a/statement.go b/statement.go index 95d23fa5..fba1991d 100644 --- a/statement.go +++ b/statement.go @@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: var varStr strings.Builder var sql = v.SQL diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ce8a2b50..ec16ccf6 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -1,16 +1,20 @@ package tests_test import ( + "context" "database/sql" "database/sql/driver" "encoding/json" "errors" + "fmt" "reflect" + "regexp" "strconv" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) { type NullString struct { sql.NullString } + +type Point struct { + X, Y int +} + +func (point *Point) Scan(v interface{}) error { + return nil +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +}