diff --git a/callbacks/create.go b/callbacks/create.go index ebfc8426..c889caf6 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -149,8 +149,11 @@ func CreateWithReturning(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - c := db.Statement.Clauses["ON CONFLICT"] - onConflict, _ := c.Expression.(clause.OnConflict) + var ( + c = db.Statement.Clauses["ON CONFLICT"] + onConflict, _ = c.Expression.(clause.OnConflict) + resetFieldValues = map[int]reflect.Value{} + ) for rows.Next() { BEGIN: @@ -159,7 +162,6 @@ func CreateWithReturning(db *gorm.DB) { break } - resetFields := map[int]*schema.Field{} for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) @@ -179,7 +181,7 @@ func CreateWithReturning(db *gorm.DB) { reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) reflectValue.Elem().Set(fieldValue.Addr()) values[idx] = reflectValue.Interface() - resetFields[idx] = field + resetFieldValues[idx] = fieldValue } } @@ -188,30 +190,31 @@ func CreateWithReturning(db *gorm.DB) { db.AddError(err) } - for idx, field := range resetFields { - if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { - field.ReflectValueOf(reflectValue).Set(v) + for idx, fv := range resetFieldValues { + if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { + fv.Set(v.Elem()) } } } case reflect.Struct: - resetFields := map[int]*schema.Field{} + resetFieldValues := map[int]reflect.Value{} for idx, field := range fields { if field.FieldType.Kind() == reflect.Ptr { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } else { reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) - reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr()) + fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) + reflectValue.Elem().Set(fieldValue.Addr()) values[idx] = reflectValue.Interface() - resetFields[idx] = field + resetFieldValues[idx] = fieldValue } } if rows.Next() { db.RowsAffected++ db.AddError(rows.Scan(values...)) - for idx, field := range resetFields { - if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { - field.ReflectValueOf(db.Statement.ReflectValue).Set(v) + for idx, fv := range resetFieldValues { + if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { + fv.Set(v.Elem()) } } }