mirror of https://github.com/go-gorm/gorm.git
make callback create works
This commit is contained in:
parent
261ece5696
commit
8dd7b4ed91
|
@ -1,15 +1,40 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
func BeforeCreate(scope *Scope) {
|
func BeforeCreate(scope *Scope) {
|
||||||
scope.CallMethod("BeforeSave")
|
scope.CallMethod("BeforeSave")
|
||||||
scope.CallMethod("BeforeCreate")
|
scope.CallMethod("BeforeCreate")
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveBeforeAssociations(scope *Scope) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func Create(scope *Scope) {
|
func Create(scope *Scope) {
|
||||||
|
defer scope.Trace(time.Now())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
|
// set create sql
|
||||||
|
var sqls, columns []string
|
||||||
|
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if field.IsBlank || len(field.SqlTag) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
columns = append(columns, scope.quote(field.DBName))
|
||||||
|
sqls = append(sqls, scope.AddToVars(field.Value))
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||||
|
scope.TableName(),
|
||||||
|
strings.Join(columns, ","),
|
||||||
|
strings.Join(sqls, ","),
|
||||||
|
scope.Dialect().ReturningStr(scope.PrimaryKey()),
|
||||||
|
))
|
||||||
|
|
||||||
|
// execute create sql
|
||||||
var id interface{}
|
var id interface{}
|
||||||
if scope.Dialect().SupportLastInsertId() {
|
if scope.Dialect().SupportLastInsertId() {
|
||||||
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||||
|
@ -20,7 +45,9 @@ func Create(scope *Scope) {
|
||||||
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.SetColumn(scope.PrimaryKey(), id)
|
if !scope.HasError() {
|
||||||
|
scope.SetColumn(scope.PrimaryKey(), id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,13 +56,12 @@ func AfterCreate(scope *Scope) {
|
||||||
scope.CallMethod("AfterSave")
|
scope.CallMethod("AfterSave")
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveAfterAssociations(scope *Scope) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
DefaultCallback.Create().Register("begin_transaction", BeginTransaction)
|
||||||
DefaultCallback.Create().Register("before_create", BeforeCreate)
|
DefaultCallback.Create().Register("before_create", BeforeCreate)
|
||||||
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
|
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
|
||||||
DefaultCallback.Create().Register("create", Create)
|
DefaultCallback.Create().Register("create", Create)
|
||||||
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
|
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
|
||||||
DefaultCallback.Create().Register("after_create", AfterCreate)
|
DefaultCallback.Create().Register("after_create", AfterCreate)
|
||||||
|
DefaultCallback.Create().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,22 +12,20 @@ func BeforeDelete(scope *Scope) {
|
||||||
func Delete(scope *Scope) {
|
func Delete(scope *Scope) {
|
||||||
defer scope.Trace(time.Now())
|
defer scope.Trace(time.Now())
|
||||||
|
|
||||||
if scope.HasError() {
|
if !scope.HasError() {
|
||||||
return
|
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
|
||||||
}
|
scope.Raw(
|
||||||
|
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
|
||||||
|
scope.TableName(),
|
||||||
|
scope.AddToVars(time.Now()),
|
||||||
|
scope.CombinedConditionSql(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql()))
|
||||||
|
}
|
||||||
|
|
||||||
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
|
scope.Exec()
|
||||||
scope.Raw(
|
|
||||||
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
|
|
||||||
scope.TableName(),
|
|
||||||
scope.AddToVars(time.Now()),
|
|
||||||
scope.CombinedConditionSql(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scope.Exec()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterDelete(scope *Scope) {
|
func AfterDelete(scope *Scope) {
|
||||||
|
|
|
@ -7,3 +7,9 @@ func BeginTransaction(scope *Scope) {
|
||||||
func CommitOrRollbackTransaction(scope *Scope) {
|
func CommitOrRollbackTransaction(scope *Scope) {
|
||||||
scope.CommitOrRollback()
|
scope.CommitOrRollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveBeforeAssociations(scope *Scope) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveAfterAssociations(scope *Scope) {
|
||||||
|
}
|
||||||
|
|
15
do.go
15
do.go
|
@ -311,21 +311,6 @@ func (s *Do) update() *Do {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) delete() *Do {
|
|
||||||
s.model.callMethod("BeforeDelete")
|
|
||||||
|
|
||||||
if !s.db.hasError() {
|
|
||||||
if !s.search.unscope && s.model.hasColumn("DeletedAt") {
|
|
||||||
s.setSql(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql()))
|
|
||||||
} else {
|
|
||||||
s.setSql(fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql()))
|
|
||||||
}
|
|
||||||
s.exec()
|
|
||||||
s.model.callMethod("AfterDelete")
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Do) prepareQuerySql() {
|
func (s *Do) prepareQuerySql() {
|
||||||
if s.search.raw {
|
if s.search.raw {
|
||||||
s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE "))
|
s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE "))
|
||||||
|
|
26
field.go
26
field.go
|
@ -9,10 +9,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Field struct {
|
type Field struct {
|
||||||
Name string
|
Name string
|
||||||
Value interface{}
|
DBName string
|
||||||
model *Model
|
Value interface{}
|
||||||
|
IsBlank bool
|
||||||
|
IsIgnored bool
|
||||||
|
Tag string
|
||||||
|
AddationalTag string
|
||||||
|
Size int
|
||||||
|
SqlTag string
|
||||||
|
|
||||||
dbName string
|
dbName string
|
||||||
|
model *Model
|
||||||
isBlank bool
|
isBlank bool
|
||||||
ignoreField bool
|
ignoreField bool
|
||||||
isPrimaryKey bool
|
isPrimaryKey bool
|
||||||
|
@ -25,6 +33,16 @@ type Field struct {
|
||||||
structField reflect.StructField
|
structField reflect.StructField
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Field) IsScanner() bool {
|
||||||
|
_, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
|
||||||
|
return is_scanner
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Field) IsTime() bool {
|
||||||
|
_, is_time := f.Value.(time.Time)
|
||||||
|
return is_time
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Field) parseBlank() {
|
func (f *Field) parseBlank() {
|
||||||
f.isBlank = isBlank(f.reflectValue)
|
f.isBlank = isBlank(f.reflectValue)
|
||||||
}
|
}
|
||||||
|
@ -38,7 +56,7 @@ func (f *Field) parseIgnore() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Field) isScanner() bool {
|
func (f *Field) isScanner() bool {
|
||||||
_, is_scanner := reflect.New(f.reflectValue.Type()).Interface().(sql.Scanner)
|
_, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
|
||||||
return is_scanner
|
return is_scanner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
7
main.go
7
main.go
|
@ -178,7 +178,12 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
return s.clone().do(value).begin().save().commit_or_rollback().db
|
scope := s.clone().newScope(value)
|
||||||
|
if scope.PrimaryKeyZero() {
|
||||||
|
return scope.callCallbacks(s.parent.callback.creates).db
|
||||||
|
} else {
|
||||||
|
return s.clone().do(value).begin().save().commit_or_rollback().db
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Delete(value interface{}) *DB {
|
func (s *DB) Delete(value interface{}) *DB {
|
||||||
|
|
72
scope.go
72
scope.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jinzhu/gorm/dialect"
|
"github.com/jinzhu/gorm/dialect"
|
||||||
|
"go/ast"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -149,8 +150,77 @@ func (s *Scope) CombinedConditionSql() string {
|
||||||
return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql()
|
return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) SqlTagForField(field *Field) (tag string) {
|
||||||
|
value := field.Value
|
||||||
|
reflect_value := reflect.ValueOf(value)
|
||||||
|
|
||||||
|
if field.IsScanner() {
|
||||||
|
value = reflect_value.Field(0).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reflect_value.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if _, ok := value.([]byte); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if !field.IsTime() && !field.IsScanner() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tag = field.Tag; len(tag) == 0 && tag != "-" {
|
||||||
|
if field.isPrimaryKey {
|
||||||
|
tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
|
||||||
|
} else {
|
||||||
|
tag = scope.Dialect().SqlTag(value, field.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(field.AddationalTag) > 0 {
|
||||||
|
tag = tag + " " + field.AddationalTag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) Fields() []*Field {
|
func (scope *Scope) Fields() []*Field {
|
||||||
return []*Field{}
|
indirect_value := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||||
|
fields := []*Field{}
|
||||||
|
|
||||||
|
if !indirect_value.IsValid() {
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := indirect_value.Type()
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field_struct := typ.Field(i)
|
||||||
|
if field_struct.Anonymous || !ast.IsExported(field_struct.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var field Field
|
||||||
|
field.Name = field_struct.Name
|
||||||
|
field.DBName = toSnake(field_struct.Name)
|
||||||
|
|
||||||
|
value := indirect_value.FieldByName(field_struct.Name)
|
||||||
|
field.Value = value.Interface()
|
||||||
|
field.IsBlank = isBlank(value)
|
||||||
|
|
||||||
|
tag, addational_tag, size := parseSqlTag(field_struct.Tag.Get(scope.db.parent.tagIdentifier))
|
||||||
|
field.Tag = tag
|
||||||
|
field.AddationalTag = addational_tag
|
||||||
|
field.Size = size
|
||||||
|
field.SqlTag = scope.SqlTagForField(&field)
|
||||||
|
|
||||||
|
if tag == "-" {
|
||||||
|
field.IsIgnored = true
|
||||||
|
}
|
||||||
|
|
||||||
|
field.parseAssociation()
|
||||||
|
fields = append(fields, &field)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) Raw(sql string) {
|
func (scope *Scope) Raw(sql string) {
|
||||||
|
|
Loading…
Reference in New Issue