mirror of https://github.com/go-gorm/gorm.git
Make inesrt into db works
This commit is contained in:
parent
868ae052a1
commit
fa22807e12
|
@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) {
|
||||||
|
|
||||||
if stmt := db.Statement; stmt != nil {
|
if stmt := db.Statement; stmt != nil {
|
||||||
db.Logger.Trace(curTime, func() (string, int64) {
|
db.Logger.Trace(curTime, func() (string, int64) {
|
||||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||||
}, db.Error)
|
}, db.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
@ -11,8 +10,6 @@ import (
|
||||||
func BeforeCreate(db *gorm.DB) {
|
func BeforeCreate(db *gorm.DB) {
|
||||||
// before save
|
// before save
|
||||||
// before create
|
// before create
|
||||||
|
|
||||||
// assign timestamp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveBeforeAssociations(db *gorm.DB) {
|
func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
|
@ -22,16 +19,29 @@ func Create(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Insert{
|
db.Statement.AddClauseIfNotExists(clause.Insert{
|
||||||
Table: clause.Table{Name: db.Statement.Table},
|
Table: clause.Table{Name: db.Statement.Table},
|
||||||
})
|
})
|
||||||
values, _ := ConvertToCreateValues(db.Statement)
|
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||||
db.Statement.AddClause(values)
|
|
||||||
|
|
||||||
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
|
||||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
fmt.Printf("%+v\n", values)
|
if err == nil {
|
||||||
fmt.Println(err)
|
if db.Statement.Schema != nil {
|
||||||
fmt.Println(result)
|
if insertID, err := result.LastInsertId(); err == nil {
|
||||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
|
||||||
|
insertID--
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
} else {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveAfterAssociations(db *gorm.DB) {
|
func SaveAfterAssociations(db *gorm.DB) {
|
||||||
|
@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertToCreateValues convert to create values
|
// ConvertToCreateValues convert to create values
|
||||||
func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) {
|
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
||||||
switch value := stmt.Dest.(type) {
|
switch value := stmt.Dest.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
return ConvertMapToValues(stmt, value), nil
|
return ConvertMapToValues(stmt, value)
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
return ConvertSliceOfMapToValues(stmt, value), nil
|
return ConvertSliceOfMapToValues(stmt, value)
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
values = clause.Values{}
|
values = clause.Values{}
|
||||||
selectColumns, restricted = SelectAndOmitColumns(stmt)
|
selectColumns, restricted = SelectAndOmitColumns(stmt)
|
||||||
curTime = stmt.DB.NowFunc()
|
curTime = stmt.DB.NowFunc()
|
||||||
isZero = false
|
isZero = false
|
||||||
returnningValues []map[string]interface{}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, db := range stmt.Schema.DBNames {
|
for _, db := range stmt.Schema.DBNames {
|
||||||
|
@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest))
|
switch stmt.ReflectValue.Kind() {
|
||||||
switch reflectValue.Kind() {
|
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
values.Values = make([][]interface{}, reflectValue.Len())
|
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
||||||
defaultValueFieldsHavingValue := map[string][]interface{}{}
|
defaultValueFieldsHavingValue := map[string][]interface{}{}
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
rv := reflect.Indirect(reflectValue.Index(i))
|
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||||
for idx, column := range values.Columns {
|
for idx, column := range values.Columns {
|
||||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||||
|
@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
||||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||||
if v, isZero := field.ValueOf(rv); !isZero {
|
if v, isZero := field.ValueOf(rv); !isZero {
|
||||||
if len(defaultValueFieldsHavingValue[db]) == 0 {
|
if len(defaultValueFieldsHavingValue[db]) == 0 {
|
||||||
defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len())
|
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
|
||||||
}
|
}
|
||||||
defaultValueFieldsHavingValue[db][i] = v
|
defaultValueFieldsHavingValue[db][i] = v
|
||||||
}
|
}
|
||||||
|
@ -113,20 +121,20 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
||||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||||
for idx, column := range values.Columns {
|
for idx, column := range values.Columns {
|
||||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||||
if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero {
|
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
|
||||||
if field.DefaultValueInterface != nil {
|
if field.DefaultValueInterface != nil {
|
||||||
values.Values[0][idx] = field.DefaultValueInterface
|
values.Values[0][idx] = field.DefaultValueInterface
|
||||||
field.Set(reflectValue, field.DefaultValueInterface)
|
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
|
||||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||||
field.Set(reflectValue, curTime)
|
field.Set(stmt.ReflectValue, curTime)
|
||||||
values.Values[0][idx], _ = field.ValueOf(reflectValue)
|
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||||
if v, isZero := field.ValueOf(reflectValue); !isZero {
|
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||||
values.Values[0] = append(values.Values[0], v)
|
values.Values[0] = append(values.Values[0], v)
|
||||||
}
|
}
|
||||||
|
@ -134,6 +142,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return values, returnningValues
|
return values
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
)
|
)
|
||||||
|
@ -15,10 +13,8 @@ func Query(db *gorm.DB) {
|
||||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
fmt.Println(err)
|
db.AddError(err)
|
||||||
fmt.Println(result)
|
|
||||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Preload(db *gorm.DB) {
|
func Preload(db *gorm.DB) {
|
||||||
|
|
|
@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface {
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.Colorful {
|
if config.Colorful {
|
||||||
infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset
|
infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||||
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset
|
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||||
errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset
|
errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||||
tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s"
|
tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s"
|
||||||
traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s"
|
traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s"
|
||||||
}
|
}
|
||||||
|
@ -93,29 +93,28 @@ type logger struct {
|
||||||
|
|
||||||
// LogMode log mode
|
// LogMode log mode
|
||||||
func (l logger) LogMode(level LogLevel) Interface {
|
func (l logger) LogMode(level LogLevel) Interface {
|
||||||
config := l.Config
|
l.LogLevel = level
|
||||||
config.LogLevel = level
|
return l
|
||||||
return logger{Writer: l.Writer, Config: config}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Info print info
|
// Info print info
|
||||||
func (l logger) Info(msg string, data ...interface{}) {
|
func (l logger) Info(msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Info {
|
if l.LogLevel >= Info {
|
||||||
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
|
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn print warn messages
|
// Warn print warn messages
|
||||||
func (l logger) Warn(msg string, data ...interface{}) {
|
func (l logger) Warn(msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Warn {
|
if l.LogLevel >= Warn {
|
||||||
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
|
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error print error messages
|
// Error print error messages
|
||||||
func (l logger) Error(msg string, data ...interface{}) {
|
func (l logger) Error(msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Error {
|
if l.LogLevel >= Error {
|
||||||
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
|
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,7 +122,11 @@ func (l logger) Error(msg string, data ...interface{}) {
|
||||||
func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
|
func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
|
||||||
if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold {
|
if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold {
|
||||||
sql, rows := fc()
|
sql, rows := fc()
|
||||||
l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
fileline := utils.FileWithLineNum()
|
||||||
|
if err != nil {
|
||||||
|
fileline += " " + err.Error()
|
||||||
|
}
|
||||||
|
l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||||
} else if l.LogLevel >= Info {
|
} else if l.LogLevel >= Info {
|
||||||
sql, rows := fc()
|
sql, rows := fc()
|
||||||
l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||||
|
|
|
@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
||||||
case bool:
|
case bool:
|
||||||
vars[idx] = fmt.Sprint(v)
|
vars[idx] = fmt.Sprint(v)
|
||||||
case time.Time:
|
case time.Time:
|
||||||
|
if v.IsZero() {
|
||||||
|
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
|
||||||
|
} else {
|
||||||
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
|
||||||
|
}
|
||||||
case []byte:
|
case []byte:
|
||||||
if isPrintable(v) {
|
if isPrintable(v) {
|
||||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||||
|
@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
||||||
vars[idx] = "NULL"
|
vars[idx] = "NULL"
|
||||||
} else {
|
} else {
|
||||||
rv := reflect.Indirect(reflect.ValueOf(v))
|
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||||
|
if !rv.IsValid() {
|
||||||
|
vars[idx] = "NULL"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, t := range convertableTypes {
|
for _, t := range convertableTypes {
|
||||||
if rv.Type().ConvertibleTo(t) {
|
if rv.Type().ConvertibleTo(t) {
|
||||||
convertParams(rv.Convert(t).Interface(), idx)
|
convertParams(rv.Convert(t).Interface(), idx)
|
||||||
|
|
|
@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
var err error
|
var err error
|
||||||
field.Creatable = false
|
field.Creatable = false
|
||||||
field.Updatable = false
|
field.Updatable = false
|
||||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
for _, ef := range field.EmbeddedSchema.Fields {
|
for _, ef := range field.EmbeddedSchema.Fields {
|
||||||
|
|
|
@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
relation.JoinTable.Name = many2many
|
relation.JoinTable.Name = many2many
|
||||||
|
|
|
@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get data type from dialector
|
// get data type from dialector
|
||||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
|
||||||
modelType := reflect.ValueOf(dest).Type()
|
reflectValue := reflect.ValueOf(dest)
|
||||||
|
modelType := reflectValue.Type()
|
||||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
if modelType.PkgPath() == "" {
|
if modelType.PkgPath() == "" {
|
||||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := cacheStore.Load(modelType); ok {
|
if v, ok := cacheStore.Load(modelType); ok {
|
||||||
return v.(*Schema), nil
|
return v.(*Schema), reflectValue, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
|
@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.DataType == "" && field.Creatable {
|
if field.DataType == "" && field.Creatable {
|
||||||
if schema.parseRelation(field); schema.err != nil {
|
if schema.parseRelation(field); schema.err != nil {
|
||||||
return schema, schema.err
|
return schema, reflectValue, schema.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema, schema.err
|
return schema, reflectValue, schema.err
|
||||||
}
|
}
|
||||||
|
|
12
statement.go
12
statement.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -32,7 +33,7 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
|
||||||
func (inst *Instance) AddError(err error) {
|
func (inst *Instance) AddError(err error) {
|
||||||
if inst.Error == nil {
|
if inst.Error == nil {
|
||||||
inst.Error = err
|
inst.Error = err
|
||||||
} else {
|
} else if err != nil {
|
||||||
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
|
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -42,6 +43,7 @@ type Statement struct {
|
||||||
Table string
|
Table string
|
||||||
Model interface{}
|
Model interface{}
|
||||||
Dest interface{}
|
Dest interface{}
|
||||||
|
ReflectValue reflect.Value
|
||||||
Clauses map[string]clause.Clause
|
Clauses map[string]clause.Clause
|
||||||
Selects []string // selected columns
|
Selects []string // selected columns
|
||||||
Omits []string // omit columns
|
Omits []string // omit columns
|
||||||
|
@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||||
// BuildCondtion build condition
|
// BuildCondtion build condition
|
||||||
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
|
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
|
||||||
if sql, ok := query.(string); ok {
|
if sql, ok := query.(string); ok {
|
||||||
if i, err := strconv.Atoi(sql); err != nil {
|
if i, err := strconv.Atoi(sql); err == nil {
|
||||||
query = i
|
query = i
|
||||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
||||||
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
|
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
|
||||||
|
@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
|
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||||
|
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
|
||||||
|
|
||||||
|
if stmt.Table == "" {
|
||||||
stmt.Table = stmt.Schema.Table
|
stmt.Table = stmt.Schema.Table
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreate(t *testing.T, db *gorm.DB) {
|
func TestCreate(t *testing.T, db *gorm.DB) {
|
||||||
|
db.AutoMigrate(&User{})
|
||||||
|
db = db.Debug()
|
||||||
|
|
||||||
t.Run("Create", func(t *testing.T) {
|
t.Run("Create", func(t *testing.T) {
|
||||||
var user = User{
|
var user = User{
|
||||||
Name: "create",
|
Name: "create",
|
||||||
|
|
Loading…
Reference in New Issue