diff --git a/callback.go b/callback.go index d2f0cf04..a17fb2aa 100644 --- a/callback.go +++ b/callback.go @@ -5,7 +5,7 @@ import ( ) // defaultCallback hold default callbacks defined by gorm -var defaultCallback = &Callback{} +var DefaultCallback = &Callback{} // Callback contains callbacks that used when CURD objects // Field `creates` hold callbacks will be call when creating object diff --git a/callback_create.go b/callback_create.go index 6316f9ee..314f505a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -7,15 +7,15 @@ import ( // Define callbacks for creating func init() { - defaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - defaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - defaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - defaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - defaultCallback.Create().Register("gorm:create", createCallback) - defaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - defaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - defaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) + DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) + DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) + DefaultCallback.Create().Register("gorm:create", createCallback) + DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) + DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) + DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating diff --git a/callback_delete.go b/callback_delete.go index 9db0666c..c8ffcc82 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -4,11 +4,11 @@ import "fmt" // Define callbacks for deleting func init() { - defaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - defaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - defaultCallback.Delete().Register("gorm:delete", deleteCallback) - defaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) + DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) + DefaultCallback.Delete().Register("gorm:delete", deleteCallback) + DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) + DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } // beforeDeleteCallback will invoke `BeforeDelete` method before deleting diff --git a/callback_query.go b/callback_query.go index 11f8b476..aa643557 100644 --- a/callback_query.go +++ b/callback_query.go @@ -8,9 +8,9 @@ import ( // Define callbacks for querying func init() { - defaultCallback.Query().Register("gorm:query", queryCallback) - defaultCallback.Query().Register("gorm:preload", preloadCallback) - defaultCallback.Query().Register("gorm:after_query", afterQueryCallback) + DefaultCallback.Query().Register("gorm:query", queryCallback) + DefaultCallback.Query().Register("gorm:preload", preloadCallback) + DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) } // queryCallback used to query data from database diff --git a/callback_update.go b/callback_update.go index b71a47b4..287b927f 100644 --- a/callback_update.go +++ b/callback_update.go @@ -7,15 +7,15 @@ import ( // Define callbacks for updating func init() { - defaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - defaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - defaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - defaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - defaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - defaultCallback.Update().Register("gorm:update", updateCallback) - defaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - defaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) + DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) + DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) + DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) + DefaultCallback.Update().Register("gorm:update", updateCallback) + DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) + DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } // assignUpdatingAttributesCallback assign updating attributes to model diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 26bb38eb..03ad60fb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,3 +1,17 @@ package mssql -import _ "github.com/denisenkom/go-mssqldb" +import ( + "fmt" + + _ "github.com/denisenkom/go-mssqldb" + "github.com/jinzhu/gorm" +) + +func setIdentityInsert(scope *gorm.Scope) { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) +} + +func init() { + gorm.DefaultCallback.Update().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) + gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) +} diff --git a/main.go b/main.go index cfa71b60..46f35d01 100644 --- a/main.go +++ b/main.go @@ -64,7 +64,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) { db = DB{ dialect: newDialect(dialect, dbSql.(*sql.DB)), logger: defaultLogger, - callbacks: defaultCallback, + callbacks: DefaultCallback, source: source, values: map[string]interface{}{}, db: dbSql,