use callback update when save

This commit is contained in:
Jinzhu 2014-01-27 11:06:13 +08:00
parent 3981baf65d
commit 23feade663
5 changed files with 41 additions and 16 deletions

View File

@ -11,13 +11,17 @@ func BeforeCreate(scope *Scope) {
scope.CallMethod("BeforeCreate")
}
func UpdateCreateTimeStamp(scope *Scope) {
if !scope.HasError() {
scope.SetColumn("CreatedAt", time.Now())
scope.SetColumn("UpdatedAt", time.Now())
}
}
func Create(scope *Scope) {
defer scope.Trace(time.Now())
if !scope.HasError() {
scope.SetColumn("CreatedAt", time.Now())
scope.SetColumn("UpdatedAt", time.Now())
// set create sql
var sqls, columns []string
@ -62,6 +66,7 @@ func init() {
DefaultCallback.Create().Register("begin_transaction", BeginTransaction)
DefaultCallback.Create().Register("before_create", BeforeCreate)
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("update_create_time_stamp", UpdateCreateTimeStamp)
DefaultCallback.Create().Register("create", Create)
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("after_create", AfterCreate)

View File

@ -1,23 +1,38 @@
package gorm
import (
"fmt"
"strings"
"time"
)
func BeforeUpdate(scope *Scope) {
scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeUpdate")
}
func UpdateTimeStampWhenUpdate(scope *Scope) {
if !scope.HasError() {
scope.SetColumn("UpdatedAt", time.Now())
}
}
func Update(scope *Scope) {
if !scope.HasError() {
var id interface{}
if scope.Dialect().SupportLastInsertId() {
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err = sql_result.LastInsertId()
scope.Err(err)
var sqls []string
for _, field := range scope.Fields() {
if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value)))
}
} else {
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
}
scope.SetColumn(scope.PrimaryKey(), id)
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.TableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
}
}
@ -27,9 +42,12 @@ func AfterUpdate(scope *Scope) {
}
func init() {
DefaultCallback.Update().Register("begin_transaction", BeginTransaction)
DefaultCallback.Update().Register("before_update", BeforeUpdate)
DefaultCallback.Update().Register("save_before_associations", SaveBeforeAssociations)
DefaultCallback.Update().Register("update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
DefaultCallback.Update().Register("update", Update)
DefaultCallback.Update().Register("save_after_associations", SaveAfterAssociations)
DefaultCallback.Update().Register("after_update", AfterUpdate)
DefaultCallback.Update().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction)
}

View File

@ -1318,7 +1318,7 @@ func TestRelated(t *testing.T) {
var credit_card CreditCard
var user3 User
db.First(&credit_card, "number = ?", "1234567890")
db.Debug().Model(&credit_card).Related(&user3)
db.Model(&credit_card).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly")
}

View File

@ -164,8 +164,9 @@ func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true)
}
func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB {
return s.clone().do(s.Value).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback().db
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.clone().do(s.Value).begin().updateAttrs(values, ignoreProtectedAttrs...).update().commit_or_rollback().db
// return s.clone().NewScope(s.Value).callCallbacks(s.parent.callback.updates).db
}
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
@ -179,9 +180,9 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) *
func (s *DB) Save(value interface{}) *DB {
scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db.do(value).db
return scope.callCallbacks(s.parent.callback.creates).db
} else {
return s.clone().do(value).begin().save().commit_or_rollback().db
return scope.callCallbacks(s.parent.callback.updates).db
}
}

View File

@ -22,6 +22,7 @@ type Scope struct {
}
func (db *DB) NewScope(value interface{}) *Scope {
db.Value = value
return &Scope{db: db, Search: db.search, Value: value}
}