forked from mirror/gorm
Allow Where select fields when searching with struct
This commit is contained in:
parent
79864af9ff
commit
1b8cb07cf2
26
statement.go
26
statement.go
|
@ -250,7 +250,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||
|
||||
conds := make([]clause.Expression, 0, 4)
|
||||
args = append([]interface{}{query}, args...)
|
||||
for _, arg := range args {
|
||||
for idx, arg := range args {
|
||||
if valuer, ok := arg.(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
@ -310,11 +310,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||
default:
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
selectedColumns := map[string]bool{}
|
||||
if idx == 0 {
|
||||
for _, v := range args[1:] {
|
||||
if vs, ok := v.(string); ok {
|
||||
selectedColumns[vs] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
restricted := len(selectedColumns) != 0
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, field := range s.Fields {
|
||||
if field.Readable {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
|
@ -326,8 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for _, field := range s.Fields {
|
||||
if field.Readable {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
|
@ -338,6 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if restricted {
|
||||
break
|
||||
}
|
||||
} else if len(conds) == 0 {
|
||||
if len(args) == 1 {
|
||||
switch reflectValue.Kind() {
|
||||
|
|
|
@ -921,6 +921,30 @@ func TestSearchWithMap(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSearchWithStruct(t *testing.T) {
|
||||
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
|
||||
result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubQuery(t *testing.T) {
|
||||
users := []User{
|
||||
{Name: "subquery_1", Age: 10},
|
||||
|
|
Loading…
Reference in New Issue