diff --git a/scope.go b/scope.go index b0c2432c..09da8e11 100644 --- a/scope.go +++ b/scope.go @@ -182,50 +182,47 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { return errors.New("could not convert column to field") } -func (scope *Scope) CallMethod(name string) { +func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { + if reflectValue.CanAddr() { + reflectValue = reflectValue.Addr() + } + + if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { + switch method := methodValue.Interface().(type) { + case func(): + method() + case func(*Scope): + method(scope) + case func(*DB): + newDB := scope.NewDB() + method(newDB) + scope.Err(newDB.Error) + case func() error: + scope.Err(method()) + case func(*Scope) error: + scope.Err(method(scope)) + case func(*DB) error: + newDB := scope.NewDB() + scope.Err(method(newDB)) + scope.Err(newDB.Error) + default: + scope.Err(fmt.Errorf("unsupported function %v", methodName)) + } + } +} + +// CallMethod call scope value's method, if it is a slice, will call value's method one by one +func (scope *Scope) CallMethod(methodName string) { if scope.Value == nil { return } - call := func(value interface{}) { - if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() { - switch f := fm.Interface().(type) { - case func(): - f() - case func(s *Scope): - f(scope) - case func(s *DB): - newDB := scope.NewDB() - f(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(f()) - case func(s *Scope) error: - scope.Err(f(scope)) - case func(s *DB) error: - newDB := scope.NewDB() - scope.Err(f(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", name)) - } - } - } - - if values := scope.IndirectValue(); values.Kind() == reflect.Slice { - for i := 0; i < values.Len(); i++ { - value := values.Index(i).Addr().Interface() - if values.Index(i).Kind() == reflect.Ptr { - value = values.Index(i).Interface() - } - call(value) + if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { + for i := 0; i < indirectScopeValue.Len(); i++ { + scope.callMethod(methodName, indirectScopeValue.Index(i)) } } else { - if scope.IndirectValue().CanAddr() { - call(scope.IndirectValue().Addr().Interface()) - } else { - call(scope.IndirectValue().Interface()) - } + scope.callMethod(methodName, indirectScopeValue) } }