diff --git a/callback.go b/callback.go index 6d0ff0f3..309078e4 100644 --- a/callback.go +++ b/callback.go @@ -148,75 +148,66 @@ func getRIndex(strs []string, str string) int { // sortProcessors sort callback processors based on its before, after, remove, replace func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var sortCallbackProcessor func(c *CallbackProcessor) - var names, sortedNames = []string{}, []string{} + var ( + allNames, sortedNames []string + sortCallbackProcessor func(c *CallbackProcessor) + ) for _, cp := range cps { - if index := getRIndex(names, cp.name); index > -1 { - if !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) - } + // show warning message the callback name already exists + if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { + fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) } - names = append(names, cp.name) + allNames = append(allNames, cp.name) } sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) > -1 { - return - } - - if len(c.before) > 0 { - if index := getRIndex(sortedNames, c.before); index > -1 { - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(names, c.before); index > -1 { - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } else { - sortedNames = append(sortedNames, c.name) - } - } - - if len(c.after) > 0 { - if index := getRIndex(sortedNames, c.after); index > -1 { - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(names, c.after); index > -1 { - cp := cps[index] - if len(cp.before) == 0 { - cp.before = c.name + if getRIndex(sortedNames, c.name) == -1 { // if not sorted + if c.before != "" { // if defined before callback + if index := getRIndex(sortedNames, c.before); index != -1 { + // if before callback already sorted, append current callback just after it + sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) + } else if index := getRIndex(allNames, c.before); index != -1 { + // if before callback exists but haven't sorted, append current callback to last + sortedNames = append(sortedNames, c.name) + sortCallbackProcessor(cps[index]) } - sortCallbackProcessor(cp) - } else { + } + + if c.after != "" { // if defined after callback + if index := getRIndex(sortedNames, c.after); index != -1 { + // if after callback already sorted, append current callback just before it + sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) + } else if index := getRIndex(allNames, c.after); index != -1 { + // if after callback exists but haven't sorted + cp := cps[index] + // set after callback's before callback to current callback + if cp.before == "" { + cp.before = c.name + } + sortCallbackProcessor(cp) + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sortedNames, c.name) == -1 { sortedNames = append(sortedNames, c.name) } } - - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } } for _, cp := range cps { sortCallbackProcessor(cp) } - var funcs = []*func(scope *Scope){} - var sortedFuncs = []*func(scope *Scope){} + var sortedFuncs []*func(scope *Scope) for _, name := range sortedNames { - index := getRIndex(names, name) - if !cps[index].remove { + if index := getRIndex(allNames, name); !cps[index].remove { sortedFuncs = append(sortedFuncs, cps[index].processor) } } - for _, cp := range cps { - if sindex := getRIndex(sortedNames, cp.name); sindex == -1 { - if !cp.remove { - funcs = append(funcs, cp.processor) - } - } - } - - return append(sortedFuncs, funcs...) + return sortedFuncs } // reorder all registered processors, and reset CURD callbacks diff --git a/callback_delete.go b/callback_delete.go index 17b5cfb4..b3a77926 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -2,6 +2,7 @@ package gorm import "fmt" +// Define callbacks for deleting func init() { defaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) defaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) @@ -10,12 +11,14 @@ func init() { defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } +// beforeDeleteCallback will invoke `BeforeDelete` method before deleting func beforeDeleteCallback(scope *Scope) { if !scope.HasError() { scope.CallMethod("BeforeDelete") } } +// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) func deleteCallback(scope *Scope) { if !scope.HasError() { if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { @@ -24,15 +27,14 @@ func deleteCallback(scope *Scope) { scope.QuotedTableName(), scope.AddToVars(NowFunc()), scope.CombinedConditionSql(), - )) + )).Exec() } else { - scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())).Exec() } - - scope.Exec() } } +// afterDeleteCallback will invoke `AfterDelete` method after deleting func afterDeleteCallback(scope *Scope) { if !scope.HasError() { scope.CallMethod("AfterDelete") diff --git a/callback_query.go b/callback_query.go index f6fa0aa1..5d9fd82d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -6,40 +6,42 @@ import ( "reflect" ) +// Define callbacks for querying func init() { defaultCallback.Query().Register("gorm:query", queryCallback) defaultCallback.Query().Register("gorm:after_query", afterQueryCallback) defaultCallback.Query().Register("gorm:preload", preloadCallback) } +// queryCallback used to query data from database func queryCallback(scope *Scope) { defer scope.trace(NowFunc()) var ( - isSlice bool - isPtr bool - destType reflect.Type + isSlice bool + isPtr bool + results = scope.IndirectValue() + resultType reflect.Type ) if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryKey := scope.PrimaryKey(); primaryKey != "" { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) + if primaryField := scope.PrimaryField(); primaryField != nil { + scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) } } - var dest = scope.IndirectValue() if value, ok := scope.Get("gorm:query_destination"); ok { - dest = reflect.Indirect(reflect.ValueOf(value)) + results = reflect.Indirect(reflect.ValueOf(value)) } - if kind := dest.Kind(); kind == reflect.Slice { + if kind := results.Kind(); kind == reflect.Slice { isSlice = true - destType = dest.Type().Elem() - dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + resultType = results.Type().Elem() + results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - if destType.Kind() == reflect.Ptr { + if resultType.Kind() == reflect.Ptr { isPtr = true - destType = destType.Elem() + resultType = resultType.Elem() } } else if kind != reflect.Struct { scope.Err(errors.New("unsupported destination, should be slice or struct")) @@ -49,41 +51,38 @@ func queryCallback(scope *Scope) { scope.prepareQuerySql() if !scope.HasError() { - rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) scope.db.RowsAffected = 0 + if rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + defer rows.Close() - if scope.Err(err) != nil { - return - } - defer rows.Close() + columns, _ := rows.Columns() + for rows.Next() { + scope.db.RowsAffected++ - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ + elem := results + if isSlice { + elem = reflect.New(resultType).Elem() + } - elem := dest - if isSlice { - elem = reflect.New(destType).Elem() - } + scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - fields := scope.New(elem.Addr().Interface()).Fields() - scope.scan(rows, columns, fields) - - if isSlice { - if isPtr { - dest.Set(reflect.Append(dest, elem.Addr())) - } else { - dest.Set(reflect.Append(dest, elem)) + if isSlice { + if isPtr { + results.Set(reflect.Append(results, elem.Addr())) + } else { + results.Set(reflect.Append(results, elem)) + } } } - } - if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(RecordNotFound) + if scope.db.RowsAffected == 0 && !isSlice { + scope.Err(RecordNotFound) + } } } } +// afterQueryCallback will invoke `AfterFind` method after querying func afterQueryCallback(scope *Scope) { if !scope.HasError() { scope.CallMethod("AfterFind") diff --git a/callback_query_preload.go b/callback_query_preload.go index 5dc91de9..ff99fea9 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -7,6 +7,7 @@ import ( "strings" ) +// preloadCallback used to preload associations func preloadCallback(scope *Scope) { if scope.Search.preload == nil || scope.HasError() { return @@ -72,6 +73,7 @@ func preloadCallback(scope *Scope) { } } +// handleHasOnePreload used to preload has one associations func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { relation := field.Relationship @@ -107,6 +109,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } } +// handleHasManyPreload used to preload has many associations func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { relation := field.Relationship @@ -144,6 +147,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) } } +// handleBelongsToPreload used to preload belongs to associations func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship @@ -179,6 +183,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } +// handleManyToManyPreload used to preload many to many associations func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { var ( relation = field.Relationship