mirror of https://github.com/go-gorm/gorm.git
Add permission check when create associations
This commit is contained in:
parent
345ff7577c
commit
56ca9a87e0
|
@ -0,0 +1,72 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/gorm/utils"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field)
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
|
||||
_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f)
|
||||
|
||||
if isZero && creatable {
|
||||
if f.Kind() == reflect.Ptr {
|
||||
db.Session(&gorm.Session{}).Create(f.Interface())
|
||||
} else {
|
||||
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
|
||||
}
|
||||
} else if !isZero && updatable {
|
||||
if f.Kind() == reflect.Ptr {
|
||||
db.Session(&gorm.Session{}).Save(f.Interface())
|
||||
} else {
|
||||
db.Session(&gorm.Session{}).Save(f.Addr().Interface())
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
if saveRef {
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(f)
|
||||
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) {
|
||||
creatable := field.Creatable
|
||||
updatable := field.Updatable
|
||||
saveRef := true
|
||||
|
||||
if value, ok := db.Get("gorm:association_autocreate"); creatable && ok {
|
||||
creatable = utils.CheckTruth(value)
|
||||
}
|
||||
|
||||
if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok {
|
||||
updatable = utils.CheckTruth(value)
|
||||
}
|
||||
|
||||
if value, ok := db.Get("gorm:association_save_reference"); ok {
|
||||
saveRef = utils.CheckTruth(value)
|
||||
}
|
||||
|
||||
return creatable, updatable, saveRef
|
||||
}
|
|
@ -41,32 +41,6 @@ func BeforeCreate(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
|
||||
func SaveBeforeAssociations(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
if f.Kind() == reflect.Ptr {
|
||||
db.Session(&gorm.Session{}).Create(f.Interface())
|
||||
} else {
|
||||
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(f)
|
||||
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
if config.WithReturning {
|
||||
return CreateWithReturning
|
||||
|
|
|
@ -21,7 +21,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
|
||||
if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil {
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
|
@ -35,9 +35,6 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||
tx.Statement.AddClause(where)
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.Selects = []string{"*"}
|
||||
}
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm/utils"
|
||||
"github.com/jinzhu/now"
|
||||
)
|
||||
|
||||
|
@ -146,13 +147,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
field.DBName = dbName
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) {
|
||||
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
||||
field.PrimaryKey = true
|
||||
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) {
|
||||
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
||||
field.PrimaryKey = true
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) {
|
||||
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
|
||||
field.AutoIncrement = true
|
||||
field.HasDefaultValue = true
|
||||
}
|
||||
|
@ -173,11 +174,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
field.Precision, _ = strconv.Atoi(p)
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) {
|
||||
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
||||
field.NotNull = true
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) {
|
||||
if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
|
||||
field.Unique = true
|
||||
}
|
||||
|
||||
|
|
|
@ -37,13 +37,6 @@ func ParseTagSetting(str string, sep string) map[string]string {
|
|||
return settings
|
||||
}
|
||||
|
||||
func checkTruth(val string) bool {
|
||||
if strings.ToLower(val) == "false" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func toColumns(val string) (results []string) {
|
||||
if val != "" {
|
||||
for _, v := range strings.Split(val, ",") {
|
||||
|
|
|
@ -2,8 +2,10 @@ package utils
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
|
@ -23,3 +25,16 @@ func FileWithLineNum() string {
|
|||
func IsChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
|
||||
}
|
||||
|
||||
func CheckTruth(val interface{}) bool {
|
||||
if v, ok := val.(bool); ok {
|
||||
return v
|
||||
}
|
||||
|
||||
if v, ok := val.(string); ok {
|
||||
v = strings.ToLower(v)
|
||||
return v != "false"
|
||||
}
|
||||
|
||||
return !reflect.ValueOf(val).IsZero()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue