Allow Where select fields when searching with struct

This commit is contained in:
Jinzhu 2020-12-30 17:42:27 +08:00
parent 79864af9ff
commit 1b8cb07cf2
2 changed files with 45 additions and 5 deletions

View File

@ -250,7 +250,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...)
for _, arg := range args {
for idx, arg := range args {
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
@ -310,11 +310,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
default:
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
if vs, ok := v.(string); ok {
selectedColumns[vs] = true
}
}
}
restricted := len(selectedColumns) != 0
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
if field.Readable {
if v, isZero := field.ValueOf(reflectValue); !isZero {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
@ -326,8 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
if field.Readable {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
@ -338,6 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
}
}
if restricted {
break
}
} else if len(conds) == 0 {
if len(args) == 1 {
switch reflectValue.Kind() {

View File

@ -921,6 +921,30 @@ func TestSearchWithMap(t *testing.T) {
}
}
func TestSearchWithStruct(t *testing.T) {
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{})
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{})
if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
}
func TestSubQuery(t *testing.T) {
users := []User{
{Name: "subquery_1", Age: 10},