diff --git a/callbacks/create.go b/callbacks/create.go index a2944319..ebfc8426 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -159,6 +159,7 @@ func CreateWithReturning(db *gorm.DB) { break } + resetFields := map[int]*schema.Field{} for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) @@ -172,22 +173,47 @@ func CreateWithReturning(db *gorm.DB) { goto BEGIN } - values[idx] = fieldValue.Addr().Interface() + if field.FieldType.Kind() == reflect.Ptr { + values[idx] = fieldValue.Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) + reflectValue.Elem().Set(fieldValue.Addr()) + values[idx] = reflectValue.Interface() + resetFields[idx] = field + } } db.RowsAffected++ if err := rows.Scan(values...); err != nil { db.AddError(err) } + + for idx, field := range resetFields { + if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { + field.ReflectValueOf(reflectValue).Set(v) + } + } } case reflect.Struct: + resetFields := map[int]*schema.Field{} for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + if field.FieldType.Kind() == reflect.Ptr { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) + reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr()) + values[idx] = reflectValue.Interface() + resetFields[idx] = field + } } - if rows.Next() { db.RowsAffected++ db.AddError(rows.Scan(values...)) + for idx, field := range resetFields { + if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { + field.ReflectValueOf(db.Statement.ReflectValue).Set(v) + } + } } } } else { diff --git a/tests/gorm_test.go b/tests/gorm_test.go new file mode 100644 index 00000000..39741439 --- /dev/null +++ b/tests/gorm_test.go @@ -0,0 +1,98 @@ +package tests_test + +import ( + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "testing" +) + +func TestReturningWithNullToZeroValues(t *testing.T) { + dialect := DB.Dialector.Name() + switch dialect { + case "mysql", "sqlserver": + // these dialects do not support the "returning" clause + return + default: + // This user struct will leverage the existing users table, but override + // the Name field to default to null. + type user struct { + gorm.Model + Name string `gorm:"default:null"` + } + u1 := user{} + c := DB.Callback().Create().Get("gorm:create") + t.Cleanup(func() { + DB.Callback().Create().Replace("gorm:create", c) + }) + DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true})) + + if results := DB.Create(&u1); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } + + got := user{} + results := DB.First(&got, "id = ?", u1.ID) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("first expects: %v, got %v", u1, got) + } + + results = DB.Select("id, name").Find(&got) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1, got) + } + + u1.Name = "jinzhu" + if results := DB.Save(&u1); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + u1 = user{} // important to reinitialize this before creating it again + u2 := user{} + db := DB.Session(&gorm.Session{CreateBatchSize: 10}) + + if results := db.Create([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } else if u2.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u2.ID) + } + + var gotUsers []user + results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected) + } else if gotUsers[0].ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID) + } else if gotUsers[1].ID != u2.ID { + t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID) + } + + u1.Name = "Jinzhu" + u2.Name = "Zhang" + if results := DB.Save([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + } +}