mirror of https://github.com/go-gorm/gorm.git
Allow Where select fields when searching with struct
This commit is contained in:
parent
79864af9ff
commit
1b8cb07cf2
26
statement.go
26
statement.go
|
@ -250,7 +250,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||||
|
|
||||||
conds := make([]clause.Expression, 0, 4)
|
conds := make([]clause.Expression, 0, 4)
|
||||||
args = append([]interface{}{query}, args...)
|
args = append([]interface{}{query}, args...)
|
||||||
for _, arg := range args {
|
for idx, arg := range args {
|
||||||
if valuer, ok := arg.(driver.Valuer); ok {
|
if valuer, ok := arg.(driver.Valuer); ok {
|
||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
|
@ -310,11 +310,22 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||||
default:
|
default:
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
||||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||||
|
selectedColumns := map[string]bool{}
|
||||||
|
if idx == 0 {
|
||||||
|
for _, v := range args[1:] {
|
||||||
|
if vs, ok := v.(string); ok {
|
||||||
|
selectedColumns[vs] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
restricted := len(selectedColumns) != 0
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for _, field := range s.Fields {
|
for _, field := range s.Fields {
|
||||||
if field.Readable {
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
if selected || (!restricted && field.Readable) {
|
||||||
|
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
|
@ -326,8 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
for _, field := range s.Fields {
|
for _, field := range s.Fields {
|
||||||
if field.Readable {
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
|
if selected || (!restricted && field.Readable) {
|
||||||
|
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
|
@ -338,6 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if restricted {
|
||||||
|
break
|
||||||
|
}
|
||||||
} else if len(conds) == 0 {
|
} else if len(conds) == 0 {
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
|
|
|
@ -921,6 +921,30 @@ func TestSearchWithMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSearchWithStruct(t *testing.T) {
|
||||||
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
|
||||||
|
result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{})
|
||||||
|
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
|
||||||
|
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
|
||||||
|
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{})
|
||||||
|
if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSubQuery(t *testing.T) {
|
func TestSubQuery(t *testing.T) {
|
||||||
users := []User{
|
users := []User{
|
||||||
{Name: "subquery_1", Age: 10},
|
{Name: "subquery_1", Age: 10},
|
||||||
|
|
Loading…
Reference in New Issue