diff --git a/command.go b/command.go index da1c01c0..b62d50dd 100644 --- a/command.go +++ b/command.go @@ -3840,3 +3840,155 @@ func (cmd *ZSliceWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } + +type Function struct { + Name string + Description string + Flags []string +} + +type Library struct { + Name string + Engine string + Functions []Function + Code string +} + +type FunctionListCmd struct { + baseCmd + + val []Library +} + +var _ Cmder = (*FunctionListCmd)(nil) + +func NewFunctionListCmd(ctx context.Context, args ...interface{}) *FunctionListCmd { + return &FunctionListCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *FunctionListCmd) SetVal(val []Library) { + cmd.val = val +} + +func (cmd *FunctionListCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *FunctionListCmd) Val() []Library { + return cmd.val +} + +func (cmd *FunctionListCmd) Result() ([]Library, error) { + return cmd.val, cmd.err +} + +func (cmd *FunctionListCmd) First() (*Library, error) { + if cmd.err != nil { + return nil, cmd.err + } + if len(cmd.val) > 0 { + return &cmd.val[0], nil + } + return nil, Nil +} + +func (cmd *FunctionListCmd) readReply(rd *proto.Reader) (err error) { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + libraries := make([]Library, n) + for i := 0; i < n; i++ { + nn, err := rd.ReadMapLen() + if err != nil { + return err + } + + library := Library{} + for f := 0; f < nn; f++ { + key, err := rd.ReadString() + if err != nil { + return err + } + + switch key { + case "library_name": + library.Name, err = rd.ReadString() + case "engine": + library.Engine, err = rd.ReadString() + case "functions": + library.Functions, err = cmd.readFunctions(rd) + case "library_code": + library.Code, err = rd.ReadString() + default: + return fmt.Errorf("redis: function list unexpected key %s", key) + } + + if err != nil { + return err + } + } + + libraries[i] = library + } + cmd.val = libraries + return nil +} + +func (cmd *FunctionListCmd) readFunctions(rd *proto.Reader) ([]Function, error) { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + + functions := make([]Function, n) + for i := 0; i < n; i++ { + nn, err := rd.ReadMapLen() + if err != nil { + return nil, err + } + + function := Function{} + for f := 0; f < nn; f++ { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + switch key { + case "name": + if function.Name, err = rd.ReadString(); err != nil { + return nil, err + } + case "description": + if function.Description, err = rd.ReadString(); err != nil && err != Nil { + return nil, err + } + case "flags": + // resp set + nx, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + + function.Flags = make([]string, nx) + for j := 0; j < nx; j++ { + if function.Flags[j], err = rd.ReadString(); err != nil { + return nil, err + } + } + default: + return nil, fmt.Errorf("redis: function list unexpected key %s", key) + } + } + + functions[i] = function + } + return functions, nil +} diff --git a/commands.go b/commands.go index db486534..ac93f39c 100644 --- a/commands.go +++ b/commands.go @@ -395,6 +395,15 @@ type Cmdable interface { ScriptKill(ctx context.Context) *StatusCmd ScriptLoad(ctx context.Context, script string) *StringCmd + FunctionLoad(ctx context.Context, code string) *StringCmd + FunctionLoadReplace(ctx context.Context, code string) *StringCmd + FunctionDelete(ctx context.Context, libName string) *StringCmd + FunctionFlush(ctx context.Context) *StringCmd + FunctionFlushAsync(ctx context.Context) *StringCmd + FunctionList(ctx context.Context, q FunctionListQuery) *FunctionListCmd + FunctionDump(ctx context.Context) *StringCmd + FunctionRestore(ctx context.Context, libDump string) *StringCmd + Publish(ctx context.Context, channel string, message interface{}) *IntCmd SPublish(ctx context.Context, channel string, message interface{}) *IntCmd PubSubChannels(ctx context.Context, pattern string) *StringSliceCmd @@ -3270,6 +3279,76 @@ func (c cmdable) ScriptLoad(ctx context.Context, script string) *StringCmd { return cmd } +// ------------------------------------------------------------------------------ + +// FunctionListQuery is used with FunctionList to query for Redis libraries +// +// LibraryNamePattern - Use an empty string to get all libraries. +// - Use a glob-style pattern to match multiple libraries with a matching name +// - Use a library's full name to match a single library +// WithCode - If true, it will return the code of the library +type FunctionListQuery struct { + LibraryNamePattern string + WithCode bool +} + +func (c cmdable) FunctionLoad(ctx context.Context, code string) *StringCmd { + cmd := NewStringCmd(ctx, "function", "load", code) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionLoadReplace(ctx context.Context, code string) *StringCmd { + cmd := NewStringCmd(ctx, "function", "load", "replace", code) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionDelete(ctx context.Context, libName string) *StringCmd { + cmd := NewStringCmd(ctx, "function", "delete", libName) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionFlush(ctx context.Context) *StringCmd { + cmd := NewStringCmd(ctx, "function", "flush") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionFlushAsync(ctx context.Context) *StringCmd { + cmd := NewStringCmd(ctx, "function", "flush", "async") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionList(ctx context.Context, q FunctionListQuery) *FunctionListCmd { + args := make([]interface{}, 2, 5) + args[0] = "function" + args[1] = "list" + if q.LibraryNamePattern != "" { + args = append(args, "libraryname", q.LibraryNamePattern) + } + if q.WithCode { + args = append(args, "withcode") + } + cmd := NewFunctionListCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionDump(ctx context.Context) *StringCmd { + cmd := NewStringCmd(ctx, "function", "dump") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) FunctionRestore(ctx context.Context, libDump string) *StringCmd { + cmd := NewStringCmd(ctx, "function", "restore", libDump) + _ = c(ctx, cmd) + return cmd +} + //------------------------------------------------------------------------------ // Publish posts the message to the channel. diff --git a/commands_test.go b/commands_test.go index 52fb5769..f806dc8f 100644 --- a/commands_test.go +++ b/commands_test.go @@ -6086,6 +6086,227 @@ var _ = Describe("Commands", func() { }) }) + Describe("Functions", func() { + var ( + q redis.FunctionListQuery + lib1Code string + lib2Code string + lib1 redis.Library + lib2 redis.Library + ) + + BeforeEach(func() { + flush := client.FunctionFlush(ctx) + Expect(flush.Err()).NotTo(HaveOccurred()) + + lib1 = redis.Library{ + Name: "mylib1", + Engine: "LUA", + Functions: []redis.Function{ + { + Name: "lib1_func1", + Description: "This is the func-1 of lib 1", + Flags: []string{"no-writes", "allow-stale"}, + }, + }, + Code: `#!lua name=%s + + local function f1(keys, args) + return 'Function 1' + end + + redis.register_function{ + function_name='%s', + description ='%s', + callback=f1, + flags={'%s', '%s'} + }`, + } + + lib2 = redis.Library{ + Name: "mylib2", + Engine: "LUA", + Functions: []redis.Function{ + { + Name: "lib2_func1", + Flags: []string{}, + }, + { + Name: "lib2_func2", + Description: "This is the func-2 of lib 2", + Flags: []string{"no-writes"}, + }, + }, + Code: `#!lua name=%s + + local function f1(keys, args) + return 'Function 1' + end + + local function f2(keys, args) + return 'Function 2' + end + + redis.register_function('%s', f1) + redis.register_function{ + function_name='%s', + description ='%s', + callback=f2, + flags={'%s'} + }`, + } + + lib1Code = fmt.Sprintf(lib1.Code, lib1.Name, lib1.Functions[0].Name, + lib1.Functions[0].Description, lib1.Functions[0].Flags[0], lib1.Functions[0].Flags[1]) + lib2Code = fmt.Sprintf(lib2.Code, lib2.Name, lib2.Functions[0].Name, + lib2.Functions[1].Name, lib2.Functions[1].Description, lib2.Functions[1].Flags[0]) + + q = redis.FunctionListQuery{} + }) + + It("Loads a new library", func() { + functionLoad := client.FunctionLoad(ctx, lib1Code) + Expect(functionLoad.Err()).NotTo(HaveOccurred()) + Expect(functionLoad.Val()).To(Equal(lib1.Name)) + + functionList := client.FunctionList(ctx, q) + Expect(functionList.Err()).NotTo(HaveOccurred()) + Expect(functionList.Val()).To(HaveLen(1)) + }) + + It("Loads and replaces a new library", func() { + // Load a library for the first time + err := client.FunctionLoad(ctx, lib1Code).Err() + Expect(err).NotTo(HaveOccurred()) + + newFuncName := "replaces_func_name" + newFuncDesc := "replaces_func_desc" + flag1, flag2 := "allow-stale", "no-cluster" + newCode := fmt.Sprintf(lib1.Code, lib1.Name, newFuncName, newFuncDesc, flag1, flag2) + + // And then replace it + functionLoadReplace := client.FunctionLoadReplace(ctx, newCode) + Expect(functionLoadReplace.Err()).NotTo(HaveOccurred()) + Expect(functionLoadReplace.Val()).To(Equal(lib1.Name)) + + lib, err := client.FunctionList(ctx, q).First() + Expect(err).NotTo(HaveOccurred()) + Expect(lib.Functions).To(Equal([]redis.Function{ + { + Name: newFuncName, + Description: newFuncDesc, + Flags: []string{flag1, flag2}, + }, + })) + }) + + It("Deletes a library", func() { + err := client.FunctionLoad(ctx, lib1Code).Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.FunctionDelete(ctx, lib1.Name).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err := client.FunctionList(ctx, redis.FunctionListQuery{ + LibraryNamePattern: lib1.Name, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(HaveLen(0)) + }) + + It("Flushes all libraries", func() { + err := client.FunctionLoad(ctx, lib1Code).Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.FunctionLoad(ctx, lib2Code).Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.FunctionFlush(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err := client.FunctionList(ctx, q).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(HaveLen(0)) + }) + + It("Flushes all libraries asynchronously", func() { + functionLoad := client.FunctionLoad(ctx, lib1Code) + Expect(functionLoad.Err()).NotTo(HaveOccurred()) + + // we only verify the command result. + functionFlush := client.FunctionFlushAsync(ctx) + Expect(functionFlush.Err()).NotTo(HaveOccurred()) + }) + + It("Lists registered functions", func() { + err := client.FunctionLoad(ctx, lib1Code).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err := client.FunctionList(ctx, redis.FunctionListQuery{ + LibraryNamePattern: "*", + WithCode: true, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(HaveLen(1)) + Expect(val[0].Name).To(Equal(lib1.Name)) + Expect(val[0].Engine).To(Equal(lib1.Engine)) + Expect(val[0].Code).To(Equal(lib1Code)) + Expect(val[0].Functions).Should(ConsistOf(lib1.Functions)) + + err = client.FunctionLoad(ctx, lib2Code).Err() + Expect(err).NotTo(HaveOccurred()) + + val, err = client.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(HaveLen(2)) + + lib, err := client.FunctionList(ctx, redis.FunctionListQuery{ + LibraryNamePattern: lib2.Name, + WithCode: false, + }).First() + Expect(err).NotTo(HaveOccurred()) + Expect(lib.Name).To(Equal(lib2.Name)) + Expect(lib.Code).To(Equal("")) + + _, err = client.FunctionList(ctx, redis.FunctionListQuery{ + LibraryNamePattern: "non_lib", + WithCode: true, + }).First() + Expect(err).To(Equal(redis.Nil)) + }) + + It("Dump and restores all libraries", func() { + err := client.FunctionLoad(ctx, lib1Code).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.FunctionLoad(ctx, lib2Code).Err() + Expect(err).NotTo(HaveOccurred()) + + dump, err := client.FunctionDump(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(dump).NotTo(BeEmpty()) + + err = client.FunctionRestore(ctx, dump).Err() + Expect(err).To(HaveOccurred()) + + err = client.FunctionFlush(ctx).Err() + Expect(err).NotTo(HaveOccurred()) + + list, err := client.FunctionList(ctx, q).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(list).To(HaveLen(0)) + + err = client.FunctionRestore(ctx, dump).Err() + Expect(err).NotTo(HaveOccurred()) + + list, err = client.FunctionList(ctx, q).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(list).To(HaveLen(2)) + }) + + }) + Describe("SlowLogGet", func() { It("returns slow query result", func() { const key = "slowlog-log-slower-than"