Make inesrt into db works

This commit is contained in:
Jinzhu 2020-02-23 19:41:29 +08:00
parent 868ae052a1
commit fa22807e12
10 changed files with 92 additions and 66 deletions

View File

@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) {
if stmt := db.Statement; stmt != nil { if stmt := db.Statement; stmt != nil {
db.Logger.Trace(curTime, func() (string, int64) { db.Logger.Trace(curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error) }, db.Error)
} }
} }

View File

@ -1,7 +1,6 @@
package callbacks package callbacks
import ( import (
"fmt"
"reflect" "reflect"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -11,8 +10,6 @@ import (
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
// before save // before save
// before create // before create
// assign timestamp
} }
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
@ -22,16 +19,29 @@ func Create(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{ db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table}, Table: clause.Table{Name: db.Statement.Table},
}) })
values, _ := ConvertToCreateValues(db.Statement) db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.AddClause(values)
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Printf("%+v\n", values) if err == nil {
fmt.Println(err) if db.Statement.Schema != nil {
fmt.Println(result) if insertID, err := result.LastInsertId(); err == nil {
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
insertID--
}
case reflect.Struct:
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
}
}
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
} }
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) {
@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) {
} }
// ConvertToCreateValues convert to create values // ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
switch value := stmt.Dest.(type) { switch value := stmt.Dest.(type) {
case map[string]interface{}: case map[string]interface{}:
return ConvertMapToValues(stmt, value), nil return ConvertMapToValues(stmt, value)
case []map[string]interface{}: case []map[string]interface{}:
return ConvertSliceOfMapToValues(stmt, value), nil return ConvertSliceOfMapToValues(stmt, value)
default: default:
var ( var (
values = clause.Values{} values = clause.Values{}
selectColumns, restricted = SelectAndOmitColumns(stmt) selectColumns, restricted = SelectAndOmitColumns(stmt)
curTime = stmt.DB.NowFunc() curTime = stmt.DB.NowFunc()
isZero = false isZero = false
returnningValues []map[string]interface{}
) )
for _, db := range stmt.Schema.DBNames { for _, db := range stmt.Schema.DBNames {
@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
} }
} }
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) switch stmt.ReflectValue.Kind() {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
values.Values = make([][]interface{}, reflectValue.Len()) values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[string][]interface{}{} defaultValueFieldsHavingValue := map[string][]interface{}{}
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(reflectValue.Index(i)) rv := reflect.Indirect(stmt.ReflectValue.Index(i))
values.Values[i] = make([]interface{}, len(values.Columns)) values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero { if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[db]) == 0 { if len(defaultValueFieldsHavingValue[db]) == 0 {
defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len()) defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
} }
defaultValueFieldsHavingValue[db][i] = v defaultValueFieldsHavingValue[db][i] = v
} }
@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero { if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface values.Values[0][idx] = field.DefaultValueInterface
field.Set(reflectValue, field.DefaultValueInterface) field.Set(stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(reflectValue, curTime) field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(reflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
} }
} }
} }
for db, field := range stmt.Schema.FieldsWithDefaultDBValue { for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(reflectValue); !isZero { if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: db}) values.Columns = append(values.Columns, clause.Column{Name: db})
values.Values[0] = append(values.Values[0], v) values.Values[0] = append(values.Values[0], v)
} }
@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
} }
} }
return values, returnningValues return values
} }
} }

View File

