Improve test structure

This commit is contained in:
Jinzhu 2020-02-02 08:35:01 +08:00
parent 3cbd233758
commit 8cb15cadde
13 changed files with 304 additions and 126 deletions

12
callbacks/callbacks.go Normal file
View File

@ -0,0 +1,12 @@
package callbacks
import "github.com/jinzhu/gorm"
func RegisterDefaultCallbacks(db *gorm.DB) {
callback := db.Callback()
callback.Create().Register("gorm:before_create", BeforeCreate)
callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
callback.Create().Register("gorm:create", Create)
callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
callback.Create().Register("gorm:after_create", AfterCreate)
}

24
callbacks/create.go Normal file
View File

@ -0,0 +1,24 @@
package callbacks
import "github.com/jinzhu/gorm"
func BeforeCreate(db *gorm.DB) {
// before save
// before create
// assign timestamp
}
func SaveBeforeAssociations(db *gorm.DB) {
}
func Create(db *gorm.DB) {
}
func SaveAfterAssociations(db *gorm.DB) {
}
func AfterCreate(db *gorm.DB) {
// after save
// after create
}

11
callbacks/interface.go Normal file
View File

@ -0,0 +1,11 @@
package callbacks
import "github.com/jinzhu/gorm"
type beforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type beforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}

7
dialects/mysql/go.mod Normal file
View File

@ -0,0 +1,7 @@
module github.com/jinzhu/gorm/dialects/mysql
go 1.13
require (
github.com/go-sql-driver/mysql v1.5.0
)

29
dialects/mysql/mysql.go Normal file
View File

@ -0,0 +1,29 @@
package mysql
import (
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
)
type Dialector struct {
}
func Open(dsn string) gorm.Dialector {
return &Dialector{}
}
func (Dialector) Initialize(db *gorm.DB) error {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
return nil
}
func (Dialector) Migrator() gorm.Migrator {
return nil
}
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
return "?"
}

View File

@ -0,0 +1,12 @@
package mysql_test
import (
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/mysql"
)
func TestOpen(t *testing.T) {
gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil)
}

7
dialects/sqlite/go.mod Normal file
View File

@ -0,0 +1,7 @@
module github.com/jinzhu/gorm/dialects/mysql
go 1.13
require (
github.com/mattn/go-sqlite3 v2.0.3+incompatible
)

28
dialects/sqlite/sqlite.go Normal file
View File

@ -0,0 +1,28 @@
package sqlite
import (
"github.com/jinzhu/gorm/callbacks"
_ "github.com/mattn/go-sqlite3"
)
type Dialector struct {
}
func Open(dsn string) gorm.Dialector {
return &Dialector{}
}
func (Dialector) Initialize(db *gorm.DB) error {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)
return nil
}
func (Dialector) Migrator() gorm.Migrator {
return nil
}
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
return "?"
}

View File

@ -0,0 +1,15 @@
package sqlite_test
import (
"os"
"path/filepath"
"testing"
"github.com/jinzhu/gorm"
)
var DB *gorm.DB
func TestOpen(t *testing.T) {
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
}

View File

@ -12,6 +12,7 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.callbacks.Create().Execute(tx.Limit(1).Order("id"))
return return
} }

33
gorm.go
View File

