From e8c07b55316b12d028eecac5e9a49f1b16918e44 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Thu, 5 Dec 2019 23:57:15 +0900 Subject: [PATCH] Set nopLogger to DefaultCallback for avoid nil pointer dereference (#2742) --- callback.go | 9 ++------- callbacks_test.go | 30 ++++++++++++++++++++++++++++++ logger.go | 4 ++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/callback.go b/callback.go index 56b2064a..1f0e3c79 100644 --- a/callback.go +++ b/callback.go @@ -3,7 +3,7 @@ package gorm import "fmt" // DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} +var DefaultCallback = &Callback{logger: nopLogger{}} // Callback is a struct that contains all CRUD callbacks // Field `creates` contains callbacks will be call when creating object @@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * } } - if cp.logger != nil { - // note cp.logger will be nil during the default gorm callback registrations - // as they occur within init() blocks. However, any user-registered callbacks - // will happen after cp.logger exists (as the default logger or user-specified). - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - } + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.parent.processors = append(cp.parent.processors, cp) diff --git a/callbacks_test.go b/callbacks_test.go index c1a1d5e4..bebd0e38 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) { t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) } } + +func TestUseDefaultCallback(t *testing.T) { + createCallbackName := "gorm:test_use_default_callback_for_create" + gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { + // nop + }) + if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { + t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) + } + gorm.DefaultCallback.Create().Remove(createCallbackName) + if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { + t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) + } + + updateCallbackName := "gorm:test_use_default_callback_for_update" + scopeValueName := "gorm:test_use_default_callback_for_update_value" + gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 1) + }) + gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 2) + }) + + scope := DB.NewScope(nil) + callback := gorm.DefaultCallback.Update().Get(updateCallbackName) + callback(scope) + if v, ok := scope.Get(scopeValueName); !ok || v != 2 { + t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) + } +} diff --git a/logger.go b/logger.go index b4a362ce..88e167dd 100644 --- a/logger.go +++ b/logger.go @@ -135,3 +135,7 @@ type Logger struct { func (logger Logger) Print(values ...interface{}) { logger.Println(LogFormatter(values...)...) } + +type nopLogger struct{} + +func (nopLogger) Print(values ...interface{}) {}