From 7a90496701f7b81e06daaa134a8f8853c1f935d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 16:27:59 +0800 Subject: [PATCH] Test create from sql expr with map --- callbacks/create.go | 4 ++++ callbacks/helper.go | 12 ++++++++---- tests/scanner_valuer_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index cc7e2671..c59b14b5 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -225,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) diff --git a/callbacks/helper.go b/callbacks/helper.go index 80fbc2a1..e0a66dd2 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -20,8 +20,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter for _, k := range keys { value := mapValue[k] - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { @@ -46,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st for idx, mapValue := range mapValues { for k, v := range mapValue { - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if _, ok := result[k]; !ok { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ec16ccf6..dbf5adac 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -353,4 +353,30 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } }