From 1b8cb07cf29e1154778bcf063ddbeb095d4f93e5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 30 Dec 2020 17:42:27 +0800 Subject: [PATCH] Allow Where select fields when searching with struct --- statement.go | 26 +++++++++++++++++++++----- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index 707e4aef..9433f4a7 100644 --- a/statement.go +++ b/statement.go @@ -250,7 +250,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) - for _, arg := range args { + for idx, arg := range args { if valuer, ok := arg.(driver.Valuer); ok { arg, _ = valuer.Value() } @@ -310,11 +310,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - if field.Readable { - if v, isZero := field.ValueOf(reflectValue); !isZero { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -326,8 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - if field.Readable { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -338,6 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } } + + if restricted { + break + } } else if len(conds) == 0 { if len(args) == 1 { switch reflectValue.Kind() { diff --git a/tests/query_test.go b/tests/query_test.go index f1234d0a..50522f71 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -921,6 +921,30 @@ func TestSearchWithMap(t *testing.T) { } } +func TestSearchWithStruct(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + func TestSubQuery(t *testing.T) { users := []User{ {Name: "subquery_1", Age: 10},