From 8cb15cadde6e2c3ff1cc19e1182ce98b734ea7d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 08:35:01 +0800 Subject: [PATCH] Improve test structure --- callbacks/callbacks.go | 12 ++ callbacks/create.go | 24 ++++ callbacks/interface.go | 11 ++ dialects/mysql/go.mod | 7 + dialects/mysql/mysql.go | 29 ++++ dialects/mysql/mysql_test.go | 12 ++ dialects/sqlite/go.mod | 7 + dialects/sqlite/sqlite.go | 28 ++++ dialects/sqlite/sqlite_test.go | 15 ++ finisher_api.go | 1 + gorm.go | 33 ++++- schema/schema_helper_test.go | 250 +++++++++++++++++---------------- tests/create_test.go | 1 + 13 files changed, 304 insertions(+), 126 deletions(-) create mode 100644 callbacks/callbacks.go create mode 100644 callbacks/create.go create mode 100644 callbacks/interface.go create mode 100644 dialects/mysql/go.mod create mode 100644 dialects/mysql/mysql.go create mode 100644 dialects/mysql/mysql_test.go create mode 100644 dialects/sqlite/go.mod create mode 100644 dialects/sqlite/sqlite.go create mode 100644 dialects/sqlite/sqlite_test.go create mode 100644 tests/create_test.go diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go new file mode 100644 index 00000000..7fd12cb7 --- /dev/null +++ b/callbacks/callbacks.go @@ -0,0 +1,12 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func RegisterDefaultCallbacks(db *gorm.DB) { + callback := db.Callback() + callback.Create().Register("gorm:before_create", BeforeCreate) + callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) + callback.Create().Register("gorm:create", Create) + callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) + callback.Create().Register("gorm:after_create", AfterCreate) +} diff --git a/callbacks/create.go b/callbacks/create.go new file mode 100644 index 00000000..2fe27140 --- /dev/null +++ b/callbacks/create.go @@ -0,0 +1,24 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func BeforeCreate(db *gorm.DB) { + // before save + // before create + + // assign timestamp +} + +func SaveBeforeAssociations(db *gorm.DB) { +} + +func Create(db *gorm.DB) { +} + +func SaveAfterAssociations(db *gorm.DB) { +} + +func AfterCreate(db *gorm.DB) { + // after save + // after create +} diff --git a/callbacks/interface.go b/callbacks/interface.go new file mode 100644 index 00000000..0ef64fcd --- /dev/null +++ b/callbacks/interface.go @@ -0,0 +1,11 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +type beforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type beforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} diff --git a/dialects/mysql/go.mod b/dialects/mysql/go.mod new file mode 100644 index 00000000..a1f29122 --- /dev/null +++ b/dialects/mysql/go.mod @@ -0,0 +1,7 @@ +module github.com/jinzhu/gorm/dialects/mysql + +go 1.13 + +require ( + github.com/go-sql-driver/mysql v1.5.0 +) diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go new file mode 100644 index 00000000..ba306889 --- /dev/null +++ b/dialects/mysql/mysql.go @@ -0,0 +1,29 @@ +package mysql + +import ( + _ "github.com/go-sql-driver/mysql" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" +) + +type Dialector struct { +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{} +} + +func (Dialector) Initialize(db *gorm.DB) error { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + return nil +} + +func (Dialector) Migrator() gorm.Migrator { + return nil +} + +func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { + return "?" +} diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go new file mode 100644 index 00000000..49c26915 --- /dev/null +++ b/dialects/mysql/mysql_test.go @@ -0,0 +1,12 @@ +package mysql_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mysql" +) + +func TestOpen(t *testing.T) { + gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil) +} diff --git a/dialects/sqlite/go.mod b/dialects/sqlite/go.mod new file mode 100644 index 00000000..db3370e9 --- /dev/null +++ b/dialects/sqlite/go.mod @@ -0,0 +1,7 @@ +module github.com/jinzhu/gorm/dialects/mysql + +go 1.13 + +require ( + github.com/mattn/go-sqlite3 v2.0.3+incompatible +) diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go new file mode 100644 index 00000000..f3c3f0c7 --- /dev/null +++ b/dialects/sqlite/sqlite.go @@ -0,0 +1,28 @@ +package sqlite + +import ( + "github.com/jinzhu/gorm/callbacks" + _ "github.com/mattn/go-sqlite3" +) + +type Dialector struct { +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{} +} + +func (Dialector) Initialize(db *gorm.DB) error { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + return nil +} + +func (Dialector) Migrator() gorm.Migrator { + return nil +} + +func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { + return "?" +} diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go new file mode 100644 index 00000000..f0429a12 --- /dev/null +++ b/dialects/sqlite/sqlite_test.go @@ -0,0 +1,15 @@ +package sqlite_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/jinzhu/gorm" +) + +var DB *gorm.DB + +func TestOpen(t *testing.T) { + db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) +} diff --git a/finisher_api.go b/finisher_api.go index 2668e1fe..b155e90d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -12,6 +12,7 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() + tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) return } diff --git a/gorm.go b/gorm.go index 6ceac412..896d07f9 100644 --- a/gorm.go +++ b/gorm.go @@ -13,7 +13,7 @@ import ( type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can cancel it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool + SkipDefaultTransaction bool // TODO // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -27,6 +27,7 @@ type Config struct { // Dialector GORM database dialector type Dialector interface { + Initialize(*DB) error Migrator() Migrator BindVar(stmt Statement, v interface{}) string } @@ -36,7 +37,8 @@ type DB struct { *Config Dialector Instance - clone bool + clone bool + callbacks *callbacks } // Session session config when create new session @@ -48,15 +50,33 @@ type Session struct { // Open initialize db session based on dialector func Open(dialector Dialector, config *Config) (db *DB, err error) { + if config == nil { + config = &Config{} + } + if config.NamingStrategy == nil { config.NamingStrategy = schema.NamingStrategy{} } - return &DB{ + if config.Logger == nil { + config.Logger = logger.Default + } + + if config.NowFunc == nil { + config.NowFunc = func() time.Time { return time.Now().Local() } + } + + db = &DB{ Config: config, Dialector: dialector, clone: true, - }, nil + callbacks: InitializeCallbacks(), + } + + if dialector != nil { + err = dialector.Initialize(db) + } + return } // Session create new db session @@ -112,6 +132,11 @@ func (db *DB) Get(key string) (interface{}, bool) { return nil, false } +// Callback returns callback manager +func (db *DB) Callback() *callbacks { + return db.callbacks +} + func (db *DB) getInstance() *DB { if db.clone { ctx := db.Instance.Context diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index ce91d8d1..05f41131 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -10,85 +10,89 @@ import ( ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { - equalFieldNames := []string{"Name", "Table"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(v).FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) - } - } - - for idx, field := range primaryFields { - var found bool - for _, f := range s.PrimaryFields { - if f.Name == field { - found = true - } - } - - if idx == 0 { - if field != s.PrioritizedPrimaryField.Name { - t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) - } - } - - if !found { - t.Errorf("schema %v failed to found priamry key: %v", s, field) - } - } -} - -func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { - if fc != nil { - fc(f) - } - - if f.TagSettings == nil { - if f.Tag != "" { - f.TagSettings = schema.ParseTagSetting(f.Tag) - } else { - f.TagSettings = map[string]string{} - } - } - - if parsedField, ok := s.FieldsByName[f.Name]; !ok { - t.Errorf("schema %v failed to look up field with name %v", s, f.Name) - } else { - equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + t.Run("CheckSchema/"+s.Name, func(t *testing.T) { + equalFieldNames := []string{"Name", "Table"} for _, name := range equalFieldNames { - got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() + got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(v).FieldByName(name).Interface() if !reflect.DeepEqual(got, expects) { - t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) } } - if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - - for _, name := range []string{f.DBName, f.Name} { - if field := s.LookUpField(name); field == nil || parsedField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - } - - if f.PrimaryKey { + for idx, field := range primaryFields { var found bool - for _, primaryField := range s.PrimaryFields { - if primaryField == parsedField { + for _, f := range s.PrimaryFields { + if f.Name == field { found = true } } + if idx == 0 { + if field != s.PrioritizedPrimaryField.Name { + t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) + } + } + if !found { - t.Errorf("schema %v doesn't include field %v", s, f.Name) + t.Errorf("schema %v failed to found priamry key: %v", s, field) } } - } + }) +} + +func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { + t.Run("CheckField/"+f.Name, func(t *testing.T) { + if fc != nil { + fc(f) + } + + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag) + } else { + f.TagSettings = map[string]string{} + } + } + + if parsedField, ok := s.FieldsByName[f.Name]; !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + } + } + + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + + for _, name := range []string{f.DBName, f.Name} { + if field := s.LookUpField(name); field == nil || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + + if f.PrimaryKey { + var found bool + for _, primaryField := range s.PrimaryFields { + if primaryField == parsedField { + found = true + } + } + + if !found { + t.Errorf("schema %v doesn't include field %v", s, f.Name) + } + } + } + }) } type Relation struct { @@ -123,79 +127,81 @@ type Reference struct { } func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { - if r, ok := s.Relationships.Relations[relation.Name]; ok { - if r.Name != relation.Name { - t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) - } - - if r.Type != relation.Type { - t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) - } - - if r.Schema.Name != relation.Schema { - t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) - } - - if r.FieldSchema.Name != relation.FieldSchema { - t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) - } - - if r.Polymorphic != nil { - if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { - t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + t.Run("CheckRelation/"+relation.Name, func(t *testing.T) { + if r, ok := s.Relationships.Relations[relation.Name]; ok { + if r.Name != relation.Name { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) } - if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { - t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + if r.Type != relation.Type { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) } - if r.Polymorphic.Value != relation.Polymorphic.Value { - t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) - } - } - - if r.JoinTable != nil { - if r.JoinTable.Name != relation.JoinTable.Name { - t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + if r.Schema.Name != relation.Schema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) } - if r.JoinTable.Table != relation.JoinTable.Table { - t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + if r.FieldSchema.Name != relation.FieldSchema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) } - for _, f := range relation.JoinTable.Fields { - checkSchemaField(t, r.JoinTable, &f, nil) - } - } + if r.Polymorphic != nil { + if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { + t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + } - if len(relation.References) != len(r.References) { - t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) - } + if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { + t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + } - for _, ref := range relation.References { - var found bool - for _, rf := range r.References { - if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { - found = true + if r.Polymorphic.Value != relation.Polymorphic.Value { + t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) } } - if !found { - var refs []string + if r.JoinTable != nil { + if r.JoinTable.Name != relation.JoinTable.Name { + t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + } + + if r.JoinTable.Table != relation.JoinTable.Table { + t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + } + + for _, f := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &f, nil) + } + } + + if len(relation.References) != len(r.References) { + t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) + } + + for _, ref := range relation.References { + var found bool for _, rf := range r.References { - var primaryKey, primaryKeySchema string - if rf.PrimaryKey != nil { - primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { + found = true } - refs = append(refs, fmt.Sprintf( - "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", - primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, - )) } - t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + + if !found { + var refs []string + for _, rf := range r.References { + var primaryKey, primaryKeySchema string + if rf.PrimaryKey != nil { + primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + } + refs = append(refs, fmt.Sprintf( + "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", + primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, + )) + } + t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + } } + } else { + t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) } - } else { - t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) - } + }) } diff --git a/tests/create_test.go b/tests/create_test.go new file mode 100644 index 00000000..ca8701d2 --- /dev/null +++ b/tests/create_test.go @@ -0,0 +1 @@ +package tests