@ -13,7 +13,7 @@ import (
type Config struct { type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can cancel it by setting `SkipDefaultTransaction` to true // You can cancel it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool SkipDefaultTransaction bool // TODO
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
@ -27,6 +27,7 @@ type Config struct {
// Dialector GORM database dialector // Dialector GORM database dialector
type Dialector interface { type Dialector interface {
Initialize(*DB) error
Migrator() Migrator Migrator() Migrator
BindVar(stmt Statement, v interface{}) string BindVar(stmt Statement, v interface{}) string
} }
@ -36,7 +37,8 @@ type DB struct {
*Config *Config
Dialector Dialector
Instance Instance
clone bool clone bool
callbacks *callbacks
} }
// Session session config when create new session // Session session config when create new session
@ -48,15 +50,33 @@ type Session struct {
// Open initialize db session based on dialector // Open initialize db session based on dialector
func Open(dialector Dialector, config *Config) (db *DB, err error) { func Open(dialector Dialector, config *Config) (db *DB, err error) {
if config == nil {
config = &Config{}
}
if config.NamingStrategy == nil { if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{} config.NamingStrategy = schema.NamingStrategy{}
} }
return &DB{ if config.Logger == nil {
config.Logger = logger.Default
}
if config.NowFunc == nil {
config.NowFunc = func() time.Time { return time.Now().Local() }
}
db = &DB{
Config: config, Config: config,
Dialector: dialector, Dialector: dialector,
clone: true, clone: true,
}, nil callbacks: InitializeCallbacks(),
}
if dialector != nil {
err = dialector.Initialize(db)
}
return
} }
// Session create new db session // Session create new db session
@ -112,6 +132,11 @@ func (db *DB) Get(key string) (interface{}, bool) {
return nil, false return nil, false
} }
// Callback returns callback manager
func (db *DB) Callback() *callbacks {
return db.callbacks
}
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {
if db.clone { if db.clone {
ctx := db.Instance.Context ctx := db.Instance.Context

View File

@ -10,85 +10,89 @@ import (
) )
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
equalFieldNames := []string{"Name", "Table"} t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
equalFieldNames := []string{"Name", "Table"}
for _, name := range equalFieldNames {
got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(v).FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
}
}
for idx, field := range primaryFields {
var found bool
for _, f := range s.PrimaryFields {
if f.Name == field {
found = true
}
}
if idx == 0 {
if field != s.PrioritizedPrimaryField.Name {
t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name)
}
}
if !found {
t.Errorf("schema %v failed to found priamry key: %v", s, field)
}
}
}
func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) {
if fc != nil {
fc(f)
}
if f.TagSettings == nil {
if f.Tag != "" {
f.TagSettings = schema.ParseTagSetting(f.Tag)
} else {
f.TagSettings = map[string]string{}
}
}
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else {
equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
for _, name := range equalFieldNames { for _, name := range equalFieldNames {
got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() expects := reflect.ValueOf(v).FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) { if !reflect.DeepEqual(got, expects) {
t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
} }
} }
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { for idx, field := range primaryFields {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
for _, name := range []string{f.DBName, f.Name} {
if field := s.LookUpField(name); field == nil || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
}
if f.PrimaryKey {
var found bool var found bool
for _, primaryField := range s.PrimaryFields { for _, f := range s.PrimaryFields {
if primaryField == parsedField { if f.Name == field {
found = true found = true
} }
} }
if idx == 0 {
if field != s.PrioritizedPrimaryField.Name {
t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name)
}
}
if !found { if !found {
t.Errorf("schema %v doesn't include field %v", s, f.Name) t.Errorf("schema %v failed to found priamry key: %v", s, field)
} }
} }
} })
}
func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) {
t.Run("CheckField/"+f.Name, func(t *testing.T) {
if fc != nil {
fc(f)
}
if f.TagSettings == nil {
if f.Tag != "" {
f.TagSettings = schema.ParseTagSetting(f.Tag)
} else {
f.TagSettings = map[string]string{}
}
}
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else {
equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
for _, name := range equalFieldNames {
got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
}
}
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
for _, name := range []string{f.DBName, f.Name} {
if field := s.LookUpField(name); field == nil || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
}
}
if f.PrimaryKey {
var found bool
for _, primaryField := range s.PrimaryFields {
if primaryField == parsedField {
found = true
}
}
if !found {
t.Errorf("schema %v doesn't include field %v", s, f.Name)
}
}
}
})
} }
type Relation struct { type Relation struct {
@ -123,79 +127,81 @@ type Reference struct {
} }
func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
if r, ok := s.Relationships.Relations[relation.Name]; ok { t.Run("CheckRelation/"+relation.Name, func(t *testing.T) {
if r.Name != relation.Name { if r, ok := s.Relationships.Relations[relation.Name]; ok {
t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) if r.Name != relation.Name {
} t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name)
if r.Type != relation.Type {
t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type)
}
if r.Schema.Name != relation.Schema {
t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
}
if r.FieldSchema.Name != relation.FieldSchema {
t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
}
if r.Polymorphic != nil {
if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID {
t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name)
} }
if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { if r.Type != relation.Type {
t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type)
} }
if r.Polymorphic.Value != relation.Polymorphic.Value { if r.Schema.Name != relation.Schema {
t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
}
}
if r.JoinTable != nil {
if r.JoinTable.Name != relation.JoinTable.Name {
t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name)
} }
if r.JoinTable.Table != relation.JoinTable.Table { if r.FieldSchema.Name != relation.FieldSchema {
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
} }
for _, f := range relation.JoinTable.Fields { if r.Polymorphic != nil {
checkSchemaField(t, r.JoinTable, &f, nil) if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID {
} t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name)
} }
if len(relation.References) != len(r.References) { if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type {
t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name)
} }
for _, ref := range relation.References { if r.Polymorphic.Value != relation.Polymorphic.Value {
var found bool t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value)
for _, rf := range r.References {
if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) {
found = true
} }
} }
if !found { if r.JoinTable != nil {
var refs []string if r.JoinTable.Name != relation.JoinTable.Name {
t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name)
}
if r.JoinTable.Table != relation.JoinTable.Table {
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
}
for _, f := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &f, nil)
}
}
if len(relation.References) != len(r.References) {
t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References))
}
for _, ref := range relation.References {
var found bool
for _, rf := range r.References { for _, rf := range r.References {
var primaryKey, primaryKeySchema string if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) {
if rf.PrimaryKey != nil { found = true
primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name
} }
refs = append(refs, fmt.Sprintf(
"{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}",
primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey,
))
} }
t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", "))
if !found {
var refs []string
for _, rf := range r.References {
var primaryKey, primaryKeySchema string
if rf.PrimaryKey != nil {
primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name
}
refs = append(refs, fmt.Sprintf(
"{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}",
primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey,
))
}
t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", "))
}
} }
} else {
t.Errorf("schema %v failed to find relations by name %v", s, relation.Name)
} }
} else { })
t.Errorf("schema %v failed to find relations by name %v", s, relation.Name)
}
} }

1
tests/create_test.go Normal file
View File

@ -0,0 +1 @@
package tests