mirror of https://github.com/go-gorm/gorm.git
make callback create works
This commit is contained in:
parent
261ece5696
commit
8dd7b4ed91
|
@ -1,15 +1,40 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BeforeCreate(scope *Scope) {
|
||||
scope.CallMethod("BeforeSave")
|
||||
scope.CallMethod("BeforeCreate")
|
||||
}
|
||||
|
||||
func SaveBeforeAssociations(scope *Scope) {
|
||||
func Create(scope *Scope) {
|
||||
defer scope.Trace(time.Now())
|
||||
|
||||
if !scope.HasError() {
|
||||
// set create sql
|
||||
var sqls, columns []string
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsBlank || len(field.SqlTag) == 0 {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, scope.quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Value))
|
||||
}
|
||||
|
||||
func Create(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) VALUES (%v) %v",
|
||||
scope.TableName(),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(sqls, ","),
|
||||
scope.Dialect().ReturningStr(scope.PrimaryKey()),
|
||||
))
|
||||
|
||||
// execute create sql
|
||||
var id interface{}
|
||||
if scope.Dialect().SupportLastInsertId() {
|
||||
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
|
@ -20,22 +45,23 @@ func Create(scope *Scope) {
|
|||
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
||||
}
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.SetColumn(scope.PrimaryKey(), id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterCreate(scope *Scope) {
|
||||
scope.CallMethod("AfterCreate")
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(scope *Scope) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultCallback.Create().Register("begin_transaction", BeginTransaction)
|
||||
DefaultCallback.Create().Register("before_create", BeforeCreate)
|
||||
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
|
||||
DefaultCallback.Create().Register("create", Create)
|
||||
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
|
||||
DefaultCallback.Create().Register("after_create", AfterCreate)
|
||||
DefaultCallback.Create().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
}
|
||||
|
|
|
@ -12,10 +12,7 @@ func BeforeDelete(scope *Scope) {
|
|||
func Delete(scope *Scope) {
|
||||
defer scope.Trace(time.Now())
|
||||
|
||||
if scope.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
if !scope.HasError() {
|
||||
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
|
||||
scope.Raw(
|
||||
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
|
||||
|
@ -29,6 +26,7 @@ func Delete(scope *Scope) {
|
|||
|
||||
scope.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(scope *Scope) {
|
||||
scope.CallMethod("AfterDelete")
|
||||
|
|
|
@ -7,3 +7,9 @@ func BeginTransaction(scope *Scope) {
|
|||
func CommitOrRollbackTransaction(scope *Scope) {
|
||||
scope.CommitOrRollback()
|
||||
}
|
||||
|
||||
func SaveBeforeAssociations(scope *Scope) {
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(scope *Scope) {
|
||||
}
|
||||
|
|
15
do.go
15
do.go
|
@ -311,21 +311,6 @@ func (s *Do) update() *Do {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Do) delete() *Do {
|
||||
s.model.callMethod("BeforeDelete")
|
||||
|
||||
if !s.db.hasError() {
|
||||
if !s.search.unscope && s.model.hasColumn("DeletedAt") {
|
||||
s.setSql(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.table(), s.addToVars(time.Now()), s.combinedSql()))
|
||||
} else {
|
||||
s.setSql(fmt.Sprintf("DELETE FROM %v %v", s.table(), s.combinedSql()))
|
||||
}
|
||||
s.exec()
|
||||
s.model.callMethod("AfterDelete")
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareQuerySql() {
|
||||
if s.search.raw {
|
||||
s.setSql(strings.TrimLeft(s.combinedSql(), "WHERE "))
|
||||
|
|
22
field.go
22
field.go
|
@ -10,9 +10,17 @@ import (
|
|||
|
||||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
Value interface{}
|
||||
model *Model
|
||||
IsBlank bool
|
||||
IsIgnored bool
|
||||
Tag string
|
||||
AddationalTag string
|
||||
Size int
|
||||
SqlTag string
|
||||
|
||||
dbName string
|
||||
model *Model
|
||||
isBlank bool
|
||||
ignoreField bool
|
||||
isPrimaryKey bool
|
||||
|
@ -25,6 +33,16 @@ type Field struct {
|
|||
structField reflect.StructField
|
||||
}
|
||||
|
||||
func (f *Field) IsScanner() bool {
|
||||
_, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
|
||||
return is_scanner
|
||||
}
|
||||
|
||||
func (f *Field) IsTime() bool {
|
||||
_, is_time := f.Value.(time.Time)
|
||||
return is_time
|
||||
}
|
||||
|
||||
func (f *Field) parseBlank() {
|
||||
f.isBlank = isBlank(f.reflectValue)
|
||||
}
|
||||
|
@ -38,7 +56,7 @@ func (f *Field) parseIgnore() {
|
|||
}
|
||||
|
||||
func (f *Field) isScanner() bool {
|
||||
_, is_scanner := reflect.New(f.reflectValue.Type()).Interface().(sql.Scanner)
|
||||
_, is_scanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
|
||||
return is_scanner
|
||||
}
|
||||
|
||||
|
|
5
main.go
5
main.go
|
@ -178,8 +178,13 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
|
|||
}
|
||||
|
||||
func (s *DB) Save(value interface{}) *DB {
|
||||
scope := s.clone().newScope(value)
|
||||
if scope.PrimaryKeyZero() {
|
||||
return scope.callCallbacks(s.parent.callback.creates).db
|
||||
} else {
|
||||
return s.clone().do(value).begin().save().commit_or_rollback().db
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DB) Delete(value interface{}) *DB {
|
||||
return s.clone().newScope(value).callCallbacks(s.parent.callback.deletes).db
|
||||
|
|
72
scope.go
72
scope.go
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"github.com/jinzhu/gorm/dialect"
|
||||
"go/ast"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -149,8 +150,77 @@ func (s *Scope) CombinedConditionSql() string {
|
|||
return s.joinsSql() + s.whereSql() + s.groupSql() + s.havingSql() + s.orderSql() + s.limitSql() + s.offsetSql()
|
||||
}
|
||||
|
||||
func (scope *Scope) SqlTagForField(field *Field) (tag string) {
|
||||
value := field.Value
|
||||
reflect_value := reflect.ValueOf(value)
|
||||
|
||||
if field.IsScanner() {
|
||||
value = reflect_value.Field(0).Interface()
|
||||
}
|
||||
|
||||
switch reflect_value.Kind() {
|
||||
case reflect.Slice:
|
||||
if _, ok := value.([]byte); !ok {
|
||||
return
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !field.IsTime() && !field.IsScanner() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if tag = field.Tag; len(tag) == 0 && tag != "-" {
|
||||
if field.isPrimaryKey {
|
||||
tag = scope.Dialect().PrimaryKeyTag(value, field.Size)
|
||||
} else {
|
||||
tag = scope.Dialect().SqlTag(value, field.Size)
|
||||
}
|
||||
|
||||
if len(field.AddationalTag) > 0 {
|
||||
tag = tag + " " + field.AddationalTag
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) Fields() []*Field {
|
||||
return []*Field{}
|
||||
indirect_value := reflect.Indirect(reflect.ValueOf(scope.Value))
|
||||
fields := []*Field{}
|
||||
|
||||
if !indirect_value.IsValid() {
|
||||
return fields
|
||||
}
|
||||
|
||||
typ := indirect_value.Type()
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field_struct := typ.Field(i)
|
||||
if field_struct.Anonymous || !ast.IsExported(field_struct.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
var field Field
|
||||
field.Name = field_struct.Name
|
||||
field.DBName = toSnake(field_struct.Name)
|
||||
|
||||
value := indirect_value.FieldByName(field_struct.Name)
|
||||
field.Value = value.Interface()
|
||||
field.IsBlank = isBlank(value)
|
||||
|
||||
tag, addational_tag, size := parseSqlTag(field_struct.Tag.Get(scope.db.parent.tagIdentifier))
|
||||
field.Tag = tag
|
||||
field.AddationalTag = addational_tag
|
||||
field.Size = size
|
||||
field.SqlTag = scope.SqlTagForField(&field)
|
||||
|
||||
if tag == "-" {
|
||||
field.IsIgnored = true
|
||||
}
|
||||
|
||||
field.parseAssociation()
|
||||
fields = append(fields, &field)
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func (scope *Scope) Raw(sql string) {
|
||||
|
|
Loading…
Reference in New Issue