forked from mirror/gorm
Test create from sql expr with map
This commit is contained in:
parent
ce8853e7a6
commit
7a90496701
|
@ -225,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||||
switch value := stmt.Dest.(type) {
|
switch value := stmt.Dest.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
values = ConvertMapToValuesForCreate(stmt, value)
|
values = ConvertMapToValuesForCreate(stmt, value)
|
||||||
|
case *map[string]interface{}:
|
||||||
|
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||||
|
case *[]map[string]interface{}:
|
||||||
|
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||||
|
|
|
@ -20,9 +20,11 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
value := mapValue[k]
|
value := mapValue[k]
|
||||||
|
if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
k = field.DBName
|
k = field.DBName
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||||
|
@ -46,9 +48,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
|
||||||
|
|
||||||
for idx, mapValue := range mapValues {
|
for idx, mapValue := range mapValues {
|
||||||
for k, v := range mapValue {
|
for k, v := range mapValue {
|
||||||
|
if stmt.Schema != nil {
|
||||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||||
k = field.DBName
|
k = field.DBName
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := result[k]; !ok {
|
if _, ok := result[k]; !ok {
|
||||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||||
|
|
|
@ -353,4 +353,30 @@ func TestGORMValuer(t *testing.T) {
|
||||||
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{
|
||||||
|
"Name": "jinzhu",
|
||||||
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{
|
||||||
|
"Name": "jinzhu",
|
||||||
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
||||||
|
}).Statement
|
||||||
|
|
||||||
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
||||||
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
||||||
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue