Fix mysql tests

This commit is contained in:
Jinzhu 2020-05-29 23:38:03 +08:00
parent d05128be78
commit 6f4602af11
4 changed files with 47 additions and 10 deletions

View File

@ -101,7 +101,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) reflectFieldValue := rel.Field.ReflectValueOf(data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() { switch reflectFieldValue.Kind() {
case reflect.Struct: case reflect.Struct:
rel.Field.Set(data, reflectResults.Index(i).Interface()) rel.Field.Set(data, reflectResults.Index(i).Interface())

View File

@ -78,7 +78,7 @@ func New(writer Writer, config Config) Interface {
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
} }
return logger{ return &logger{
Writer: writer, Writer: writer,
Config: config, Config: config,
infoStr: infoStr, infoStr: infoStr,
@ -98,7 +98,7 @@ type logger struct {
} }
// LogMode log mode // LogMode log mode
func (l logger) LogMode(level LogLevel) Interface { func (l *logger) LogMode(level LogLevel) Interface {
l.LogLevel = level l.LogLevel = level
return l return l
} }

View File

@ -87,6 +87,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
values[idx] = field.ReflectValueOf(elem).Addr().Interface() values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil { } else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem) relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
} }
} }
@ -110,6 +114,10 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = field.ReflectValueOf(relValue).Addr().Interface() values[idx] = field.ReflectValueOf(relValue).Addr().Interface()
continue continue
} }

View File

@ -353,9 +353,6 @@ func (field *Field) setupValuerAndSetter() {
if field.FieldType.Kind() == reflect.Ptr { if field.FieldType.Kind() == reflect.Ptr {
field.ReflectValueOf = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
return fieldValue return fieldValue
} }
} else { } else {
@ -406,7 +403,14 @@ func (field *Field) setupValuerAndSetter() {
return setter(value, v) return setter(value, v)
} }
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
return field.Set(value, reflectV.Elem().Interface()) return field.Set(value, reflectV.Elem().Interface())
} else { } else {
@ -607,12 +611,26 @@ func (field *Field) setupValuerAndSetter() {
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v)) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t)) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == "" {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
@ -651,7 +669,14 @@ func (field *Field) setupValuerAndSetter() {
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)