mirror of https://github.com/go-gorm/gorm.git
Fix Scan with soft delete, close #3712
This commit is contained in:
parent
c915471169
commit
560d303e71
|
@ -13,15 +13,7 @@ import (
|
||||||
|
|
||||||
func Query(db *gorm.DB) {
|
func Query(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
BuildQuerySQL(db)
|
||||||
for _, c := range db.Statement.Schema.QueryClauses {
|
|
||||||
db.Statement.AddClause(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.SQL.String() == "" {
|
|
||||||
BuildQuerySQL(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
@ -37,131 +29,139 @@ func Query(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildQuerySQL(db *gorm.DB) {
|
func BuildQuerySQL(db *gorm.DB) {
|
||||||
db.Statement.SQL.Grow(100)
|
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
for _, c := range db.Statement.Schema.QueryClauses {
|
||||||
|
db.Statement.AddClause(c)
|
||||||
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
|
||||||
var conds []clause.Expression
|
|
||||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
|
||||||
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(conds) > 0 {
|
|
||||||
db.Statement.AddClause(clause.Where{Exprs: conds})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(db.Statement.Selects) > 0 {
|
if db.Statement.SQL.String() == "" {
|
||||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
db.Statement.SQL.Grow(100)
|
||||||
for idx, name := range db.Statement.Selects {
|
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||||
if db.Statement.Schema == nil {
|
|
||||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
var conds []clause.Expression
|
||||||
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||||
} else {
|
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
||||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conds) > 0 {
|
||||||
|
db.Statement.AddClause(clause.Where{Exprs: conds})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
|
|
||||||
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
|
if len(db.Statement.Selects) > 0 {
|
||||||
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||||
for _, dbName := range db.Statement.Schema.DBNames {
|
for idx, name := range db.Statement.Selects {
|
||||||
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
|
if db.Statement.Schema == nil {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName})
|
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||||
|
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||||
|
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||||
|
} else {
|
||||||
|
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
|
||||||
|
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
|
||||||
|
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
|
||||||
|
for _, dbName := range db.Statement.Schema.DBNames {
|
||||||
|
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
|
||||||
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
|
||||||
|
smallerStruct := false
|
||||||
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
|
||||||
|
case reflect.Slice:
|
||||||
|
smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
|
|
||||||
smallerStruct := false
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
|
||||||
case reflect.Struct:
|
|
||||||
smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
|
|
||||||
case reflect.Slice:
|
|
||||||
smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
|
|
||||||
}
|
|
||||||
|
|
||||||
if smallerStruct {
|
if smallerStruct {
|
||||||
stmt := gorm.Statement{DB: db}
|
stmt := gorm.Statement{DB: db}
|
||||||
// smaller struct
|
// smaller struct
|
||||||
if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType {
|
if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType {
|
||||||
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
||||||
|
|
||||||
for idx, dbName := range stmt.Schema.DBNames {
|
for idx, dbName := range stmt.Schema.DBNames {
|
||||||
clauseSelect.Columns[idx] = clause.Column{Name: dbName}
|
clauseSelect.Columns[idx] = clause.Column{Name: dbName}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// inline joins
|
// inline joins
|
||||||
if len(db.Statement.Joins) != 0 {
|
if len(db.Statement.Joins) != 0 {
|
||||||
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
|
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
|
||||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
joins := []clause.Join{}
|
joins := []clause.Join{}
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema == nil {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
|
||||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
|
||||||
tableAliasName := relation.Name
|
|
||||||
|
|
||||||
for _, s := range relation.FieldSchema.DBNames {
|
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
|
||||||
Table: tableAliasName,
|
|
||||||
Name: s,
|
|
||||||
Alias: tableAliasName + "__" + s,
|
|
||||||
})
|
})
|
||||||
}
|
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||||
|
tableAliasName := relation.Name
|
||||||
|
|
||||||
exprs := make([]clause.Expression, len(relation.References))
|
for _, s := range relation.FieldSchema.DBNames {
|
||||||
for idx, ref := range relation.References {
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
if ref.OwnPrimaryKey {
|
Table: tableAliasName,
|
||||||
exprs[idx] = clause.Eq{
|
Name: s,
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
Alias: tableAliasName + "__" + s,
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
})
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if ref.PrimaryValue == "" {
|
exprs := make([]clause.Expression, len(relation.References))
|
||||||
|
for idx, ref := range relation.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
exprs[idx] = clause.Eq{
|
exprs[idx] = clause.Eq{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
exprs[idx] = clause.Eq{
|
if ref.PrimaryValue == "" {
|
||||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
exprs[idx] = clause.Eq{
|
||||||
Value: ref.PrimaryValue,
|
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: ref.PrimaryValue,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Type: clause.LeftJoin,
|
Type: clause.LeftJoin,
|
||||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||||
ON: clause.Where{Exprs: exprs},
|
ON: clause.Where{Exprs: exprs},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
joins = append(joins, clause.Join{
|
joins = append(joins, clause.Join{
|
||||||
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
Expression: clause.Expr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.Statement.AddClause(clause.From{Joins: joins})
|
||||||
|
} else {
|
||||||
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.AddClause(clause.From{Joins: joins})
|
db.Statement.AddClauseIfNotExists(clauseSelect)
|
||||||
} else {
|
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
|
||||||
|
|
||||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Preload(db *gorm.DB) {
|
func Preload(db *gorm.DB) {
|
||||||
|
|
|
@ -6,9 +6,7 @@ import (
|
||||||
|
|
||||||
func RowQuery(db *gorm.DB) {
|
func RowQuery(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
if db.Statement.SQL.String() == "" {
|
BuildQuerySQL(db)
|
||||||
BuildQuerySQL(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !db.DryRun {
|
if !db.DryRun {
|
||||||
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
|
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
|
||||||
|
|
|
@ -14,10 +14,16 @@ func TestSoftDelete(t *testing.T) {
|
||||||
DB.Save(&user)
|
DB.Save(&user)
|
||||||
|
|
||||||
var count int64
|
var count int64
|
||||||
|
var age uint
|
||||||
|
|
||||||
if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 {
|
if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 {
|
||||||
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
|
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age {
|
||||||
|
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age)
|
||||||
|
}
|
||||||
|
|
||||||
if err := DB.Delete(&user).Error; err != nil {
|
if err := DB.Delete(&user).Error; err != nil {
|
||||||
t.Fatalf("No error should happen when soft delete user, but got %v", err)
|
t.Fatalf("No error should happen when soft delete user, but got %v", err)
|
||||||
}
|
}
|
||||||
|
@ -26,18 +32,30 @@ func TestSoftDelete(t *testing.T) {
|
||||||
t.Errorf("Can't find a soft deleted record")
|
t.Errorf("Can't find a soft deleted record")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 {
|
if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 {
|
||||||
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
|
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
age = 0
|
||||||
|
if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != 0 {
|
||||||
|
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age)
|
||||||
|
}
|
||||||
|
|
||||||
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
|
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
|
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 {
|
if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 {
|
||||||
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
|
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
age = 0
|
||||||
|
if DB.Unscoped().Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age {
|
||||||
|
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age)
|
||||||
|
}
|
||||||
|
|
||||||
DB.Unscoped().Delete(&user)
|
DB.Unscoped().Delete(&user)
|
||||||
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
t.Errorf("Can't find permanently deleted record")
|
t.Errorf("Can't find permanently deleted record")
|
||||||
|
|
Loading…
Reference in New Issue