@ -1,8 +1,6 @@
package callbacks package callbacks
import ( import (
"fmt"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
) )
@ -15,10 +13,8 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
} }
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err) db.AddError(err)
fmt.Println(result)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
} }
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {

View File

@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface {
) )
if config.Colorful { if config.Colorful {
infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s" tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s"
traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s" traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s"
} }
@ -93,29 +93,28 @@ type logger struct {
// LogMode log mode // LogMode log mode
func (l logger) LogMode(level LogLevel) Interface { func (l logger) LogMode(level LogLevel) Interface {
config := l.Config l.LogLevel = level
config.LogLevel = level return l
return logger{Writer: l.Writer, Config: config}
} }
// Info print info // Info print info
func (l logger) Info(msg string, data ...interface{}) { func (l logger) Info(msg string, data ...interface{}) {
if l.LogLevel >= Info { if l.LogLevel >= Info {
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Warn print warn messages // Warn print warn messages
func (l logger) Warn(msg string, data ...interface{}) { func (l logger) Warn(msg string, data ...interface{}) {
if l.LogLevel >= Warn { if l.LogLevel >= Warn {
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Error print error messages // Error print error messages
func (l logger) Error(msg string, data ...interface{}) { func (l logger) Error(msg string, data ...interface{}) {
if l.LogLevel >= Error { if l.LogLevel >= Error {
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
@ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) {
func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold { if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold {
sql, rows := fc() sql, rows := fc()
l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) fileline := utils.FileWithLineNum()
if err != nil {
fileline += " " + err.Error()
}
l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql)
} else if l.LogLevel >= Info { } else if l.LogLevel >= Info {
sql, rows := fc() sql, rows := fc()
l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)

View File

@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
case bool: case bool:
vars[idx] = fmt.Sprint(v) vars[idx] = fmt.Sprint(v)
case time.Time: case time.Time:
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
}
case []byte: case []byte:
if isPrintable(v) { if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
vars[idx] = "NULL" vars[idx] = "NULL"
} else { } else {
rv := reflect.Indirect(reflect.ValueOf(v)) rv := reflect.Indirect(reflect.ValueOf(v))
if !rv.IsValid() {
vars[idx] = "NULL"
return
}
for _, t := range convertableTypes { for _, t := range convertableTypes {
if rv.Type().ConvertibleTo(t) { if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx) convertParams(rv.Convert(t).Interface(), idx)

View File

@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error var err error
field.Creatable = false field.Creatable = false
field.Updatable = false field.Updatable = false
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
schema.err = err schema.err = err
} }
for _, ef := range field.EmbeddedSchema.Fields { for _, ef := range field.EmbeddedSchema.Fields {

View File

@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
} }
) )
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err schema.err = err
return return
} }
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many

View File

@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field {
} }
// get data type from dialector // get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
modelType := reflect.ValueOf(dest).Type() reflectValue := reflect.ValueOf(dest)
modelType := reflectValue.Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" { if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
} }
if v, ok := cacheStore.Load(modelType); ok { if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil return v.(*Schema), reflectValue, nil
} }
schema := &Schema{ schema := &Schema{
@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable { if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, schema.err return schema, reflectValue, schema.err
} }
} }
} }
return schema, schema.err return schema, reflectValue, schema.err
} }

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -32,22 +33,23 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
func (inst *Instance) AddError(err error) { func (inst *Instance) AddError(err error) {
if inst.Error == nil { if inst.Error == nil {
inst.Error = err inst.Error = err
} else { } else if err != nil {
inst.Error = fmt.Errorf("%v; %w", inst.Error, err) inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
} }
} }
// Statement statement // Statement statement
type Statement struct { type Statement struct {
Table string Table string
Model interface{} Model interface{}
Dest interface{} Dest interface{}
Clauses map[string]clause.Clause ReflectValue reflect.Value
Selects []string // selected columns Clauses map[string]clause.Clause
Omits []string // omit columns Selects []string // selected columns
Settings sync.Map Omits []string // omit columns
DB *DB Settings sync.Map
Schema *schema.Schema DB *DB
Schema *schema.Schema
// SQL Builder // SQL Builder
SQL strings.Builder SQL strings.Builder
@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
// BuildCondtion build condition // BuildCondtion build condition
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
if sql, ok := query.(string); ok { if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err != nil { if i, err := strconv.Atoi(sql); err == nil {
query = i query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) {
} }
func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
stmt.Table = stmt.Schema.Table stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
if stmt.Table == "" {
stmt.Table = stmt.Schema.Table
}
} }
return err return err
} }

View File

@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
} }
func TestCreate(t *testing.T, db *gorm.DB) { func TestCreate(t *testing.T, db *gorm.DB) {
db.AutoMigrate(&User{})
db = db.Debug()
t.Run("Create", func(t *testing.T) { t.Run("Create", func(t *testing.T) {
var user = User{ var user = User{
Name: "create", Name: "create",