diff --git a/callbacks/associations.go b/callbacks/associations.go index a0c296e3..96d9ce22 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -234,7 +234,6 @@ func SaveAfterAssociations(db *gorm.DB) { ref.ForeignKey.Set(joinValue, fv) } } - joins = reflect.Append(joins, joinValue) } @@ -277,7 +276,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.Session(&gorm.Session{}).Create(joins.Interface()) + db.Session(&gorm.Session{}).Debug().Create(joins.Interface()) } } } diff --git a/callbacks/create.go b/callbacks/create.go index 9dc8dc67..ff88bc0e 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -56,25 +56,27 @@ func Create(config *Config) func(db *gorm.DB) { if err == nil { if db.Statement.Schema != nil { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } db.RowsAffected, _ = result.RowsAffected() diff --git a/tests/main_test.go b/tests/main_test.go deleted file mode 100644 index 3e329454..00000000 --- a/tests/main_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package tests - -import ( - "log" - "math/rand" - "os" - "path/filepath" - "testing" - "time" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/mssql" - "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/gorm/logger" -) - -var DB *gorm.DB - -func TestMain(m *testing.M) { - var err error - DB, err = OpenTestConnection() - if err == nil { - RunMigrations() - m.Run() - } else { - log.Printf("failed to connect database, got error %v\n", err) - os.Exit(1) - } -} - -func RunMigrations() { - var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} - rand.Seed(time.Now().UnixNano()) - rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - - DB.Migrator().DropTable("user_friends", "user_speak") - - if err = DB.Migrator().DropTable(allModels...); err != nil { - log.Printf("Failed to drop table, got error %v\n", err) - os.Exit(1) - } - - if err = DB.AutoMigrate(allModels...); err != nil { - log.Printf("Failed to auto migrate, but got error %v\n", err) - os.Exit(1) - } - - for _, m := range allModels { - if !DB.Migrator().HasTable(m) { - log.Printf("Failed to create table for %#v\n", m) - os.Exit(1) - } - } -} - -func OpenTestConnection() (db *gorm.DB, err error) { - dbDSN := os.Getenv("GORM_DSN") - switch os.Getenv("GORM_DIALECT") { - case "mysql": - log.Println("testing mysql...") - if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" - } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) - case "postgres": - log.Println("testing postgres...") - if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" - } - db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) - case "mssql": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; - // CREATE DATABASE gorm; - // USE gorm; - // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - log.Println("testing mssql...") - if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" - } - db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) - default: - log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) - } - - if debug := os.Getenv("DEBUG"); debug == "true" { - db.Logger.LogMode(logger.Info) - } else if debug == "false" { - db.Logger.LogMode(logger.Error) - } - - return -} diff --git a/tests/tests.go b/tests/tests.go index 809d2e39..1ff700c5 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -1,24 +1,99 @@ package tests import ( - "testing" + "log" + "math/rand" + "os" + "path/filepath" "time" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mssql" + "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/dialects/postgres" + "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/gorm/logger" ) +var DB *gorm.DB + +func init() { + var err error + if DB, err = OpenTestConnection(); err == nil { + RunMigrations() + } else { + log.Printf("failed to connect database, got error %v\n", err) + os.Exit(1) + } +} + +func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") + switch os.Getenv("GORM_DIALECT") { + case "mysql": + log.Println("testing mysql...") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + case "postgres": + log.Println("testing postgres...") + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" + } + db, err = gorm.Open(postgres.Open(dbDSN), &gorm.Config{}) + case "mssql": + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // CREATE DATABASE gorm; + // USE gorm; + // CREATE USER gorm FROM LOGIN gorm; + // sp_changedbowner 'gorm'; + log.Println("testing mssql...") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open(mssql.Open(dbDSN), &gorm.Config{}) + default: + log.Println("testing sqlite3...") + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + } + + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger.LogMode(logger.Error) + } + + return +} + +func RunMigrations() { + var err error + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + + DB.Migrator().DropTable("user_friends", "user_speak") + + if err = DB.Migrator().DropTable(allModels...); err != nil { + log.Printf("Failed to drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(allModels...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + log.Printf("Failed to create table for %#v\n", m) + os.Exit(1) + } + } +} + func Now() *time.Time { now := time.Now() return &now } - -func RunTestsSuit(t *testing.T, db *gorm.DB) { - // TestCreate(t, db) - TestFind(t, db) - TestUpdate(t, db) - TestDelete(t, db) - - TestGroupBy(t, db) - TestJoins(t, db) - TestAssociations(t, db) -}