From 6f4602af11c17d79610386df1112b2bf13fe509b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 29 May 2020 23:38:03 +0800 Subject: [PATCH] Fix mysql tests --- callbacks/preload.go | 6 +++++- logger/logger.go | 4 ++-- scan.go | 8 ++++++++ schema/field.go | 39 ++++++++++++++++++++++++++++++++------- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index f48777c2..cfea4f94 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -101,7 +101,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { - reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: rel.Field.Set(data, reflectResults.Index(i).Interface()) diff --git a/logger/logger.go b/logger/logger.go index 24cee821..7121b4fb 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -78,7 +78,7 @@ func New(writer Writer, config Config) Interface { traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" } - return logger{ + return &logger{ Writer: writer, Config: config, infoStr: infoStr, @@ -98,7 +98,7 @@ type logger struct { } // LogMode log mode -func (l logger) LogMode(level LogLevel) Interface { +func (l *logger) LogMode(level LogLevel) Interface { l.LogLevel = level return l } diff --git a/scan.go b/scan.go index d2169f87..c223f6eb 100644 --- a/scan.go +++ b/scan.go @@ -87,6 +87,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { values[idx] = field.ReflectValueOf(elem).Addr().Interface() } else if joinFields[idx][0] != nil { relValue := joinFields[idx][0].ReflectValueOf(elem) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + relValue.Set(reflect.New(relValue.Type().Elem())) + } + values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() } } @@ -110,6 +114,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + relValue.Set(reflect.New(relValue.Type().Elem())) + } + values[idx] = field.ReflectValueOf(relValue).Addr().Interface() continue } diff --git a/schema/field.go b/schema/field.go index 75ff71f6..f4fbad95 100644 --- a/schema/field.go +++ b/schema/field.go @@ -353,9 +353,6 @@ func (field *Field) setupValuerAndSetter() { if field.FieldType.Kind() == reflect.Ptr { field.ReflectValueOf = func(value reflect.Value) reflect.Value { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.FieldType.Elem())) - } return fieldValue } } else { @@ -406,7 +403,14 @@ func (field *Field) setupValuerAndSetter() { return setter(value, v) } } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == nil { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { return field.Set(value, reflectV.Elem().Interface()) } else { @@ -607,12 +611,26 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v)) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == nil { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t)) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == "" { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } @@ -651,7 +669,14 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == nil { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)