diff --git a/callbacks/create.go b/callbacks/create.go index 97a2832c..e21e04c2 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -42,6 +42,29 @@ func BeforeCreate(db *gorm.DB) { } func SaveBeforeAssociations(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + } + + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(f) + ref.ForeignKey.Set(db.Statement.ReflectValue, fv) + } + } + } + } + } + } } func Create(config *Config) func(db *gorm.DB) { diff --git a/logger/sql.go b/logger/sql.go index 41c514fd..9c0f54d7 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -51,20 +51,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v if v == nil { vars[idx] = "NULL" } else { - rv := reflect.Indirect(reflect.ValueOf(v)) + rv := reflect.ValueOf(v) + if !rv.IsValid() { vars[idx] = "NULL" - return - } - - for _, t := range convertableTypes { - if rv.Type().ConvertibleTo(t) { - convertParams(rv.Convert(t).Interface(), idx) - return + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) + } else { + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } } - } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + } } } } diff --git a/tests/create.go b/tests/create.go index dfd73bd3..74a010dc 100644 --- a/tests/create.go +++ b/tests/create.go @@ -40,4 +40,45 @@ func TestCreate(t *testing.T, db *gorm.DB) { AssertObjEqual(t, newUser, user, "Name", "Age", "Birthday") } }) + + TestCreateAssociations(t, db) +} + +func TestCreateAssociations(t *testing.T, db *gorm.DB) { + db.Migrator().DropTable(&Company{}) + db.Migrator().AutoMigrate(&Company{}) + + t.Run("Create-BelongsToAssociation", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + Company: Company{Name: "company-belongs-to-association"}, + Manager: &User{Name: "manager-belongs-to-association"}, + } + + if err := db.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if user.CompanyID == nil { + t.Errorf("Failed to create belongs to association - Company") + } else { + var company Company + db.First(&company, "id = ?", *user.CompanyID) + if company.Name != "company-belongs-to-association" { + t.Errorf("Failed to query saved belongs to association - Company") + } + } + + if user.ManagerID == nil { + t.Errorf("Failed to create belongs to association - Manager") + } else { + var manager User + db.First(&manager, "id = ?", *user.ManagerID) + if manager.Name != "manager-belongs-to-association" { + t.Errorf("Failed to query saved belongs to association - Manager") + } + } + }) }