diff --git a/statement.go b/statement.go index 5dd3a584..3617d7ed 100644 --- a/statement.go +++ b/statement.go @@ -308,24 +308,38 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } default: - if reflectValue := reflect.Indirect(reflect.ValueOf(arg)); reflectValue.IsValid() { - if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { - selectedColumns := map[string]bool{} - if idx == 0 { - for _, v := range args[1:] { - if vs, ok := v.(string); ok { - selectedColumns[vs] = true + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } } } } - restricted := len(selectedColumns) != 0 - - switch reflectValue.Kind() { - case reflect.Struct: + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -334,44 +348,31 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } } - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { - for _, field := range s.Fields { - selected := selectedColumns[field.DBName] || selectedColumns[field.Name] - if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { - if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) - } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) - } - } - } - } - } } - - if restricted { - break - } - } else if len(conds) == 0 { - if len(args) == 1 { - switch reflectValue.Kind() { - case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() - } - - if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) - } - return conds - } - } - - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } + + if restricted { + break + } + } else if !reflectValue.IsValid() { + stmt.AddError(ErrInvalidData) + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) } } }