diff --git a/callback_create.go b/callback_create.go index 38706cf3..b17f596a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -51,9 +51,13 @@ func Create(scope *Scope) { // execute create sql var id interface{} if scope.Dialect().SupportLastInsertId() { - if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - id, err = sql_result.LastInsertId() - scope.Err(err) + if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + id, err = result.LastInsertId() + if scope.Err(err) == nil { + if count, err := result.RowsAffected(); err == nil { + scope.db.RowsAffected = count + } + } } } else { scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) diff --git a/main.go b/main.go index d655855c..e3fd9fe2 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( type DB struct { Value interface{} Error error + RowsAffected int64 callback *callback db sqlCommon parent *DB diff --git a/main_test.go b/main_test.go index 7310dd40..2bdb8471 100644 --- a/main_test.go +++ b/main_test.go @@ -329,24 +329,19 @@ func TestCreateAndUpdate(t *testing.T) { name, name2, new_name := "update", "update2", "new_update" user := User{Name: name, Age: 1, PasswordHash: []byte{'f', 'a', 'k', '4'}} - if !db.NewRecord(user) { + if !db.NewRecord(user) || !db.NewRecord(&user) { t.Error("User should be new record") } - if !db.NewRecord(&user) { - t.Error("User should be new record") + if count := db.Save(&user).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") } - db.Save(&user) if user.Id == 0 { t.Errorf("Should have ID after create") } - if db.NewRecord(user) { - t.Error("User should not new record after save") - } - - if db.NewRecord(&user) { + if db.NewRecord(user) || db.NewRecord(&user) { t.Error("User should not new record after save") } @@ -356,7 +351,9 @@ func TestCreateAndUpdate(t *testing.T) { t.Errorf("User's Password should be saved") } - db.Save(&User{Name: name2, Age: 1}) + if count := db.Save(&User{Name: name2, Age: 1}).RowsAffected; count != 1 { + t.Error("There should be one record be affected when update a record") + } user.Name = new_name db.Save(&user) @@ -1106,6 +1103,12 @@ func TestUpdate(t *testing.T) { if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { t.Error("No error should raise when update_column with CamelCase") } + + var animals []Animal + db.Find(&animals) + if count := db.Model(Animal{}).Update("CreatedAt", time.Now()).RowsAffected; count != int64(len(animals)) { + t.Error("RowsAffected should be correct when do batch update") + } } func TestUpdates(t *testing.T) { diff --git a/scope.go b/scope.go index fbc2ad5a..b6a6c27e 100644 --- a/scope.go +++ b/scope.go @@ -310,8 +310,12 @@ func (scope *Scope) Exec() *Scope { defer scope.Trace(time.Now()) if !scope.HasError() { - _, err := scope.DB().Exec(scope.Sql, scope.SqlVars...) - scope.Err(err) + result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...) + if scope.Err(err) == nil { + if count, err := result.RowsAffected(); err == nil { + scope.db.RowsAffected = count + } + } } return scope }