diff --git a/callbacks.go b/callbacks.go index 573d7a8e..3aed2d37 100644 --- a/callbacks.go +++ b/callbacks.go @@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) { if stmt := db.Statement; stmt != nil { db.Logger.Trace(curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) } } diff --git a/callbacks/create.go b/callbacks/create.go index 95afc854..3866ddb0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,7 +1,6 @@ package callbacks import ( - "fmt" "reflect" "github.com/jinzhu/gorm" @@ -11,8 +10,6 @@ import ( func BeforeCreate(db *gorm.DB) { // before save // before create - - // assign timestamp } func SaveBeforeAssociations(db *gorm.DB) { @@ -22,16 +19,29 @@ func Create(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Insert{ Table: clause.Table{Name: db.Statement.Table}, }) - values, _ := ConvertToCreateValues(db.Statement) - db.Statement.AddClause(values) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - fmt.Printf("%+v\n", values) - fmt.Println(err) - fmt.Println(result) - fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) + if err == nil { + if db.Statement.Schema != nil { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } + } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } func SaveAfterAssociations(db *gorm.DB) { @@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) { } // ConvertToCreateValues convert to create values -func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) { +func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { switch value := stmt.Dest.(type) { case map[string]interface{}: - return ConvertMapToValues(stmt, value), nil + return ConvertMapToValues(stmt, value) case []map[string]interface{}: - return ConvertSliceOfMapToValues(stmt, value), nil + return ConvertSliceOfMapToValues(stmt, value) default: var ( values = clause.Values{} selectColumns, restricted = SelectAndOmitColumns(stmt) curTime = stmt.DB.NowFunc() isZero = false - returnningValues []map[string]interface{} ) for _, db := range stmt.Schema.DBNames { @@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in } } - reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest)) - switch reflectValue.Kind() { + switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - values.Values = make([][]interface{}, reflectValue.Len()) + values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[string][]interface{}{} - for i := 0; i < reflectValue.Len(); i++ { - rv := reflect.Indirect(reflectValue.Index(i)) + for i := 0; i < stmt.ReflectValue.Len(); i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] @@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { if len(defaultValueFieldsHavingValue[db]) == 0 { - defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len()) + defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) } defaultValueFieldsHavingValue[db][i] = v } @@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(reflectValue, field.DefaultValueInterface) + field.Set(stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(reflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(reflectValue) + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } } } for db, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(reflectValue); !isZero { + if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: db}) values.Values[0] = append(values.Values[0], v) } @@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in } } - return values, returnningValues + return values } } diff --git a/callbacks/query.go b/callbacks/query.go index a4ed3adb..195709fe 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,8 +1,6 @@ package callbacks import ( - "fmt" - "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" ) @@ -15,10 +13,8 @@ func Query(db *gorm.DB) { db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - fmt.Println(err) - fmt.Println(result) - fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) + rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.AddError(err) } func Preload(db *gorm.DB) { diff --git a/logger/logger.go b/logger/logger.go index 5656a86f..568ddd57 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface { ) if config.Colorful { - infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset - warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset - errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset + infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s" traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s" } @@ -93,29 +93,28 @@ type logger struct { // LogMode log mode func (l logger) LogMode(level LogLevel) Interface { - config := l.Config - config.LogLevel = level - return logger{Writer: l.Writer, Config: config} + l.LogLevel = level + return l } // Info print info func (l logger) Info(msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) + l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages func (l logger) Warn(msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) + l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages func (l logger) Error(msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)) + l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } @@ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) { func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold { sql, rows := fc() - l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + fileline := utils.FileWithLineNum() + if err != nil { + fileline += " " + err.Error() + } + l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql) } else if l.LogLevel >= Info { sql, rows := fc() l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) diff --git a/logger/sql.go b/logger/sql.go index f63dc160..eec72d47 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v case bool: vars[idx] = fmt.Sprint(v) case time.Time: - vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + } case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper @@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v vars[idx] = "NULL" } else { rv := reflect.Indirect(reflect.ValueOf(v)) + if !rv.IsValid() { + vars[idx] = "NULL" + return + } + for _, t := range convertableTypes { if rv.Type().ConvertibleTo(t) { convertParams(rv.Convert(t).Interface(), idx) diff --git a/schema/field.go b/schema/field.go index ea4e6a40..f640ec3b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field.Creatable = false field.Updatable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/relationship.go b/schema/relationship.go index 4ffea8b3..3b9d692a 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = err return } @@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many diff --git a/schema/schema.go b/schema/schema.go index acf6ff52..c3ac2bd9 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field { } // get data type from dialector -func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { - modelType := reflect.ValueOf(dest).Type() +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) { + reflectValue := reflect.ValueOf(dest) + modelType := reflectValue.Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), nil + return v.(*Schema), reflectValue, nil } schema := &Schema{ @@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for _, field := range schema.Fields { if field.DataType == "" && field.Creatable { if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + return schema, reflectValue, schema.err } } } - return schema, schema.err + return schema, reflectValue, schema.err } diff --git a/statement.go b/statement.go index d486a1c7..91f45b2b 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "reflect" "strconv" "strings" "sync" @@ -32,22 +33,23 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { func (inst *Instance) AddError(err error) { if inst.Error == nil { inst.Error = err - } else { + } else if err != nil { inst.Error = fmt.Errorf("%v; %w", inst.Error, err) } } // Statement statement type Statement struct { - Table string - Model interface{} - Dest interface{} - Clauses map[string]clause.Clause - Selects []string // selected columns - Omits []string // omit columns - Settings sync.Map - DB *DB - Schema *schema.Schema + Table string + Model interface{} + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + Selects []string // selected columns + Omits []string // omit columns + Settings sync.Map + DB *DB + Schema *schema.Schema // SQL Builder SQL strings.Builder @@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { // BuildCondtion build condition func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { if sql, ok := query.(string); ok { - if i, err := strconv.Atoi(sql); err != nil { + if i, err := strconv.Atoi(sql); err == nil { query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} @@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { - stmt.Table = stmt.Schema.Table + if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue) + + if stmt.Table == "" { + stmt.Table = stmt.Schema.Table + } } return err } diff --git a/tests/tests.go b/tests/tests.go index b3246a79..53700710 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) { } func TestCreate(t *testing.T, db *gorm.DB) { + db.AutoMigrate(&User{}) + db = db.Debug() + t.Run("Create", func(t *testing.T) { var user = User{ Name: "create",