diff --git a/callbacks/associations.go b/callbacks/associations.go new file mode 100644 index 00000000..1df0103a --- /dev/null +++ b/callbacks/associations.go @@ -0,0 +1,72 @@ +package callbacks + +import ( + "reflect" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/utils" +) + +func SaveBeforeAssociations(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field) + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice: + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + + _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f) + + if isZero && creatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Create(f.Interface()) + } else { + db.Session(&gorm.Session{}).Create(f.Addr().Interface()) + } + } else if !isZero && updatable { + if f.Kind() == reflect.Ptr { + db.Session(&gorm.Session{}).Save(f.Interface()) + } else { + db.Session(&gorm.Session{}).Save(f.Addr().Interface()) + } + } else { + continue + } + + if saveRef { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(f) + ref.ForeignKey.Set(db.Statement.ReflectValue, fv) + } + } + } + } + } + } + } +} + +func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) { + creatable := field.Creatable + updatable := field.Updatable + saveRef := true + + if value, ok := db.Get("gorm:association_autocreate"); creatable && ok { + creatable = utils.CheckTruth(value) + } + + if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok { + updatable = utils.CheckTruth(value) + } + + if value, ok := db.Get("gorm:association_save_reference"); ok { + saveRef = utils.CheckTruth(value) + } + + return creatable, updatable, saveRef +} diff --git a/callbacks/create.go b/callbacks/create.go index e21e04c2..829c9c4c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -41,32 +41,6 @@ func BeforeCreate(db *gorm.DB) { } } -func SaveBeforeAssociations(db *gorm.DB) { - if db.Statement.Schema != nil { - for _, rel := range db.Statement.Schema.Relationships.BelongsTo { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: - case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if f.Kind() == reflect.Ptr { - db.Session(&gorm.Session{}).Create(f.Interface()) - } else { - db.Session(&gorm.Session{}).Create(f.Addr().Interface()) - } - - for _, ref := range rel.References { - if !ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(f) - ref.ForeignKey.Set(db.Statement.ReflectValue, fv) - } - } - } - } - } - } -} - func Create(config *Config) func(db *gorm.DB) { if config.WithReturning { return CreateWithReturning diff --git a/finisher_api.go b/finisher_api.go index 62c1af30..9e29e327 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -21,7 +21,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil { + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} reflectValue := reflect.ValueOf(value) for idx, pf := range tx.Statement.Schema.PrimaryFields { @@ -35,9 +35,6 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.AddClause(where) } - if len(tx.Statement.Selects) == 0 { - tx.Statement.Selects = []string{"*"} - } tx.callbacks.Update().Execute(tx) return } diff --git a/schema/field.go b/schema/field.go index ec419383..7b37733b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/jinzhu/gorm/utils" "github.com/jinzhu/now" ) @@ -146,13 +147,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBName = dbName } - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { + } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { field.PrimaryKey = true } - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { field.AutoIncrement = true field.HasDefaultValue = true } @@ -173,11 +174,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } - if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } - if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { field.Unique = true } diff --git a/schema/utils.go b/schema/utils.go index d7572d3d..7be78bc5 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -37,13 +37,6 @@ func ParseTagSetting(str string, sep string) map[string]string { return settings } -func checkTruth(val string) bool { - if strings.ToLower(val) == "false" { - return false - } - return true -} - func toColumns(val string) (results []string) { if val != "" { for _, v := range strings.Split(val, ",") { diff --git a/utils/utils.go b/utils/utils.go index 8521d09b..8dd500a5 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,8 +2,10 @@ package utils import ( "fmt" + "reflect" "regexp" "runtime" + "strings" "unicode" ) @@ -23,3 +25,16 @@ func FileWithLineNum() string { func IsChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) } + +func CheckTruth(val interface{}) bool { + if v, ok := val.(bool); ok { + return v + } + + if v, ok := val.(string); ok { + v = strings.ToLower(v) + return v != "false" + } + + return !reflect.ValueOf(val).IsZero() +}