forked from mirror/gorm
Fix Scopes with Row, close #4465
This commit is contained in:
parent
3226937f68
commit
8e67a08774
|
@ -373,7 +373,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
|
||||||
})
|
})
|
||||||
|
|
||||||
if tx.Statement.FullSaveAssociations {
|
if tx.Statement.FullSaveAssociations {
|
||||||
tx = tx.InstanceSet("gorm:update_track_time", true)
|
tx = tx.Set("gorm:update_track_time", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(selects) > 0 {
|
if len(selects) > 0 {
|
||||||
|
|
|
@ -243,9 +243,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
|
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||||
curTime = stmt.DB.NowFunc()
|
curTime = stmt.DB.NowFunc()
|
||||||
isZero bool
|
isZero bool
|
||||||
)
|
)
|
||||||
|
stmt.Settings.Delete("gorm:update_track_time")
|
||||||
|
|
||||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||||
|
|
||||||
for _, db := range stmt.Schema.DBNames {
|
for _, db := range stmt.Schema.DBNames {
|
||||||
|
@ -284,11 +287,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
field.Set(rv, curTime)
|
field.Set(rv, curTime)
|
||||||
values.Values[i][idx], _ = field.ValueOf(rv)
|
values.Values[i][idx], _ = field.ValueOf(rv)
|
||||||
}
|
}
|
||||||
} else if field.AutoUpdateTime > 0 {
|
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||||
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
|
field.Set(rv, curTime)
|
||||||
field.Set(rv, curTime)
|
values.Values[i][idx], _ = field.ValueOf(rv)
|
||||||
values.Values[i][idx], _ = field.ValueOf(rv)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,11 +327,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
field.Set(stmt.ReflectValue, curTime)
|
field.Set(stmt.ReflectValue, curTime)
|
||||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
||||||
}
|
}
|
||||||
} else if field.AutoUpdateTime > 0 {
|
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||||
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
|
field.Set(stmt.ReflectValue, curTime)
|
||||||
field.Set(stmt.ReflectValue, curTime)
|
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
||||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,8 @@ func RowQuery(db *gorm.DB) {
|
||||||
BuildQuerySQL(db)
|
BuildQuerySQL(db)
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun {
|
||||||
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
|
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
||||||
|
db.Statement.Settings.Delete("rows")
|
||||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
} else {
|
} else {
|
||||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
|
@ -79,7 +79,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||||
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
||||||
}
|
}
|
||||||
tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
|
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||||
|
@ -426,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Row() *sql.Row {
|
func (db *DB) Row() *sql.Row {
|
||||||
tx := db.getInstance().InstanceSet("rows", false)
|
tx := db.getInstance().Set("rows", false)
|
||||||
tx = tx.callbacks.Row().Execute(tx)
|
tx = tx.callbacks.Row().Execute(tx)
|
||||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||||
if !ok && tx.DryRun {
|
if !ok && tx.DryRun {
|
||||||
|
@ -436,7 +436,7 @@ func (db *DB) Row() *sql.Row {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Rows() (*sql.Rows, error) {
|
func (db *DB) Rows() (*sql.Rows, error) {
|
||||||
tx := db.getInstance().InstanceSet("rows", true)
|
tx := db.getInstance().Set("rows", true)
|
||||||
tx = tx.callbacks.Row().Execute(tx)
|
tx = tx.callbacks.Row().Execute(tx)
|
||||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||||
if !ok && tx.DryRun && tx.Error == nil {
|
if !ok && tx.DryRun && tx.Error == nil {
|
||||||
|
|
|
@ -124,7 +124,6 @@ func TestCount(t *testing.T) {
|
||||||
|
|
||||||
var count9 int64
|
var count9 int64
|
||||||
if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB {
|
if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB {
|
||||||
fmt.Println("kdkdkdkdk")
|
|
||||||
return tx.Table("users")
|
return tx.Table("users")
|
||||||
}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 {
|
}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 {
|
||||||
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -62,4 +63,12 @@ func TestScopes(t *testing.T) {
|
||||||
if result.RowsAffected != 2 {
|
if result.RowsAffected != 2 {
|
||||||
t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected)
|
t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var maxId int64
|
||||||
|
userTable := func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.WithContext(context.Background()).Table("users")
|
||||||
|
}
|
||||||
|
if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil {
|
||||||
|
t.Errorf("select max(id)")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue