diff --git a/callbacks/query.go b/callbacks/query.go index 0703b92e..8613e46d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -13,15 +13,7 @@ import ( func Query(db *gorm.DB) { if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.QueryClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + BuildQuerySQL(db) if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) @@ -37,131 +29,139 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { - db.Statement.SQL.Grow(100) - clauseSelect := clause.Select{Distinct: db.Statement.Distinct} - - if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { - var conds []clause.Expression - for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { - conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) - } - } - - if len(conds) > 0 { - db.Statement.AddClause(clause.Where{Exprs: conds}) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) } } - if len(db.Statement.Selects) > 0 { - clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) - for idx, name := range db.Statement.Selects { - if db.Statement.Schema == nil { - clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} - } else if f := db.Statement.Schema.LookUpField(name); f != nil { - clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} - } else { - clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} + + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) } } - } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { - selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) - clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) - for _, dbName := range db.Statement.Schema.DBNames { - if v, ok := selectColumns[dbName]; (ok && v) || !ok { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + + if len(db.Statement.Selects) > 0 { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} + } else { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } + } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + } + } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + smallerStruct := false + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType } - } - } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { - smallerStruct := false - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType - case reflect.Slice: - smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType - } - if smallerStruct { - stmt := gorm.Statement{DB: db} - // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { - clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + if smallerStruct { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) - for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Name: dbName} + } } } } - } - // inline joins - if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { - clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) - for idx, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + // inline joins + if len(db.Statement.Joins) != 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } } - } - joins := []clause.Join{} - for _, join := range db.Statement.Joins { - if db.Statement.Schema == nil { - joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, - }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { - tableAliasName := relation.Name - - for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, + joins := []clause.Join{} + for _, join := range db.Statement.Joins { + if db.Statement.Schema == nil { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { + tableAliasName := relation.Name - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } } } } - } - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, - }) + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, + }) + } } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) } - db.Statement.AddClause(clause.From{Joins: joins}) - } else { - db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.AddClauseIfNotExists(clauseSelect) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - - db.Statement.AddClauseIfNotExists(clauseSelect) - - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } func Preload(db *gorm.DB) { diff --git a/callbacks/row.go b/callbacks/row.go index 4f985d7b..10e880e1 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -6,9 +6,7 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + BuildQuerySQL(db) if !db.DryRun { if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index c77675f7..283a4c34 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -14,10 +14,16 @@ func TestSoftDelete(t *testing.T) { DB.Save(&user) var count int64 + var age uint + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) } + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + if err := DB.Delete(&user).Error; err != nil { t.Fatalf("No error should happen when soft delete user, but got %v", err) } @@ -26,18 +32,30 @@ func TestSoftDelete(t *testing.T) { t.Errorf("Can't find a soft deleted record") } + count = 0 if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) } + age = 0 + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != 0 { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) } + count = 0 if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) } + age = 0 + if DB.Unscoped().Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + DB.Unscoped().Delete(&user) if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("Can't find permanently deleted record")