diff --git a/clause/where.go b/clause/where.go index 00b1a40e..61aa73a8 100644 --- a/clause/where.go +++ b/clause/where.go @@ -60,6 +60,9 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case Expr: sql := strings.ToLower(v.SQL) wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + case NamedExpr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") } } diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index d0a6f915..a3a25f7b 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -2,6 +2,7 @@ package tests_test import ( "database/sql" + "errors" "testing" "gorm.io/gorm" @@ -66,4 +67,16 @@ func TestNamedArg(t *testing.T) { } AssertEqual(t, result6, namedUser) + + var result7 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu-new")).Where("name3 = 'jinzhu-new3'").First(&result7).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } + + DB.Delete(&namedUser) + + var result8 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu-new"}).First(&result8).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } }