mirror of https://github.com/go-gorm/gorm.git
Add RowsAffected for DB
This commit is contained in:
parent
319af32c78
commit
f32fa0cb6e
|
@ -51,9 +51,13 @@ func Create(scope *Scope) {
|
||||||
// execute create sql
|
// execute create sql
|
||||||
var id interface{}
|
var id interface{}
|
||||||
if scope.Dialect().SupportLastInsertId() {
|
if scope.Dialect().SupportLastInsertId() {
|
||||||
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||||
id, err = sql_result.LastInsertId()
|
id, err = result.LastInsertId()
|
||||||
scope.Err(err)
|
if scope.Err(err) == nil {
|
||||||
|
if count, err := result.RowsAffected(); err == nil {
|
||||||
|
scope.db.RowsAffected = count
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
|
||||||
|
|
1
main.go
1
main.go
|
@ -7,6 +7,7 @@ import (
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Error error
|
Error error
|
||||||
|
RowsAffected int64
|
||||||
callback *callback
|
callback *callback
|
||||||
db sqlCommon
|
db sqlCommon
|
||||||
parent *DB
|
parent *DB
|
||||||
|
|
23
main_test.go
23
main_test.go
|
@ -329,24 +329,19 @@ func TestCreateAndUpdate(t *testing.T) {
|
||||||
name, name2, new_name := "update", "update2", "new_update"
|
name, name2, new_name := "update", "update2", "new_update"
|
||||||
user := User{Name: name, Age: 1, PasswordHash: []byte{'f', 'a', 'k', '4'}}
|
user := User{Name: name, Age: 1, PasswordHash: []byte{'f', 'a', 'k', '4'}}
|
||||||
|
|
||||||
if !db.NewRecord(user) {
|
if !db.NewRecord(user) || !db.NewRecord(&user) {
|
||||||
t.Error("User should be new record")
|
t.Error("User should be new record")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !db.NewRecord(&user) {
|
if count := db.Save(&user).RowsAffected; count != 1 {
|
||||||
t.Error("User should be new record")
|
t.Error("There should be one record be affected when create record")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Save(&user)
|
|
||||||
if user.Id == 0 {
|
if user.Id == 0 {
|
||||||
t.Errorf("Should have ID after create")
|
t.Errorf("Should have ID after create")
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.NewRecord(user) {
|
if db.NewRecord(user) || db.NewRecord(&user) {
|
||||||
t.Error("User should not new record after save")
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.NewRecord(&user) {
|
|
||||||
t.Error("User should not new record after save")
|
t.Error("User should not new record after save")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,7 +351,9 @@ func TestCreateAndUpdate(t *testing.T) {
|
||||||
t.Errorf("User's Password should be saved")
|
t.Errorf("User's Password should be saved")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Save(&User{Name: name2, Age: 1})
|
if count := db.Save(&User{Name: name2, Age: 1}).RowsAffected; count != 1 {
|
||||||
|
t.Error("There should be one record be affected when update a record")
|
||||||
|
}
|
||||||
|
|
||||||
user.Name = new_name
|
user.Name = new_name
|
||||||
db.Save(&user)
|
db.Save(&user)
|
||||||
|
@ -1106,6 +1103,12 @@ func TestUpdate(t *testing.T) {
|
||||||
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
if db.Model(&animal2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||||
t.Error("No error should raise when update_column with CamelCase")
|
t.Error("No error should raise when update_column with CamelCase")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var animals []Animal
|
||||||
|
db.Find(&animals)
|
||||||
|
if count := db.Model(Animal{}).Update("CreatedAt", time.Now()).RowsAffected; count != int64(len(animals)) {
|
||||||
|
t.Error("RowsAffected should be correct when do batch update")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdates(t *testing.T) {
|
func TestUpdates(t *testing.T) {
|
||||||
|
|
8
scope.go
8
scope.go
|
@ -310,8 +310,12 @@ func (scope *Scope) Exec() *Scope {
|
||||||
defer scope.Trace(time.Now())
|
defer scope.Trace(time.Now())
|
||||||
|
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
_, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
|
||||||
scope.Err(err)
|
if scope.Err(err) == nil {
|
||||||
|
if count, err := result.RowsAffected(); err == nil {
|
||||||
|
scope.db.RowsAffected = count
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue