From d002c70cf6ac6f35e4a2840606e65d84d33c5391 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Sep 2020 21:52:41 +0800 Subject: [PATCH] Support named argument for struct --- clause/expression.go | 12 ++++++++++++ clause/expression_test.go | 10 ++++++++++ tests/go.mod | 4 ++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index dde236d3..49924ef7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -3,6 +3,7 @@ package clause import ( "database/sql" "database/sql/driver" + "go/ast" "reflect" ) @@ -89,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) { for k, v := range value { namedMap[k] = v } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + } + } + } } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 17af737d..53d79c8f 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -37,6 +37,11 @@ func TestExpr(t *testing.T) { } func TestNamedExpr(t *testing.T) { + type NamedArgument struct { + Name1 string + Name2 string + } + results := []struct { SQL string Result string @@ -66,6 +71,11 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, }} for idx, result := range results { diff --git a/tests/go.mod b/tests/go.mod index 17a3b156..0db87934 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,9 +8,9 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.2 + gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 - gorm.io/gorm v1.20.0 + gorm.io/gorm v1.20.1 ) replace gorm.io/gorm => ../