forked from mirror/gorm
Save before associations
This commit is contained in:
parent
b4b249ddcb
commit
345ff7577c
|
@ -42,6 +42,29 @@ func BeforeCreate(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
if f.Kind() == reflect.Ptr {
|
||||
db.Session(&gorm.Session{}).Create(f.Interface())
|
||||
} else {
|
||||
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(f)
|
||||
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
|
|
|
@ -51,12 +51,13 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
if v == nil {
|
||||
vars[idx] = "NULL"
|
||||
} else {
|
||||
rv := reflect.Indirect(reflect.ValueOf(v))
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
if !rv.IsValid() {
|
||||
vars[idx] = "NULL"
|
||||
return
|
||||
}
|
||||
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else {
|
||||
for _, t := range convertableTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
convertParams(rv.Convert(t).Interface(), idx)
|
||||
|
@ -68,6 +69,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, v := range vars {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
|
|
|
@ -40,4 +40,45 @@ func TestCreate(t *testing.T, db *gorm.DB) {
|
|||
AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday")
|
||||
}
|
||||
})
|
||||
|
||||
TestCreateAssociations(t, db)
|
||||
}
|
||||
|
||||
func TestCreateAssociations(t *testing.T, db *gorm.DB) {
|
||||
db.Migrator().DropTable(&Company{})
|
||||
db.Migrator().AutoMigrate(&Company{})
|
||||
|
||||
t.Run("Create-BelongsToAssociation", func(t *testing.T) {
|
||||
var user = User{
|
||||
Name: "create",
|
||||
Age: 18,
|
||||
Birthday: Now(),
|
||||
Company: Company{Name: "company-belongs-to-association"},
|
||||
Manager: &User{Name: "manager-belongs-to-association"},
|
||||
}
|
||||
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if user.CompanyID == nil {
|
||||
t.Errorf("Failed to create belongs to association - Company")
|
||||
} else {
|
||||
var company Company
|
||||
db.First(&company, "id = ?", *user.CompanyID)
|
||||
if company.Name != "company-belongs-to-association" {
|
||||
t.Errorf("Failed to query saved belongs to association - Company")
|
||||
}
|
||||
}
|
||||
|
||||
if user.ManagerID == nil {
|
||||
t.Errorf("Failed to create belongs to association - Manager")
|
||||
} else {
|
||||
var manager User
|
||||
db.First(&manager, "id = ?", *user.ManagerID)
|
||||
if manager.Name != "manager-belongs-to-association" {
|
||||
t.Errorf("Failed to query saved belongs to association - Manager")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue