diff --git a/clause/expression.go b/clause/expression.go index f76ce138..a0933ad2 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -233,11 +233,24 @@ type Eq struct { func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) - if eqNil(eq.Value) { - builder.WriteString(" IS NULL") - } else { - builder.WriteString(" = ") - builder.AddVar(builder, eq.Value) + switch eq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" IN (") + rv := reflect.ValueOf(eq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i)) + } + builder.WriteByte(')') + default: + if eqNil(eq.Value) { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } } } @@ -251,11 +264,24 @@ type Neq Eq func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) - if eqNil(neq.Value) { - builder.WriteString(" IS NOT NULL") - } else { - builder.WriteString(" <> ") - builder.AddVar(builder, neq.Value) + switch neq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" NOT IN (") + rv := reflect.ValueOf(neq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i)) + } + builder.WriteByte(')') + default: + if eqNil(neq.Value) { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 4472bdb1..e0e192f7 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -136,6 +136,16 @@ func TestExpression(t *testing.T) { clause.Neq{Column: column, Value: (interface{})(nil)}, }, Result: "`column-name` IS NOT NULL", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{"a", "b"}}, + }, + Result: "`column-name` IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: []string{"a", "b"}}, + }, + Result: "`column-name` NOT IN (?,?)", }} for idx, result := range results {