forked from mirror/gorm
use callback update when save
This commit is contained in:
parent
3981baf65d
commit
23feade663
|
@ -11,13 +11,17 @@ func BeforeCreate(scope *Scope) {
|
|||
scope.CallMethod("BeforeCreate")
|
||||
}
|
||||
|
||||
func UpdateCreateTimeStamp(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.SetColumn("CreatedAt", time.Now())
|
||||
scope.SetColumn("UpdatedAt", time.Now())
|
||||
}
|
||||
}
|
||||
|
||||
func Create(scope *Scope) {
|
||||
defer scope.Trace(time.Now())
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.SetColumn("CreatedAt", time.Now())
|
||||
scope.SetColumn("UpdatedAt", time.Now())
|
||||
|
||||
// set create sql
|
||||
var sqls, columns []string
|
||||
|
||||
|
@ -62,6 +66,7 @@ func init() {
|
|||
DefaultCallback.Create().Register("begin_transaction", BeginTransaction)
|
||||
DefaultCallback.Create().Register("before_create", BeforeCreate)
|
||||
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
|
||||
DefaultCallback.Create().Register("update_create_time_stamp", UpdateCreateTimeStamp)
|
||||
DefaultCallback.Create().Register("create", Create)
|
||||
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
|
||||
DefaultCallback.Create().Register("after_create", AfterCreate)
|
||||
|
|
|
@ -1,23 +1,38 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BeforeUpdate(scope *Scope) {
|
||||
scope.CallMethod("BeforeSave")
|
||||
scope.CallMethod("BeforeUpdate")
|
||||
}
|
||||
|
||||
func UpdateTimeStampWhenUpdate(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.SetColumn("UpdatedAt", time.Now())
|
||||
}
|
||||
}
|
||||
|
||||
func Update(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var id interface{}
|
||||
if scope.Dialect().SupportLastInsertId() {
|
||||
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
id, err = sql_result.LastInsertId()
|
||||
scope.Err(err)
|
||||
var sqls []string
|
||||
for _, field := range scope.Fields() {
|
||||
if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
}
|
||||
} else {
|
||||
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
||||
}
|
||||
|
||||
scope.SetColumn(scope.PrimaryKey(), id)
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
scope.TableName(),
|
||||
strings.Join(sqls, ", "),
|
||||
scope.CombinedConditionSql(),
|
||||
))
|
||||
scope.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,9 +42,12 @@ func AfterUpdate(scope *Scope) {
|
|||
}
|
||||
|
||||
func init() {
|
||||
DefaultCallback.Update().Register("begin_transaction", BeginTransaction)
|
||||
DefaultCallback.Update().Register("before_update", BeforeUpdate)
|
||||
DefaultCallback.Update().Register("save_before_associations", SaveBeforeAssociations)
|
||||
DefaultCallback.Update().Register("update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
|
||||
DefaultCallback.Update().Register("update", Update)
|
||||
DefaultCallback.Update().Register("save_after_associations", SaveAfterAssociations)
|
||||
DefaultCallback.Update().Register("after_update", AfterUpdate)
|
||||
DefaultCallback.Update().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
}
|
||||
|
|
|
@ -1318,7 +1318,7 @@ func TestRelated(t *testing.T) {
|
|||
var credit_card CreditCard
|
||||
var user3 User
|
||||
db.First(&credit_card, "number = ?", "1234567890")
|
||||
db.Debug().Model(&credit_card).Related(&user3)
|
||||
db.Model(&credit_card).Related(&user3)
|
||||
if user3.Id != user.Id || user3.Name != user.Name {
|
||||
t.Errorf("Should get user from credit card correctly")
|
||||
}
|
||||
|
|
9
main.go
9
main.go
|
@ -164,8 +164,9 @@ func (s *DB) Update(attrs ...interface{}) *DB {
|
|||
return s.Updates(toSearchableMap(attrs...), true)
|
||||
}
|
||||
|
||||
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
|
||||
return s.clone().do(s.Value).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback().db
|
||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||
return s.clone().do(s.Value).begin().updateAttrs(values, ignoreProtectedAttrs...).update().commit_or_rollback().db
|
||||
// return s.clone().NewScope(s.Value).callCallbacks(s.parent.callback.updates).db
|
||||
}
|
||||
|
||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||
|
@ -179,9 +180,9 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
|
|||
func (s *DB) Save(value interface{}) *DB {
|
||||
scope := s.clone().NewScope(value)
|
||||
if scope.PrimaryKeyZero() {
|
||||
return scope.callCallbacks(s.parent.callback.creates).db.do(value).db
|
||||
return scope.callCallbacks(s.parent.callback.creates).db
|
||||
} else {
|
||||
return s.clone().do(value).begin().save().commit_or_rollback().db
|
||||
return scope.callCallbacks(s.parent.callback.updates).db
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue