diff --git a/callback_create.go b/callback_create.go index fec5fe43..5ae83d63 100644 --- a/callback_create.go +++ b/callback_create.go @@ -11,13 +11,17 @@ func BeforeCreate(scope *Scope) { scope.CallMethod("BeforeCreate") } +func UpdateCreateTimeStamp(scope *Scope) { + if !scope.HasError() { + scope.SetColumn("CreatedAt", time.Now()) + scope.SetColumn("UpdatedAt", time.Now()) + } +} + func Create(scope *Scope) { defer scope.Trace(time.Now()) if !scope.HasError() { - scope.SetColumn("CreatedAt", time.Now()) - scope.SetColumn("UpdatedAt", time.Now()) - // set create sql var sqls, columns []string @@ -62,6 +66,7 @@ func init() { DefaultCallback.Create().Register("begin_transaction", BeginTransaction) DefaultCallback.Create().Register("before_create", BeforeCreate) DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations) + DefaultCallback.Create().Register("update_create_time_stamp", UpdateCreateTimeStamp) DefaultCallback.Create().Register("create", Create) DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations) DefaultCallback.Create().Register("after_create", AfterCreate) diff --git a/callback_update.go b/callback_update.go index 1cfe6d8d..e5eb9382 100644 --- a/callback_update.go +++ b/callback_update.go @@ -1,23 +1,38 @@ package gorm +import ( + "fmt" + "strings" + "time" +) + func BeforeUpdate(scope *Scope) { scope.CallMethod("BeforeSave") scope.CallMethod("BeforeUpdate") } +func UpdateTimeStampWhenUpdate(scope *Scope) { + if !scope.HasError() { + scope.SetColumn("UpdatedAt", time.Now()) + } +} + func Update(scope *Scope) { if !scope.HasError() { - var id interface{} - if scope.Dialect().SupportLastInsertId() { - if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - id, err = sql_result.LastInsertId() - scope.Err(err) + var sqls []string + for _, field := range scope.Fields() { + if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value))) } - } else { - scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) } - scope.SetColumn(scope.PrimaryKey(), id) + scope.Raw(fmt.Sprintf( + "UPDATE %v SET %v %v", + scope.TableName(), + strings.Join(sqls, ", "), + scope.CombinedConditionSql(), + )) + scope.Exec() } } @@ -27,9 +42,12 @@ func AfterUpdate(scope *Scope) { } func init() { + DefaultCallback.Update().Register("begin_transaction", BeginTransaction) DefaultCallback.Update().Register("before_update", BeforeUpdate) DefaultCallback.Update().Register("save_before_associations", SaveBeforeAssociations) + DefaultCallback.Update().Register("update_time_stamp_when_update", UpdateTimeStampWhenUpdate) DefaultCallback.Update().Register("update", Update) DefaultCallback.Update().Register("save_after_associations", SaveAfterAssociations) DefaultCallback.Update().Register("after_update", AfterUpdate) + DefaultCallback.Update().Register("commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/gorm_test.go b/gorm_test.go index 45dc9b8f..78bf1d3b 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1318,7 +1318,7 @@ func TestRelated(t *testing.T) { var credit_card CreditCard var user3 User db.First(&credit_card, "number = ?", "1234567890") - db.Debug().Model(&credit_card).Related(&user3) + db.Model(&credit_card).Related(&user3) if user3.Id != user.Id || user3.Name != user.Name { t.Errorf("Should get user from credit card correctly") } diff --git a/main.go b/main.go index 717e535c..c3c44601 100644 --- a/main.go +++ b/main.go @@ -164,8 +164,9 @@ func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } -func (s *DB) Updates(values interface{}, ignore_protected_attrs ...bool) *DB { - return s.clone().do(s.Value).begin().updateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback().db +func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { + return s.clone().do(s.Value).begin().updateAttrs(values, ignoreProtectedAttrs...).update().commit_or_rollback().db + // return s.clone().NewScope(s.Value).callCallbacks(s.parent.callback.updates).db } func (s *DB) UpdateColumn(attrs ...interface{}) *DB { @@ -179,9 +180,9 @@ func (s *DB) UpdateColumns(values interface{}, ignore_protected_attrs ...bool) * func (s *DB) Save(value interface{}) *DB { scope := s.clone().NewScope(value) if scope.PrimaryKeyZero() { - return scope.callCallbacks(s.parent.callback.creates).db.do(value).db + return scope.callCallbacks(s.parent.callback.creates).db } else { - return s.clone().do(value).begin().save().commit_or_rollback().db + return scope.callCallbacks(s.parent.callback.updates).db } } diff --git a/scope.go b/scope.go index 4e5eca3e..3ea4646c 100644 --- a/scope.go +++ b/scope.go @@ -22,6 +22,7 @@ type Scope struct { } func (db *DB) NewScope(value interface{}) *Scope { + db.Value = value return &Scope{db: db, Search: db.search, Value: value} }