Save before associations

This commit is contained in:
Jinzhu 2020-04-15 23:58:26 +08:00
parent b4b249ddcb
commit 345ff7577c
3 changed files with 76 additions and 10 deletions

View File

@ -42,6 +42,29 @@ func BeforeCreate(db *gorm.DB) {
} }
func SaveBeforeAssociations(db *gorm.DB) { func SaveBeforeAssociations(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
}
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(f)
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
}
}
}
}
}
}
} }
func Create(config *Config) func(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {

View File

@ -51,20 +51,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
if v == nil { if v == nil {
vars[idx] = "NULL" vars[idx] = "NULL"
} else { } else {
rv := reflect.Indirect(reflect.ValueOf(v)) rv := reflect.ValueOf(v)
if !rv.IsValid() { if !rv.IsValid() {
vars[idx] = "NULL" vars[idx] = "NULL"
return } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
} convertParams(reflect.Indirect(rv).Interface(), idx)
} else {
for _, t := range convertableTypes { for _, t := range convertableTypes {
if rv.Type().ConvertibleTo(t) { if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx) convertParams(rv.Convert(t).Interface(), idx)
return return
}
} }
}
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
}
} }
} }
} }

View File

@ -40,4 +40,45 @@ func TestCreate(t *testing.T, db *gorm.DB) {
AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday")
} }
}) })
TestCreateAssociations(t, db)
}
func TestCreateAssociations(t *testing.T, db *gorm.DB) {
db.Migrator().DropTable(&Company{})
db.Migrator().AutoMigrate(&Company{})
t.Run("Create-BelongsToAssociation", func(t *testing.T) {
var user = User{
Name: "create",
Age: 18,
Birthday: Now(),
Company: Company{Name: "company-belongs-to-association"},
Manager: &User{Name: "manager-belongs-to-association"},
}
if err := db.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if user.CompanyID == nil {
t.Errorf("Failed to create belongs to association - Company")
} else {
var company Company
db.First(&company, "id = ?", *user.CompanyID)
if company.Name != "company-belongs-to-association" {
t.Errorf("Failed to query saved belongs to association - Company")
}
}
if user.ManagerID == nil {
t.Errorf("Failed to create belongs to association - Manager")
} else {
var manager User
db.First(&manager, "id = ?", *user.ManagerID)
if manager.Name != "manager-belongs-to-association" {
t.Errorf("Failed to query saved belongs to association - Manager")
}
}
})
} }