From d50879cc280520f944a965577ce3198cb1933161 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 5 Jun 2020 19:18:22 +0800 Subject: [PATCH] Add field permission test --- callbacks/update.go | 40 +++++--- schema/field.go | 64 ++++++------ schema/field_test.go | 12 ++- schema/schema_helper_test.go | 12 ++- tests/customize_column_test.go | 56 ----------- tests/customize_field_test.go | 172 +++++++++++++++++++++++++++++++++ tests/go.mod | 2 +- tests/query_test.go | 47 ++++++--- tests/sql_builder_test.go | 16 +++ 9 files changed, 300 insertions(+), 121 deletions(-) delete mode 100644 tests/customize_column_test.go create mode 100644 tests/customize_field_test.go diff --git a/callbacks/update.go b/callbacks/update.go index 9b2e924b..2589370f 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -10,7 +10,7 @@ import ( ) func SetupUpdateReflectValue(db *gorm.DB) { - if db.Error == nil { + if db.Error == nil && db.Statement.Schema != nil { if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) for db.Statement.ReflectValue.Kind() == reflect.Ptr { @@ -172,26 +172,38 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { - if field := stmt.Schema.LookUpField(k); field != nil { - if field.DBName != "" { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { assignValue(field, value[k]) } - } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { - assignValue(field, value[k]) + continue } - } else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) } } - if !stmt.DisableUpdateTime { + if !stmt.DisableUpdateTime && stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByDBName { if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { now := stmt.DB.NowFunc() - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) assignValue(field, now) + + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.DataType == schema.Time { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } } } } @@ -205,7 +217,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value, isZero := field.ValueOf(updatingValue) if !stmt.DisableUpdateTime { if field.AutoUpdateTime > 0 { - value = stmt.DB.NowFunc() + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.DataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } isZero = false } } diff --git a/schema/field.go b/schema/field.go index a27fdd87..854ec520 100644 --- a/schema/field.go +++ b/schema/field.go @@ -133,33 +133,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - // setup permission - if _, ok := field.TagSettings["-"]; ok { - field.Creatable = false - field.Updatable = false - field.Readable = false - } - - if v, ok := field.TagSettings["<-"]; ok { - if v != "<-" { - if !strings.Contains(v, "create") { - field.Creatable = false - } - - if !strings.Contains(v, "update") { - field.Updatable = false - } - } - - field.Readable = false - } - - if _, ok := field.TagSettings["->"]; ok { - field.Creatable = false - field.Updatable = false - field.Readable = true - } - if dbName, ok := field.TagSettings["COLUMN"]; ok { field.DBName = dbName } @@ -276,6 +249,39 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + // setup permission + if _, ok := field.TagSettings["-"]; ok { + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + } + + if v, ok := field.TagSettings["->"]; ok { + field.Creatable = false + field.Updatable = false + if strings.ToLower(v) == "false" { + field.Readable = false + } else { + field.Readable = true + } + } + + if v, ok := field.TagSettings["<-"]; ok { + field.Creatable = true + field.Updatable = true + + if v != "<-" { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + } + if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { var err error field.Creatable = false @@ -510,14 +516,14 @@ func (field *Field) setupValuerAndSetter() { return err } case time.Time: - if field.AutoCreateTime == UnixNanosecond { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } case *time.Time: if data != nil { - if field.AutoCreateTime == UnixNanosecond { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else { field.ReflectValueOf(value).SetInt(data.Unix()) diff --git a/schema/field_test.go b/schema/field_test.go index fe88891f..cc4b53fc 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -225,6 +225,7 @@ type UserWithPermissionControl struct { Name4 string `gorm:"<-:create"` Name5 string `gorm:"<-:update"` Name6 string `gorm:"<-:create,update"` + Name7 string `gorm:"->:false;<-:create,update"` } func TestParseFieldWithPermission(t *testing.T) { @@ -235,12 +236,13 @@ func TestParseFieldWithPermission(t *testing.T) { fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, - {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String, Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, + {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, - {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: false}, - {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: false}, - {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: false}, - {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true}, + {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, + {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, } for _, f := range fields { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index f2ed4145..d2e68536 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -54,13 +54,17 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } else { tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") - if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + if f.DBName != "" { + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } } for _, name := range []string{f.DBName, f.Name} { - if field := s.LookUpField(name); field == nil || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + if name != "" { + if field := s.LookUpField(name); field == nil || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } } } diff --git a/tests/customize_column_test.go b/tests/customize_column_test.go deleted file mode 100644 index 98dea494..00000000 --- a/tests/customize_column_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package tests_test - -import ( - "testing" - "time" -) - -func TestCustomizeColumn(t *testing.T) { - type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date *time.Time `gorm:"column:mapped_time"` - } - - DB.Migrator().DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - expected := "foo" - now := time.Now() - cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, "mapped_name = ?", "foo") - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, "mapped_id = ?", 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - // Make sure an ignored field does not interfere with another field's custom - // column name that matches the ignored field. - type CustomColumnAndIgnoredFieldClash struct { - Body string `gorm:"-"` - RawBody string `gorm:"column:body"` - } - - DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) - - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { - t.Errorf("Should not raise error: %v", err) - } -} diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go new file mode 100644 index 00000000..910fa6ae --- /dev/null +++ b/tests/customize_field_test.go @@ -0,0 +1,172 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCustomizeColumn(t *testing.T) { + type CustomizeColumn struct { + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date *time.Time `gorm:"column:mapped_time"` + } + + DB.Migrator().DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) + + expected := "foo" + now := time.Now() + cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} + + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + var cc1 CustomizeColumn + DB.First(&cc1, "mapped_name = ?", "foo") + + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, "mapped_id = ?", 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } +} + +func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { + // Make sure an ignored field does not interfere with another field's custom + // column name that matches the ignored field. + type CustomColumnAndIgnoredFieldClash struct { + Body string `gorm:"-"` + RawBody string `gorm:"column:body"` + } + + DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) + + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { + t.Errorf("Should not raise error: %v", err) + } +} + +func TestCustomizeField(t *testing.T) { + type CustomizeFieldStruct struct { + gorm.Model + Name string + FieldAllowCreate string `gorm:"<-:create"` + FieldAllowUpdate string `gorm:"<-:update"` + FieldAllowSave string `gorm:"<-"` + FieldAllowSave2 string `gorm:"<-:create,update"` + FieldAllowSave3 string `gorm:"->:false;<-:create"` + FieldReadonly string `gorm:"->"` + FieldIgnore string `gorm:"-"` + AutoUnixCreateTime int64 `gorm:"autocreatetime"` + AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` + AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` + AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + } + + DB.Migrator().DropTable(&CustomizeFieldStruct{}) + + if err := DB.AutoMigrate(&CustomizeFieldStruct{}); err != nil { + t.Errorf("Failed to migrate, got error: %v", err) + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "FieldIgnore") { + t.Errorf("FieldIgnore should not be created") + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "field_ignore") { + t.Errorf("FieldIgnore should not be created") + } + + generateStruct := func(name string) *CustomizeFieldStruct { + return &CustomizeFieldStruct{ + Name: name, + FieldAllowCreate: name + "_allow_create", + FieldAllowUpdate: name + "_allow_update", + FieldAllowSave: name + "_allow_save", + FieldAllowSave2: name + "_allow_save2", + FieldAllowSave3: name + "_allow_save3", + FieldReadonly: name + "_allow_readonly", + FieldIgnore: name + "_allow_ignore", + } + } + + create := generateStruct("create") + DB.Create(&create) + + var result CustomizeFieldStruct + DB.Find(&result, "name = ?", "create") + + AssertObjEqual(t, result, create, "Name", "FieldAllowCreate", "FieldAllowSave", "FieldAllowSave2") + + if result.FieldAllowUpdate != "" || result.FieldReadonly != "" || result.FieldIgnore != "" || result.FieldAllowSave3 != "" { + t.Fatalf("invalid result: %#v", result) + } + + if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { + t.Fatalf("invalid create/update unix time: %#v", result) + } + + if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { + t.Fatalf("invalid create/update unix nano time: %#v", result) + } + + result.FieldAllowUpdate = "field_allow_update_updated" + result.FieldReadonly = "field_readonly_updated" + result.FieldIgnore = "field_ignore_updated" + DB.Save(&result) + + var result2 CustomizeFieldStruct + DB.Find(&result2, "name = ?", "create") + + if result2.FieldAllowUpdate != result.FieldAllowUpdate || result2.FieldReadonly != "" || result2.FieldIgnore != "" { + t.Fatalf("invalid updated result: %#v", result2) + } + + if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { + t.Fatalf("failed to update field_readonly column") + } + + var result3 CustomizeFieldStruct + DB.Find(&result3, "name = ?", "create") + + if result3.FieldReadonly != "readonly" { + t.Fatalf("invalid updated result: %#v", result3) + } + + var result4 CustomizeFieldStruct + if err := DB.First(&result4, "field_allow_save3 = ?", create.FieldAllowSave3).Error; err != nil { + t.Fatalf("failed to query with inserted field, got error %v", err) + } + + AssertEqual(t, result3, result4) + + createWithDefaultTime := generateStruct("create_with_default_time") + createWithDefaultTime.AutoUnixCreateTime = 100 + createWithDefaultTime.AutoUnixUpdateTime = 100 + createWithDefaultTime.AutoUnixNanoCreateTime = 100 + createWithDefaultTime.AutoUnixNanoUpdateTime = 100 + DB.Create(&createWithDefaultTime) + + var createWithDefaultTimeResult CustomizeFieldStruct + DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) + + if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) + } + + if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) + } +} diff --git a/tests/go.mod b/tests/go.mod index 3401b9b2..de58a0de 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/jinzhu/now v1.1.1 gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 - gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 + gorm.io/driver/sqlite v1.0.0 gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 gorm.io/gorm v0.0.0-00010101000000-000000000000 ) diff --git a/tests/query_test.go b/tests/query_test.go index 18ffb3fb..66413b3b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -67,22 +67,39 @@ func TestFind(t *testing.T) { } }) - var allMap = []map[string]interface{}{} - if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) - } else { - for idx, user := range users { - t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { - for _, name := range []string{"Name", "Age", "Birthday"} { - t.Run(name, func(t *testing.T) { - dbName := DB.NamingStrategy.ColumnName("", name) - reflectValue := reflect.Indirect(reflect.ValueOf(user)) - AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) - }) - } - }) + t.Run("FirstPtrMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } } - } + }) + + t.Run("FirstSliceOfMap", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) } func TestFillSmallerStruct(t *testing.T) { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 278a5b96..a60514c9 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -122,3 +122,19 @@ func TestQueryRaw(t *testing.T) { DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) CheckUser(t, user, *users[1]) } + +func TestDryRun(t *testing.T) { + user := *GetUser("dry-run", Config{}) + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&user).Statement + if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + stmt2 := dryRunDB.Find(&user, "id = ?", user.ID).Statement + if stmt2.SQL.String() == "" || len(stmt2.Vars) != 1 { + t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) + } +}