mirror of https://github.com/go-gorm/gorm.git
Add permission check when create associations
This commit is contained in:
parent
345ff7577c
commit
56ca9a87e0
|
@ -0,0 +1,72 @@
|
||||||
|
package callbacks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
"github.com/jinzhu/gorm/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||||
|
creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field)
|
||||||
|
|
||||||
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||||
|
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||||
|
|
||||||
|
_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f)
|
||||||
|
|
||||||
|
if isZero && creatable {
|
||||||
|
if f.Kind() == reflect.Ptr {
|
||||||
|
db.Session(&gorm.Session{}).Create(f.Interface())
|
||||||
|
} else {
|
||||||
|
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
|
||||||
|
}
|
||||||
|
} else if !isZero && updatable {
|
||||||
|
if f.Kind() == reflect.Ptr {
|
||||||
|
db.Session(&gorm.Session{}).Save(f.Interface())
|
||||||
|
} else {
|
||||||
|
db.Session(&gorm.Session{}).Save(f.Addr().Interface())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if saveRef {
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if !ref.OwnPrimaryKey {
|
||||||
|
fv, _ := ref.PrimaryKey.ValueOf(f)
|
||||||
|
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) {
|
||||||
|
creatable := field.Creatable
|
||||||
|
updatable := field.Updatable
|
||||||
|
saveRef := true
|
||||||
|
|
||||||
|
if value, ok := db.Get("gorm:association_autocreate"); creatable && ok {
|
||||||
|
creatable = utils.CheckTruth(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok {
|
||||||
|
updatable = utils.CheckTruth(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := db.Get("gorm:association_save_reference"); ok {
|
||||||
|
saveRef = utils.CheckTruth(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return creatable, updatable, saveRef
|
||||||
|
}
|
|
@ -41,32 +41,6 @@ func BeforeCreate(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveBeforeAssociations(db *gorm.DB) {
|
|
||||||
if db.Statement.Schema != nil {
|
|
||||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
|
||||||
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
|
||||||
if f.Kind() == reflect.Ptr {
|
|
||||||
db.Session(&gorm.Session{}).Create(f.Interface())
|
|
||||||
} else {
|
|
||||||
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
|
||||||
if !ref.OwnPrimaryKey {
|
|
||||||
fv, _ := ref.PrimaryKey.ValueOf(f)
|
|
||||||
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Create(config *Config) func(db *gorm.DB) {
|
func Create(config *Config) func(db *gorm.DB) {
|
||||||
if config.WithReturning {
|
if config.WithReturning {
|
||||||
return CreateWithReturning
|
return CreateWithReturning
|
||||||
|
|
|
@ -21,7 +21,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
|
|
||||||
if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil {
|
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||||
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
|
||||||
reflectValue := reflect.ValueOf(value)
|
reflectValue := reflect.ValueOf(value)
|
||||||
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
for idx, pf := range tx.Statement.Schema.PrimaryFields {
|
||||||
|
@ -35,9 +35,6 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
tx.Statement.AddClause(where)
|
tx.Statement.AddClause(where)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tx.Statement.Selects) == 0 {
|
|
||||||
tx.Statement.Selects = []string{"*"}
|
|
||||||
}
|
|
||||||
tx.callbacks.Update().Execute(tx)
|
tx.callbacks.Update().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm/utils"
|
||||||
"github.com/jinzhu/now"
|
"github.com/jinzhu/now"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -146,13 +147,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
field.DBName = dbName
|
field.DBName = dbName
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) {
|
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
||||||
field.PrimaryKey = true
|
field.PrimaryKey = true
|
||||||
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) {
|
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
||||||
field.PrimaryKey = true
|
field.PrimaryKey = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) {
|
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
|
||||||
field.AutoIncrement = true
|
field.AutoIncrement = true
|
||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
}
|
}
|
||||||
|
@ -173,11 +174,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
field.Precision, _ = strconv.Atoi(p)
|
field.Precision, _ = strconv.Atoi(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) {
|
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
||||||
field.NotNull = true
|
field.NotNull = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) {
|
if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
|
||||||
field.Unique = true
|
field.Unique = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,13 +37,6 @@ func ParseTagSetting(str string, sep string) map[string]string {
|
||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkTruth(val string) bool {
|
|
||||||
if strings.ToLower(val) == "false" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func toColumns(val string) (results []string) {
|
func toColumns(val string) (results []string) {
|
||||||
if val != "" {
|
if val != "" {
|
||||||
for _, v := range strings.Split(val, ",") {
|
for _, v := range strings.Split(val, ",") {
|
||||||
|
|
|
@ -2,8 +2,10 @@ package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,3 +25,16 @@ func FileWithLineNum() string {
|
||||||
func IsChar(c rune) bool {
|
func IsChar(c rune) bool {
|
||||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
|
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CheckTruth(val interface{}) bool {
|
||||||
|
if v, ok := val.(bool); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := val.(string); ok {
|
||||||
|
v = strings.ToLower(v)
|
||||||
|
return v != "false"
|
||||||
|
}
|
||||||
|
|
||||||
|
return !reflect.ValueOf(val).IsZero()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue