Add field permission test

This commit is contained in:
Jinzhu 2020-06-05 19:18:22 +08:00
parent c8e7878b3e
commit d50879cc28
9 changed files with 300 additions and 121 deletions

View File

@ -10,7 +10,7 @@ import (
) )
func SetupUpdateReflectValue(db *gorm.DB) { func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil { if db.Error == nil && db.Statement.Schema != nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr { for db.Statement.ReflectValue.Kind() == reflect.Ptr {
@ -172,26 +172,38 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
sort.Strings(keys) sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
if field := stmt.Schema.LookUpField(k); field != nil { if stmt.Schema != nil {
if field.DBName != "" { if field := stmt.Schema.LookUpField(k); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if field.DBName != "" {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k]) assignValue(field, value[k])
} }
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { continue
assignValue(field, value[k])
} }
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { }
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
} }
} }
if !stmt.DisableUpdateTime { if !stmt.DisableUpdateTime && stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByDBName { for _, field := range stmt.Schema.FieldsByDBName {
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
now := stmt.DB.NowFunc() now := stmt.DB.NowFunc()
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
assignValue(field, now) assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.DataType == schema.Time {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
}
} }
} }
} }
@ -205,7 +217,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
value, isZero := field.ValueOf(updatingValue) value, isZero := field.ValueOf(updatingValue)
if !stmt.DisableUpdateTime { if !stmt.DisableUpdateTime {
if field.AutoUpdateTime > 0 { if field.AutoUpdateTime > 0 {
value = stmt.DB.NowFunc() if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.DataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
}
isZero = false isZero = false
} }
} }

View File

@ -133,33 +133,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
// setup permission
if _, ok := field.TagSettings["-"]; ok {
field.Creatable = false
field.Updatable = false
field.Readable = false
}
if v, ok := field.TagSettings["<-"]; ok {
if v != "<-" {
if !strings.Contains(v, "create") {
field.Creatable = false
}
if !strings.Contains(v, "update") {
field.Updatable = false
}
}
field.Readable = false
}
if _, ok := field.TagSettings["->"]; ok {
field.Creatable = false
field.Updatable = false
field.Readable = true
}
if dbName, ok := field.TagSettings["COLUMN"]; ok { if dbName, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = dbName field.DBName = dbName
} }
@ -276,6 +249,39 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
// setup permission
if _, ok := field.TagSettings["-"]; ok {
field.Creatable = false
field.Updatable = false
field.Readable = false
field.DataType = ""
}
if v, ok := field.TagSettings["->"]; ok {
field.Creatable = false
field.Updatable = false
if strings.ToLower(v) == "false" {
field.Readable = false
} else {
field.Readable = true
}
}
if v, ok := field.TagSettings["<-"]; ok {
field.Creatable = true
field.Updatable = true
if v != "<-" {
if !strings.Contains(v, "create") {
field.Creatable = false
}
if !strings.Contains(v, "update") {
field.Updatable = false
}
}
}
if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
var err error var err error
field.Creatable = false field.Creatable = false
@ -510,14 +516,14 @@ func (field *Field) setupValuerAndSetter() {
return err return err
} }
case time.Time: case time.Time:
if field.AutoCreateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano()) field.ReflectValueOf(value).SetInt(data.UnixNano())
} else { } else {
field.ReflectValueOf(value).SetInt(data.Unix()) field.ReflectValueOf(value).SetInt(data.Unix())
} }
case *time.Time: case *time.Time:
if data != nil { if data != nil {
if field.AutoCreateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano()) field.ReflectValueOf(value).SetInt(data.UnixNano())
} else { } else {
field.ReflectValueOf(value).SetInt(data.Unix()) field.ReflectValueOf(value).SetInt(data.Unix())

View File

@ -225,6 +225,7 @@ type UserWithPermissionControl struct {
Name4 string `gorm:"<-:create"` Name4 string `gorm:"<-:create"`
Name5 string `gorm:"<-:update"` Name5 string `gorm:"<-:update"`
Name6 string `gorm:"<-:create,update"` Name6 string `gorm:"<-:create,update"`
Name7 string `gorm:"->:false;<-:create,update"`
} }
func TestParseFieldWithPermission(t *testing.T) { func TestParseFieldWithPermission(t *testing.T) {
@ -235,12 +236,13 @@ func TestParseFieldWithPermission(t *testing.T) {
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String, Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
{Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: false}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: false}, {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true},
{Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: false}, {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true},
{Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false},
} }
for _, f := range fields { for _, f := range fields {

View File

@ -54,13 +54,17 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
} else { } else {
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { if f.DBName != "" {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
} }
for _, name := range []string{f.DBName, f.Name} { for _, name := range []string{f.DBName, f.Name} {
if field := s.LookUpField(name); field == nil || parsedField != field { if name != "" {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) if field := s.LookUpField(name); field == nil || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
} }
} }

View File

@ -1,56 +0,0 @@
package tests_test
import (
"testing"
"time"
)
func TestCustomizeColumn(t *testing.T) {
type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date *time.Time `gorm:"column:mapped_time"`
}
DB.Migrator().DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
expected := "foo"
now := time.Now()
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
var cc1 CustomizeColumn
DB.First(&cc1, "mapped_name = ?", "foo")
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
cc.Name = "bar"
DB.Save(&cc)
var cc2 CustomizeColumn
DB.First(&cc2, "mapped_id = ?", 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `gorm:"-"`
RawBody string `gorm:"column:body"`
}
DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil {
t.Errorf("Should not raise error: %v", err)
}
}

View File

@ -0,0 +1,172 @@
package tests_test
import (
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestCustomizeColumn(t *testing.T) {
type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date *time.Time `gorm:"column:mapped_time"`
}
DB.Migrator().DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
expected := "foo"
now := time.Now()
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
var cc1 CustomizeColumn
DB.First(&cc1, "mapped_name = ?", "foo")
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
cc.Name = "bar"
DB.Save(&cc)
var cc2 CustomizeColumn
DB.First(&cc2, "mapped_id = ?", 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `gorm:"-"`
RawBody string `gorm:"column:body"`
}
DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil {
t.Errorf("Should not raise error: %v", err)
}
}
func TestCustomizeField(t *testing.T) {
type CustomizeFieldStruct struct {
gorm.Model
Name string
FieldAllowCreate string `gorm:"<-:create"`
FieldAllowUpdate string `gorm:"<-:update"`
FieldAllowSave string `gorm:"<-"`
FieldAllowSave2 string `gorm:"<-:create,update"`
FieldAllowSave3 string `gorm:"->:false;<-:create"`
FieldReadonly string `gorm:"->"`
FieldIgnore string `gorm:"-"`
AutoUnixCreateTime int64 `gorm:"autocreatetime"`
AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"`
AutoUnixUpdateTime int64 `gorm:"autoupdatetime"`
AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"`
}
DB.Migrator().DropTable(&CustomizeFieldStruct{})
if err := DB.AutoMigrate(&CustomizeFieldStruct{}); err != nil {
t.Errorf("Failed to migrate, got error: %v", err)
}
if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "FieldIgnore") {
t.Errorf("FieldIgnore should not be created")
}
if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "field_ignore") {
t.Errorf("FieldIgnore should not be created")
}
generateStruct := func(name string) *CustomizeFieldStruct {
return &CustomizeFieldStruct{
Name: name,
FieldAllowCreate: name + "_allow_create",
FieldAllowUpdate: name + "_allow_update",
FieldAllowSave: name + "_allow_save",
FieldAllowSave2: name + "_allow_save2",
FieldAllowSave3: name + "_allow_save3",
FieldReadonly: name + "_allow_readonly",
FieldIgnore: name + "_allow_ignore",
}
}
create := generateStruct("create")
DB.Create(&create)
var result CustomizeFieldStruct
DB.Find(&result, "name = ?", "create")
AssertObjEqual(t, result, create, "Name", "FieldAllowCreate", "FieldAllowSave", "FieldAllowSave2")
if result.FieldAllowUpdate != "" || result.FieldReadonly != "" || result.FieldIgnore != "" || result.FieldAllowSave3 != "" {
t.Fatalf("invalid result: %#v", result)
}
if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 {
t.Fatalf("invalid create/update unix time: %#v", result)
}
if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 {
t.Fatalf("invalid create/update unix nano time: %#v", result)
}
result.FieldAllowUpdate = "field_allow_update_updated"
result.FieldReadonly = "field_readonly_updated"
result.FieldIgnore = "field_ignore_updated"
DB.Save(&result)
var result2 CustomizeFieldStruct
DB.Find(&result2, "name = ?", "create")
if result2.FieldAllowUpdate != result.FieldAllowUpdate || result2.FieldReadonly != "" || result2.FieldIgnore != "" {
t.Fatalf("invalid updated result: %#v", result2)
}
if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil {
t.Fatalf("failed to update field_readonly column")
}
var result3 CustomizeFieldStruct
DB.Find(&result3, "name = ?", "create")
if result3.FieldReadonly != "readonly" {
t.Fatalf("invalid updated result: %#v", result3)
}
var result4 CustomizeFieldStruct
if err := DB.First(&result4, "field_allow_save3 = ?", create.FieldAllowSave3).Error; err != nil {
t.Fatalf("failed to query with inserted field, got error %v", err)
}
AssertEqual(t, result3, result4)
createWithDefaultTime := generateStruct("create_with_default_time")
createWithDefaultTime.AutoUnixCreateTime = 100
createWithDefaultTime.AutoUnixUpdateTime = 100
createWithDefaultTime.AutoUnixNanoCreateTime = 100
createWithDefaultTime.AutoUnixNanoUpdateTime = 100
DB.Create(&createWithDefaultTime)
var createWithDefaultTimeResult CustomizeFieldStruct
DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name)
if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 {
t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult)
}
if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 {
t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult)
}
}

View File

@ -6,7 +6,7 @@ require (
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 gorm.io/driver/sqlite v1.0.0
gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2
gorm.io/gorm v0.0.0-00010101000000-000000000000 gorm.io/gorm v0.0.0-00010101000000-000000000000
) )

View File

@ -67,22 +67,39 @@ func TestFind(t *testing.T) {
} }
}) })
var allMap = []map[string]interface{}{} t.Run("FirstPtrMap", func(t *testing.T) {
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { var first = map[string]interface{}{}
t.Errorf("errors happened when query first: %v", err) if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil {
} else { t.Errorf("errors happened when query first: %v", err)
for idx, user := range users { } else {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} {
for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) {
t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name)
dbName := DB.NamingStrategy.ColumnName("", name) reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) })
}) }
}
})
} }
} })
t.Run("FirstSliceOfMap", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
}
})
}
}
})
} }
func TestFillSmallerStruct(t *testing.T) { func TestFillSmallerStruct(t *testing.T) {

View File

@ -122,3 +122,19 @@ func TestQueryRaw(t *testing.T) {
DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user)
CheckUser(t, user, *users[1]) CheckUser(t, user, *users[1])
} }
func TestDryRun(t *testing.T) {
user := *GetUser("dry-run", Config{})
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Create(&user).Statement
if stmt.SQL.String() == "" || len(stmt.Vars) != 9 {
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
}
stmt2 := dryRunDB.Find(&user, "id = ?", user.ID).Statement
if stmt2.SQL.String() == "" || len(stmt2.Vars) != 1 {
t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String())
}
}