diff --git a/clause/expression.go b/clause/expression.go index b30c46b0..3844d66b 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -203,7 +203,7 @@ type Eq struct { func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) - if eq.Value == nil { + if eqNil(eq.Value) { builder.WriteString(" IS NULL") } else { builder.WriteString(" = ") @@ -221,7 +221,7 @@ type Neq Eq func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) - if neq.Value == nil { + if eqNil(neq.Value) { builder.WriteString(" IS NOT NULL") } else { builder.WriteString(" <> ") @@ -299,3 +299,12 @@ func (like Like) NegationBuild(builder Builder) { builder.WriteString(" NOT LIKE ") builder.AddVar(builder, like.Value) } + +func eqNil(value interface{}) bool { + return value == nil || eqNilReflect(value) +} + +func eqNilReflect(value interface{}) bool { + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} diff --git a/clause/expression_test.go b/clause/expression_test.go index 83082486..9e3d7bad 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -101,3 +101,52 @@ func TestNamedExpr(t *testing.T) { }) } } + +func TestExpression(t *testing.T) { + column := "column-name" + results := []struct { + Expressions []clause.Expression + Result string + }{{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: "column-value"}, + }, + Result: "`column-name` = ?", + },{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: nil}, + clause.Eq{Column: column, Value: (*string)(nil)}, + clause.Eq{Column: column, Value: (*int)(nil)}, + clause.Eq{Column: column, Value: (*bool)(nil)}, + clause.Eq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NULL", + },{ + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: "column-value"}, + }, + Result: "`column-name` <> ?", + },{ + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: nil}, + clause.Neq{Column: column, Value: (*string)(nil)}, + clause.Neq{Column: column, Value: (*int)(nil)}, + clause.Neq{Column: column, Value: (*bool)(nil)}, + clause.Neq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NOT NULL", + }} + + for idx, result := range results { + for idy, expression := range result.Expressions { + t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + expression.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } + } +}