mirror of https://github.com/go-gorm/gorm.git
Make inesrt into db works
This commit is contained in:
parent
868ae052a1
commit
fa22807e12
|
@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) {
|
|||
|
||||
if stmt := db.Statement; stmt != nil {
|
||||
db.Logger.Trace(curTime, func() (string, int64) {
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
@ -11,8 +10,6 @@ import (
|
|||
func BeforeCreate(db *gorm.DB) {
|
||||
// before save
|
||||
// before create
|
||||
|
||||
// assign timestamp
|
||||
}
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
|
@ -22,16 +19,29 @@ func Create(db *gorm.DB) {
|
|||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||
Table: clause.Table{Name: db.Statement.Table},
|
||||
})
|
||||
values, _ := ConvertToCreateValues(db.Statement)
|
||||
db.Statement.AddClause(values)
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
fmt.Printf("%+v\n", values)
|
||||
fmt.Println(err)
|
||||
fmt.Println(result)
|
||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||
insertID--
|
||||
}
|
||||
case reflect.Struct:
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
}
|
||||
}
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(db *gorm.DB) {
|
||||
|
@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) {
|
|||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) {
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
return ConvertMapToValues(stmt, value), nil
|
||||
return ConvertMapToValues(stmt, value)
|
||||
case []map[string]interface{}:
|
||||
return ConvertSliceOfMapToValues(stmt, value), nil
|
||||
return ConvertSliceOfMapToValues(stmt, value)
|
||||
default:
|
||||
var (
|
||||
values = clause.Values{}
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt)
|
||||
curTime = stmt.DB.NowFunc()
|
||||
isZero = false
|
||||
returnningValues []map[string]interface{}
|
||||
)
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
|
@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest))
|
||||
switch reflectValue.Kind() {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
values.Values = make([][]interface{}, reflectValue.Len())
|
||||
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
||||
defaultValueFieldsHavingValue := map[string][]interface{}{}
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
rv := reflect.Indirect(reflectValue.Index(i))
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
|
@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||
if v, isZero := field.ValueOf(rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[db]) == 0 {
|
||||
defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len())
|
||||
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
|
||||
}
|
||||
defaultValueFieldsHavingValue[db][i] = v
|
||||
}
|
||||
|
@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero {
|
||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[0][idx] = field.DefaultValueInterface
|
||||
field.Set(reflectValue, field.DefaultValueInterface)
|
||||
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
field.Set(reflectValue, curTime)
|
||||
values.Values[0][idx], _ = field.ValueOf(reflectValue)
|
||||
field.Set(stmt.ReflectValue, curTime)
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
||||
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
values.Values[0] = append(values.Values[0], v)
|
||||
}
|
||||
|
@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
|||
}
|
||||
}
|
||||
|
||||
return values, returnningValues
|
||||
return values
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
)
|
||||
|
@ -15,10 +13,8 @@ func Query(db *gorm.DB) {
|
|||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
fmt.Println(err)
|
||||
fmt.Println(result)
|
||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
||||
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
|
|
|
@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface {
|
|||
)
|
||||
|
||||
if config.Colorful {
|
||||
infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset
|
||||
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset
|
||||
errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset
|
||||
infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||
tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s"
|
||||
traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s"
|
||||
}
|
||||
|
@ -93,29 +93,28 @@ type logger struct {
|
|||
|
||||
// LogMode log mode
|
||||
func (l logger) LogMode(level LogLevel) Interface {
|
||||
config := l.Config
|
||||
config.LogLevel = level
|
||||
return logger{Writer: l.Writer, Config: config}
|
||||
l.LogLevel = level
|
||||
return l
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (l logger) Info(msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Info {
|
||||
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
|
||||
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (l logger) Warn(msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Warn {
|
||||
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
|
||||
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (l logger) Error(msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Error {
|
||||
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
|
||||
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) {
|
|||
func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
|
||||
if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold {
|
||||
sql, rows := fc()
|
||||
l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
fileline := utils.FileWithLineNum()
|
||||
if err != nil {
|
||||
fileline += " " + err.Error()
|
||||
}
|
||||
l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
} else if l.LogLevel >= Info {
|
||||
sql, rows := fc()
|
||||
l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
|
|
|
@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
case bool:
|
||||
vars[idx] = fmt.Sprint(v)
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
||||
}
|
||||
case []byte:
|
||||
if isPrintable(v) {
|
||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||
|
@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
vars[idx] = "NULL"
|
||||
} else {
|
||||
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||
if !rv.IsValid() {
|
||||
vars[idx] = "NULL"
|
||||
return
|
||||
}
|
||||
|
||||
for _, t := range convertableTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
convertParams(rv.Convert(t).Interface(), idx)
|
||||
|
|
|
@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
var err error
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
for _, ef := range field.EmbeddedSchema.Fields {
|
||||
|
|
|
@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
|
|||
}
|
||||
)
|
||||
|
||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
}
|
||||
}
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
|
|
|
@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field {
|
|||
}
|
||||
|
||||
// get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
|
||||
reflectValue := reflect.ValueOf(dest)
|
||||
modelType := reflectValue.Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
return v.(*Schema), nil
|
||||
return v.(*Schema), reflectValue, nil
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
|
@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.Creatable {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
return schema, reflectValue, schema.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema, schema.err
|
||||
return schema, reflectValue, schema.err
|
||||
}
|
||||
|
|
12
statement.go
12
statement.go
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -32,7 +33,7 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
|||
func (inst *Instance) AddError(err error) {
|
||||
if inst.Error == nil {
|
||||
inst.Error = err
|
||||
} else {
|
||||
} else if err != nil {
|
||||
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
|
||||
}
|
||||
}
|
||||
|
@ -42,6 +43,7 @@ type Statement struct {
|
|||
Table string
|
||||
Model interface{}
|
||||
Dest interface{}
|
||||
ReflectValue reflect.Value
|
||||
Clauses map[string]clause.Clause
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
|
@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
|||
// BuildCondtion build condition
|
||||
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
|
||||
if sql, ok := query.(string); ok {
|
||||
if i, err := strconv.Atoi(sql); err != nil {
|
||||
if i, err := strconv.Atoi(sql); err == nil {
|
||||
query = i
|
||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
||||
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
|
||||
|
@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) {
|
|||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
|
||||
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
|
||||
|
||||
if stmt.Table == "" {
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
|||
}
|
||||
|
||||
func TestCreate(t *testing.T, db *gorm.DB) {
|
||||
db.AutoMigrate(&User{})
|
||||
db = db.Debug()
|
||||
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
var user = User{
|
||||
Name: "create",
|
||||
|
|
Loading…
Reference in New Issue