From 15ce5b3cdd8b256ce070245b3a41a1ca7d4ca0fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Feb 2020 12:53:46 +0800 Subject: [PATCH] Add create value converter --- callbacks/create.go | 87 +++++++++++++++++++++++++++++++++++++++- callbacks/helper.go | 97 +++++++++++++++++++++++++++++++++++++++++++++ chainable_api.go | 2 +- clause/values.go | 3 +- 4 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 callbacks/helper.go diff --git a/callbacks/create.go b/callbacks/create.go index 58256085..8dba8a5f 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -2,6 +2,7 @@ package callbacks import ( "fmt" + "reflect" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" @@ -19,11 +20,15 @@ func SaveBeforeAssociations(db *gorm.DB) { func Create(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Table: db.Statement.Table}, + Table: clause.Table{Name: db.Statement.Table}, }) + values, _ := ConvertToCreateValues(db.Statement) + db.Statement.AddClause(values) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + fmt.Printf("%+v\n", values) fmt.Println(err) fmt.Println(result) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) @@ -36,3 +41,83 @@ func AfterCreate(db *gorm.DB) { // after save // after create } + +// ConvertToCreateValues convert to create values +func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { + switch value := stmt.Dest.(type) { + case map[string]interface{}: + return ConvertMapToValues(stmt, value), nil + case []map[string]interface{}: + return ConvertSliceOfMapToValues(stmt, value), nil + default: + var ( + values = clause.Values{} + selectColumns, restricted = SelectAndOmitColumns(stmt) + curTime = stmt.DB.NowFunc() + isZero = false + returnningValues []map[string]interface{} + ) + + for _, db := range stmt.Schema.DBNames { + if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + values.Values = make([][]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + rv := reflect.Indirect(reflectValue.Index(i)) + values.Values[i] = make([]interface{}, len(values.Columns)) + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if field.DefaultValueInterface != nil { + values.Values[i][idx] = field.DefaultValueInterface + field.Set(rv, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) + } else if field.HasDefaultValue { + if len(returnningValues) == 0 { + returnningValues = make([]map[string]interface{}, reflectValue.Len()) + } + + if returnningValues[i] == nil { + returnningValues[i] = map[string]interface{}{} + } + + // FIXME + returnningValues[i][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() + } + } + } + } + case reflect.Struct: + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[0][idx], _ = field.ValueOf(reflectValue); isZero { + if field.DefaultValueInterface != nil { + values.Values[0][idx] = field.DefaultValueInterface + field.Set(reflectValue, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(reflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(reflectValue) + } else if field.HasDefaultValue { + if len(returnningValues) == 0 { + returnningValues = make([]map[string]interface{}, 1) + } + + values.Values[0][idx] = clause.Expr{SQL: "DEFAULT"} + returnningValues[0][column.Name] = field.ReflectValueOf(reflectValue).Addr().Interface() + } else if field.PrimaryKey { + } + } + } + } + return values, returnningValues + } +} diff --git a/callbacks/helper.go b/callbacks/helper.go new file mode 100644 index 00000000..56c0767d --- /dev/null +++ b/callbacks/helper.go @@ -0,0 +1,97 @@ +package callbacks + +import ( + "sort" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" +) + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { + results := map[string]bool{} + + // select columns + for _, column := range stmt.Selects { + if field := stmt.Schema.LookUpField(column); field != nil { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if field := stmt.Schema.LookUpField(omit); field != nil { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + return results, len(stmt.Selects) > 0 +} + +// ConvertMapToValues convert map to values +func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { + columns := make([]string, 0, len(mapValue)) + selectColumns, restricted := SelectAndOmitColumns(stmt) + + var keys []string + for k, _ := range mapValue { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + columns = append(columns, k) + values.Values[0] = append(values.Values[0], mapValue[k]) + } + } + return +} + +// ConvertSliceOfMapToValues convert slice of map to values +func ConvertSliceOfMapToValues(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { + var ( + columns = []string{} + result = map[string][]interface{}{} + selectColumns, restricted = SelectAndOmitColumns(stmt) + ) + + for idx, mapValue := range mapValues { + for k, v := range mapValue { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + + if _, ok := result[k]; !ok { + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + result[k] = make([]interface{}, len(mapValues)) + columns = append(columns, k) + } else { + continue + } + } + + result[k][idx] = v + } + } + + sort.Strings(columns) + values.Values = make([][]interface{}, len(mapValues)) + for idx, column := range columns { + for i, v := range result[column] { + if i == 0 { + values.Values[i] = make([]interface{}, len(columns)) + } + values.Values[i][idx] = v + } + } + return +} diff --git a/chainable_api.go b/chainable_api.go index 9aa08b54..a57deb63 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -99,7 +99,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() - if len(columns) == 1 && strings.Contains(columns[0], ",") { + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { tx.Statement.Omits = strings.FieldsFunc(columns[0], isChar) } else { tx.Statement.Omits = columns diff --git a/clause/values.go b/clause/values.go index 594b92e2..2c8dcf89 100644 --- a/clause/values.go +++ b/clause/values.go @@ -7,7 +7,7 @@ type Values struct { // Name from clause name func (Values) Name() string { - return "" + return "VALUES" } // Build build from clause @@ -40,6 +40,7 @@ func (values Values) Build(builder Builder) { // MergeClause merge values clauses func (values Values) MergeClause(clause *Clause) { + clause.Name = "" if v, ok := clause.Expression.(Values); ok { values.Values = append(v.Values, values.Values...) }