Add permission check when create associations

This commit is contained in:
Jinzhu 2020-04-16 10:29:18 +08:00
parent 345ff7577c
commit 56ca9a87e0
6 changed files with 94 additions and 42 deletions

72
callbacks/associations.go Normal file
View File

@ -0,0 +1,72 @@
package callbacks
import (
"reflect"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/utils"
)
func SaveBeforeAssociations(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
creatable, updatable, saveRef := saveAssociationCheck(db, rel.Field)
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
_, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f)
if isZero && creatable {
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
}
} else if !isZero && updatable {
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Save(f.Interface())
} else {
db.Session(&gorm.Session{}).Save(f.Addr().Interface())
}
} else {
continue
}
if saveRef {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(f)
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
}
}
}
}
}
}
}
}
func saveAssociationCheck(db *gorm.DB, field *schema.Field) (bool, bool, bool) {
creatable := field.Creatable
updatable := field.Updatable
saveRef := true
if value, ok := db.Get("gorm:association_autocreate"); creatable && ok {
creatable = utils.CheckTruth(value)
}
if value, ok := db.Get("gorm:association_autoupdate"); updatable && ok {
updatable = utils.CheckTruth(value)
}
if value, ok := db.Get("gorm:association_save_reference"); ok {
saveRef = utils.CheckTruth(value)
}
return creatable, updatable, saveRef
}

View File

@ -41,32 +41,6 @@ func BeforeCreate(db *gorm.DB) {
} }
} }
func SaveBeforeAssociations(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice:
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() == reflect.Ptr {
db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Create(f.Addr().Interface())
}
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(f)
ref.ForeignKey.Set(db.Statement.ReflectValue, fv)
}
}
}
}
}
}
}
func Create(config *Config) func(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning { if config.WithReturning {
return CreateWithReturning return CreateWithReturning

View File

@ -21,7 +21,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
if err := tx.Statement.Parse(value); err != nil && tx.Statement.Schema != nil { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
reflectValue := reflect.ValueOf(value) reflectValue := reflect.ValueOf(value)
for idx, pf := range tx.Statement.Schema.PrimaryFields { for idx, pf := range tx.Statement.Schema.PrimaryFields {
@ -35,9 +35,6 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.AddClause(where) tx.Statement.AddClause(where)
} }
if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = []string{"*"}
}
tx.callbacks.Update().Execute(tx) tx.callbacks.Update().Execute(tx)
return return
} }

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/jinzhu/gorm/utils"
"github.com/jinzhu/now" "github.com/jinzhu/now"
) )
@ -146,13 +147,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DBName = dbName field.DBName = dbName
} }
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
field.PrimaryKey = true field.PrimaryKey = true
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
field.PrimaryKey = true field.PrimaryKey = true
} }
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
field.AutoIncrement = true field.AutoIncrement = true
field.HasDefaultValue = true field.HasDefaultValue = true
} }
@ -173,11 +174,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Precision, _ = strconv.Atoi(p) field.Precision, _ = strconv.Atoi(p)
} }
if val, ok := field.TagSettings["NOT NULL"]; ok && checkTruth(val) { if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true field.NotNull = true
} }
if val, ok := field.TagSettings["UNIQUE"]; ok && checkTruth(val) { if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
field.Unique = true field.Unique = true
} }

View File

@ -37,13 +37,6 @@ func ParseTagSetting(str string, sep string) map[string]string {
return settings return settings
} }
func checkTruth(val string) bool {
if strings.ToLower(val) == "false" {
return false
}
return true
}
func toColumns(val string) (results []string) { func toColumns(val string) (results []string) {
if val != "" { if val != "" {
for _, v := range strings.Split(val, ",") { for _, v := range strings.Split(val, ",") {

View File

@ -2,8 +2,10 @@ package utils
import ( import (
"fmt" "fmt"
"reflect"
"regexp" "regexp"
"runtime" "runtime"
"strings"
"unicode" "unicode"
) )
@ -23,3 +25,16 @@ func FileWithLineNum() string {
func IsChar(c rune) bool { func IsChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) return !unicode.IsLetter(c) && !unicode.IsNumber(c)
} }
func CheckTruth(val interface{}) bool {
if v, ok := val.(bool); ok {
return v
}
if v, ok := val.(string); ok {
v = strings.ToLower(v)
return v != "false"
}
return !reflect.ValueOf(val).IsZero()
}