diff --git a/command.go b/command.go index c6cd9db6..f24255f3 100644 --- a/command.go +++ b/command.go @@ -3,7 +3,9 @@ package redis import ( "bufio" "context" + "errors" "fmt" + "math/big" "net" "regexp" "strconv" @@ -5484,3 +5486,283 @@ func (cmd *MonitorCmd) Stop() { defer cmd.mu.Unlock() cmd.status = monitorStatusStop } + +// -------------------------------------------------------------------------------------- + +type GraphCmd struct { + baseCmd + + val *GraphResult +} + +var _ Cmder = (*GraphCmd)(nil) + +func NewGraphCmd(ctx context.Context, args ...any) *GraphCmd { + return &GraphCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *GraphCmd) SetVal(val *GraphResult) { + cmd.val = val +} + +func (cmd *GraphCmd) Val() *GraphResult { + return cmd.val +} + +func (cmd *GraphCmd) Result() (*GraphResult, error) { + return cmd.val, cmd.err +} + +func (cmd *GraphCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *GraphCmd) readReply(rd *proto.Reader) error { + cmd.val = &GraphResult{} + cmd.val.err = cmd.readGraph(rd) + return cmd.val.err +} + +func (cmd *GraphCmd) readGraph(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + if n != 1 && n != 3 { + return fmt.Errorf("redis: invalid number of elements in graph result: %d", n) + } + + if n == 1 { + // create? + cmd.val.noResult = true + if cmd.val.text, err = cmd.readStringArray(rd); err != nil { + return err + } + return nil + } + + // n = 3, read graph result + if cmd.val.field, err = cmd.readStringArray(rd); err != nil { + return err + } + fieldLen := len(cmd.val.field) + if fieldLen == 0 { + return nil + } + + // read response + rows, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val.rows = make([][]*graphRow, rows) + for i := 0; i < rows; i++ { + // field == row + if err = rd.ReadFixedArrayLen(fieldLen); err != nil { + return err + } + + cmd.val.rows[i] = make([]*graphRow, fieldLen) + for f := 0; f < fieldLen; f++ { + next, err := rd.PeekReplyType() + if err != nil { + return err + } + var nn int + switch next { + case proto.RespArray, proto.RespSet, proto.RespPush: + nn, err = rd.ReadArrayLen() + if err != nil { + return err + } + default: + nn = 1 + } + + // 1: int/string/nil + // 3: node, id + labels + properties + // 5: edge, id + type + src_node + dest_node + properties + switch nn { + case 1: + data, err := cmd.readData(rd) + if err != nil { + return err + } + cmd.val.rows[i][f] = &graphRow{ + typ: graphResultBasic, + basic: data, + } + case 3: + node, err := cmd.readNode(rd) + if err != nil { + return err + } + cmd.val.rows[i][f] = &graphRow{ + typ: graphResultNode, + node: node, + } + case 5: + edge, err := cmd.readEdge(rd) + if err != nil { + return err + } + cmd.val.rows[i][f] = &graphRow{ + typ: graphResultEdge, + edge: edge, + } + default: + return fmt.Errorf("redis: graph-row-field-len, got %d elements in the array, wanted %v", nn, "1/3/5") + } + } + } + if cmd.val.text, err = cmd.readStringArray(rd); err != nil { + return err + } + + return err +} + +// node = 3, id +labels +properties +func (cmd *GraphCmd) readNode(rd *proto.Reader) (GraphNode, error) { + node := GraphNode{} + for j := 0; j < 3; j++ { + if err := rd.ReadFixedArrayLen(2); err != nil { + return node, err + } + key, err := rd.ReadString() + if err != nil { + return node, err + } + switch key { + case "id": + if node.ID, err = rd.ReadInt(); err != nil { + return node, err + } + case "labels": + if node.Labels, err = cmd.readStringArray(rd); err != nil { + return node, err + } + case "properties": + if node.Properties, err = cmd.readProperties(rd); err != nil { + return node, err + } + default: + return node, fmt.Errorf("redis: invalid graph node key - %s", key) + } + } + return node, nil +} + +// edge = 5, id + type + src_node + dest_node + properties +func (cmd *GraphCmd) readEdge(rd *proto.Reader) (GraphEdge, error) { + edge := GraphEdge{} + for j := 0; j < 5; j++ { + if err := rd.ReadFixedArrayLen(2); err != nil { + return edge, err + } + key, err := rd.ReadString() + if err != nil { + return edge, err + } + switch key { + case "id": + if edge.ID, err = rd.ReadInt(); err != nil { + return edge, err + } + case "type": + if edge.Typ, err = rd.ReadString(); err != nil { + return edge, err + } + case "src_node": + if edge.SrcNode, err = rd.ReadInt(); err != nil { + return edge, err + } + case "dest_node": + if edge.DstNode, err = rd.ReadInt(); err != nil { + return edge, err + } + case "properties": + if edge.Properties, err = cmd.readProperties(rd); err != nil { + return edge, err + } + default: + return edge, fmt.Errorf("redis: invalid graph edge key - %s", key) + } + } + return edge, nil +} + +func (cmd *GraphCmd) readProperties(rd *proto.Reader) (map[string]GraphData, error) { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + m := make(map[string]GraphData, n) + for i := 0; i < n; i++ { + if err = rd.ReadFixedArrayLen(2); err != nil { + return nil, err + } + key, err := rd.ReadString() + if err != nil { + return nil, err + } + val, err := cmd.readData(rd) + if err != nil { + return nil, err + } + m[key] = val + } + return m, nil +} + +func (cmd *GraphCmd) readData(rd *proto.Reader) (GraphData, error) { + var data GraphData + reply, err := rd.ReadReply() + if err != nil { + if errors.Is(err, Nil) { + data.typ = graphNil + return data, nil + } + return data, err + } + + switch v := reply.(type) { + case string: + data.typ = graphString + data.stringVal = v + case int64: + data.typ = graphInteger + data.integerVal = v + case *big.Int: + data.typ = graphInteger + if !v.IsInt64() { + return data, fmt.Errorf("redis: bigInt(%s) value out of range", v.String()) + } + data.integerVal = v.Int64() + default: + return data, fmt.Errorf("redis: invalid reply - %q", reply) + } + return data, nil +} + +func (cmd *GraphCmd) readStringArray(rd *proto.Reader) ([]string, error) { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n == 0 { + return nil, nil + } + ss := make([]string, n) + for i := 0; i < n; i++ { + if ss[i], err = rd.ReadString(); err != nil { + return ss, err + } + } + return ss, nil +} diff --git a/commands.go b/commands.go index db595944..9988a0bf 100644 --- a/commands.go +++ b/commands.go @@ -226,6 +226,7 @@ type Cmdable interface { StreamCmdable TimeseriesCmdable JSONCmdable + GraphCmdable } type StatefulCmdable interface { diff --git a/graph.go b/graph.go new file mode 100644 index 00000000..92ff1fab --- /dev/null +++ b/graph.go @@ -0,0 +1,204 @@ +package redis + +import ( + "context" + "strconv" +) + +type GraphCmdable interface { + GraphQuery(ctx context.Context, key, query string) *GraphCmd +} + +// GraphQuery executes a query on a graph, exec: GRAPH.QUERY key query +func (c cmdable) GraphQuery(ctx context.Context, key, query string) *GraphCmd { + cmd := NewGraphCmd(ctx, "GRAPH.QUERY", key, query) + _ = c(ctx, cmd) + return cmd +} + +// ---------------------------------------------------------------------------- + +type GraphValue interface { + IsNil() bool + String() string + Int() int + Bool() bool + Float64() float64 + Node() (*GraphNode, bool) + Edge() (*GraphEdge, bool) +} + +type ( + graphDataType int + graphRowType int +) + +const ( + graphInteger graphDataType = iota + 1 // int (graph int) + graphNil // nil (graph nil/null) + graphString // string (graph string/boolean/double) +) + +const ( + graphResultBasic graphRowType = iota + 1 // int/nil/string + graphResultNode // node(3), id +labels +properties + graphResultEdge // edge(5), id + type + src_node + dest_node + properties +) + +type GraphResult struct { + noResult bool + text []string + field []string + rows [][]*graphRow + err error +} + +// Message is used to obtain statistical information about the execution of graph statements. +func (g *GraphResult) Message() []string { + return g.text +} + +// IsResult indicates whether the statement has a result set response. +// For example, a MATCH statement always has a result set (regardless of the number of rows in the response), +// while a CREATE statement does not. +func (g *GraphResult) IsResult() bool { + return !g.noResult +} + +// Error return exec error +func (g *GraphResult) Error() error { + return g.err +} + +// Field returns the fields in the query result set. +// If the statement has no result set (IsResult() == false), it returns nil. +func (g *GraphResult) Field() []string { + return g.field +} + +// Len return the number of rows in the result set. +func (g *GraphResult) Len() int { + return len(g.rows) +} + +// Row read the first row of data in the result set. If there is no data response, return redis.Nil. +func (g *GraphResult) Row() (map[string]GraphValue, error) { + if g.err != nil { + return nil, g.err + } + if g.noResult || len(g.field) == 0 || len(g.rows) == 0 || len(g.rows[0]) == 0 { + return nil, Nil + } + row := make(map[string]GraphValue, len(g.field)) + for i, f := range g.field { + row[f] = g.rows[0][i] + } + return row, nil +} + +// Rows return all result set responses, If there is no data response, return redis.Nil. +func (g *GraphResult) Rows() ([]map[string]GraphValue, error) { + if g.err != nil { + return nil, g.err + } + if g.noResult || len(g.field) == 0 || len(g.rows) == 0 || len(g.rows[0]) == 0 { + return nil, Nil + } + rows := make([]map[string]GraphValue, 0, len(g.rows)) + for i := 0; i < len(g.rows); i++ { + row := make(map[string]GraphValue, len(g.field)) + for x, f := range g.field { + row[f] = g.rows[i][x] + } + rows = append(rows, row) + } + return rows, nil +} + +type graphRow struct { + typ graphRowType + basic GraphData + node GraphNode + edge GraphEdge +} + +func (g *graphRow) IsNil() bool { return g.basic.IsNil() } +func (g *graphRow) String() string { return g.basic.String() } +func (g *graphRow) Int() int { return g.basic.Int() } +func (g *graphRow) Bool() bool { return g.basic.Bool() } +func (g *graphRow) Float64() float64 { return g.basic.Float64() } +func (g *graphRow) Node() (*GraphNode, bool) { return &g.node, g.typ == graphResultNode } +func (g *graphRow) Edge() (*GraphEdge, bool) { return &g.edge, g.typ == graphResultEdge } + +type GraphData struct { + typ graphDataType + integerVal int64 + stringVal string +} + +func (d GraphData) IsNil() bool { + return d.typ == graphNil +} + +func (d GraphData) String() string { + switch d.typ { + case graphInteger: + return strconv.FormatInt(d.integerVal, 10) + case graphNil: + return "" + case graphString: + return d.stringVal + default: + return "" + } +} + +func (d GraphData) Int() int { + switch d.typ { + case graphInteger: + return int(d.integerVal) + case graphString: + n, _ := strconv.Atoi(d.stringVal) + return n + default: + return 0 + } +} + +func (d GraphData) Bool() bool { + switch d.typ { + case graphInteger: + return d.integerVal != 0 + case graphNil: + return false + case graphString: + return d.stringVal == "true" + default: + return false + } +} + +func (d GraphData) Float64() float64 { + if d.typ == graphInteger { + return float64(d.integerVal) + } + if d.typ == graphString { + v, _ := strconv.ParseFloat(d.stringVal, 64) + return v + } + return 0 +} + +type GraphNode struct { + ID int64 + Labels []string + Properties map[string]GraphData +} + +type GraphEdge struct { + ID int64 + Typ string + SrcNode int64 + DstNode int64 + Properties map[string]GraphData +} diff --git a/graph_test.go b/graph_test.go new file mode 100644 index 00000000..c7702203 --- /dev/null +++ b/graph_test.go @@ -0,0 +1,134 @@ +package redis_test + +import ( + . "github.com/bsm/ginkgo/v2" + . "github.com/bsm/gomega" + + "github.com/redis/go-redis/v9" +) + +var _ = Describe("Client", func() { + var client *redis.Client + var graph = "test-graph" + + BeforeEach(func() { + client = redis.NewClient(redisOptions()) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("graph no-result", func() { + res, err := client.GraphQuery(ctx, graph, "CREATE ()").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res.IsResult()).To(BeFalse()) + }) + + It("graph query result-basic", func() { + query := "CREATE (:per {id: 1024, name: 'foo', pr: 3.14, success: true})" + _, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + + query = "MATCH (p:per {id: 1024}) RETURN p.id as id, p.name as name, p.pr as pr, p.success as success, p.non as non" + res, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res.IsResult()).To(BeTrue()) + Expect(res.Len()).To(Equal(1)) + + row, err := res.Row() + Expect(err).NotTo(HaveOccurred()) + Expect(row).To(HaveLen(5)) + Expect(row["id"].Int()).To(Equal(1024)) + Expect(row["name"].Int()).To(Equal("foo")) + Expect(row["pr"].Float64()).To(Equal(3.14)) + Expect(row["success"].Bool()).To(BeTrue()) + Expect(row["non"].IsNil()).To(BeTrue()) + }) + + It("graph query result-node", func() { + query := "CREATE (:per {id: 1024, name: 'foo', pr: 3.14, success: true})" + _, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + + query = "MATCH (p:per {id: 1024}) RETURN p" + res, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res.IsResult()).To(BeTrue()) + Expect(res.Len()).To(Equal(1)) + + row, err := res.Row() + Expect(err).NotTo(HaveOccurred()) + Expect(row).To(HaveLen(1)) + + node, ok := row["p"].Node() + Expect(ok).To(BeTrue()) + + Expect(node.Labels).To(Equal([]string{"per"})) + Expect(node.Properties["id"].Int()).To(Equal(1024)) + Expect(node.Properties["name"].String()).To(Equal("foo")) + Expect(node.Properties["pr"].Float64()).To(Equal(3.14)) + Expect(node.Properties["success"].Bool()).To(BeTrue()) + }) + + It("graph query result-edge", func() { + query := "CREATE (:per {id: 1024}) - [:FRIENDS {ts: 100, msg: 'txt-msg'}] -> (:per {id: 2048})" + _, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + + query = "MATCH (:per {id: 1024}) - [r:FRIENDS] -> (:per {id: 2048}) RETURN r" + res, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res.IsResult()).To(BeTrue()) + Expect(res.Len()).To(Equal(1)) + + row, err := res.Row() + Expect(err).NotTo(HaveOccurred()) + Expect(row).To(HaveLen(1)) + + edge, ok := row["r"].Edge() + Expect(ok).To(BeTrue()) + + Expect(edge.Typ).To(Equal("FRIENDS")) + Expect(edge.Properties["ts"].Int()).To(Equal(100)) + Expect(edge.Properties["msg"].String()).To(Equal("txt-msg")) + }) + + It("graph query no-row", func() { + query := "MATCH (p:per {id: 999}) RETURN p.name" + res, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res.IsResult()).To(BeTrue()) + Expect(res.Len()).To(Equal(0)) + + row, err := res.Row() + Expect(err).To(Equal(redis.Nil)) + Expect(row).To(HaveLen(0)) + + rows, err := res.Rows() + Expect(err).To(Equal(redis.Nil)) + Expect(rows).To(HaveLen(0)) + }) + + It("graph query rows", func() { + query := "CREATE (:per {id: 1024}),(:per {id: 2048}),(:per {id: 4096})" + _, err := client.GraphQuery(ctx, graph, query).Result() + Expect(err).NotTo(HaveOccurred()) + + query = "MATCH (p:per) return p.id as id" + res, err := client.GraphQuery(ctx, graph, query).Result() + + row, err := res.Row() + Expect(err).NotTo(HaveOccurred()) + Expect(row).To(HaveLen(1)) + Expect(row["id"].Int()).To(Equal(1024)) + + rows, err := res.Rows() + Expect(err).NotTo(HaveOccurred()) + Expect(rows).To(HaveLen(3)) + Expect(rows[0]["id"].Int()).To(Equal(1024)) + Expect(rows[1]["id"].Int()).To(Equal(2048)) + Expect(rows[2]["id"].Int()).To(Equal(4096)) + }) +})