feat(graph): support RedisGraph query

Signed-off-by: monkey92t <golang@88.com>
This commit is contained in:
monkey92t 2024-04-13 00:44:26 +08:00
parent 6833d2f8e1
commit caa75ac3a7
4 changed files with 621 additions and 0 deletions

View File

@ -3,7 +3,9 @@ package redis
import (
"bufio"
"context"
"errors"
"fmt"
"math/big"
"net"
"regexp"
"strconv"
@ -5484,3 +5486,283 @@ func (cmd *MonitorCmd) Stop() {
defer cmd.mu.Unlock()
cmd.status = monitorStatusStop
}
// --------------------------------------------------------------------------------------
type GraphCmd struct {
baseCmd
val *GraphResult
}
var _ Cmder = (*GraphCmd)(nil)
func NewGraphCmd(ctx context.Context, args ...any) *GraphCmd {
return &GraphCmd{
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
func (cmd *GraphCmd) SetVal(val *GraphResult) {
cmd.val = val
}
func (cmd *GraphCmd) Val() *GraphResult {
return cmd.val
}
func (cmd *GraphCmd) Result() (*GraphResult, error) {
return cmd.val, cmd.err
}
func (cmd *GraphCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *GraphCmd) readReply(rd *proto.Reader) error {
cmd.val = &GraphResult{}
cmd.val.err = cmd.readGraph(rd)
return cmd.val.err
}
func (cmd *GraphCmd) readGraph(rd *proto.Reader) error {
n, err := rd.ReadArrayLen()
if err != nil {
return err
}
if n != 1 && n != 3 {
return fmt.Errorf("redis: invalid number of elements in graph result: %d", n)
}
if n == 1 {
// create?
cmd.val.noResult = true
if cmd.val.text, err = cmd.readStringArray(rd); err != nil {
return err
}
return nil
}
// n = 3, read graph result
if cmd.val.field, err = cmd.readStringArray(rd); err != nil {
return err
}
fieldLen := len(cmd.val.field)
if fieldLen == 0 {
return nil
}
// read response
rows, err := rd.ReadArrayLen()
if err != nil {
return err
}
cmd.val.rows = make([][]*graphRow, rows)
for i := 0; i < rows; i++ {
// field == row
if err = rd.ReadFixedArrayLen(fieldLen); err != nil {
return err
}
cmd.val.rows[i] = make([]*graphRow, fieldLen)
for f := 0; f < fieldLen; f++ {
next, err := rd.PeekReplyType()
if err != nil {
return err
}
var nn int
switch next {
case proto.RespArray, proto.RespSet, proto.RespPush:
nn, err = rd.ReadArrayLen()
if err != nil {
return err
}
default:
nn = 1
}
// 1: int/string/nil
// 3: node, id + labels + properties
// 5: edge, id + type + src_node + dest_node + properties
switch nn {
case 1:
data, err := cmd.readData(rd)
if err != nil {
return err
}
cmd.val.rows[i][f] = &graphRow{
typ: graphResultBasic,
basic: data,
}
case 3:
node, err := cmd.readNode(rd)
if err != nil {
return err
}
cmd.val.rows[i][f] = &graphRow{
typ: graphResultNode,
node: node,
}
case 5:
edge, err := cmd.readEdge(rd)
if err != nil {
return err
}
cmd.val.rows[i][f] = &graphRow{
typ: graphResultEdge,
edge: edge,
}
default:
return fmt.Errorf("redis: graph-row-field-len, got %d elements in the array, wanted %v", nn, "1/3/5")
}
}
}
if cmd.val.text, err = cmd.readStringArray(rd); err != nil {
return err
}
return err
}
// node = 3, id +labels +properties
func (cmd *GraphCmd) readNode(rd *proto.Reader) (GraphNode, error) {
node := GraphNode{}
for j := 0; j < 3; j++ {
if err := rd.ReadFixedArrayLen(2); err != nil {
return node, err
}
key, err := rd.ReadString()
if err != nil {
return node, err
}
switch key {
case "id":
if node.ID, err = rd.ReadInt(); err != nil {
return node, err
}
case "labels":
if node.Labels, err = cmd.readStringArray(rd); err != nil {
return node, err
}
case "properties":
if node.Properties, err = cmd.readProperties(rd); err != nil {
return node, err
}
default:
return node, fmt.Errorf("redis: invalid graph node key - %s", key)
}
}
return node, nil
}
// edge = 5, id + type + src_node + dest_node + properties
func (cmd *GraphCmd) readEdge(rd *proto.Reader) (GraphEdge, error) {
edge := GraphEdge{}
for j := 0; j < 5; j++ {
if err := rd.ReadFixedArrayLen(2); err != nil {
return edge, err
}
key, err := rd.ReadString()
if err != nil {
return edge, err
}
switch key {
case "id":
if edge.ID, err = rd.ReadInt(); err != nil {
return edge, err
}
case "type":
if edge.Typ, err = rd.ReadString(); err != nil {
return edge, err
}
case "src_node":
if edge.SrcNode, err = rd.ReadInt(); err != nil {
return edge, err
}
case "dest_node":
if edge.DstNode, err = rd.ReadInt(); err != nil {
return edge, err
}
case "properties":
if edge.Properties, err = cmd.readProperties(rd); err != nil {
return edge, err
}
default:
return edge, fmt.Errorf("redis: invalid graph edge key - %s", key)
}
}
return edge, nil
}
func (cmd *GraphCmd) readProperties(rd *proto.Reader) (map[string]GraphData, error) {
n, err := rd.ReadArrayLen()
if err != nil {
return nil, err
}
m := make(map[string]GraphData, n)
for i := 0; i < n; i++ {
if err = rd.ReadFixedArrayLen(2); err != nil {
return nil, err
}
key, err := rd.ReadString()
if err != nil {
return nil, err
}
val, err := cmd.readData(rd)
if err != nil {
return nil, err
}
m[key] = val
}
return m, nil
}
func (cmd *GraphCmd) readData(rd *proto.Reader) (GraphData, error) {
var data GraphData
reply, err := rd.ReadReply()
if err != nil {
if errors.Is(err, Nil) {
data.typ = graphNil
return data, nil
}
return data, err
}
switch v := reply.(type) {
case string:
data.typ = graphString
data.stringVal = v
case int64:
data.typ = graphInteger
data.integerVal = v
case *big.Int:
data.typ = graphInteger
if !v.IsInt64() {
return data, fmt.Errorf("redis: bigInt(%s) value out of range", v.String())
}
data.integerVal = v.Int64()
default:
return data, fmt.Errorf("redis: invalid reply - %q", reply)
}
return data, nil
}
func (cmd *GraphCmd) readStringArray(rd *proto.Reader) ([]string, error) {
n, err := rd.ReadArrayLen()
if err != nil {
return nil, err
}
if n == 0 {
return nil, nil
}
ss := make([]string, n)
for i := 0; i < n; i++ {
if ss[i], err = rd.ReadString(); err != nil {
return ss, err
}
}
return ss, nil
}

View File

@ -226,6 +226,7 @@ type Cmdable interface {
StreamCmdable
TimeseriesCmdable
JSONCmdable
GraphCmdable
}
type StatefulCmdable interface {

204
graph.go Normal file
View File

@ -0,0 +1,204 @@
package redis
import (
"context"
"strconv"
)
type GraphCmdable interface {
GraphQuery(ctx context.Context, key, query string) *GraphCmd
}
// GraphQuery executes a query on a graph, exec: GRAPH.QUERY key query
func (c cmdable) GraphQuery(ctx context.Context, key, query string) *GraphCmd {
cmd := NewGraphCmd(ctx, "GRAPH.QUERY", key, query)
_ = c(ctx, cmd)
return cmd
}
// ----------------------------------------------------------------------------
type GraphValue interface {
IsNil() bool
String() string
Int() int
Bool() bool
Float64() float64
Node() (*GraphNode, bool)
Edge() (*GraphEdge, bool)
}
type (
graphDataType int
graphRowType int
)
const (
graphInteger graphDataType = iota + 1 // int (graph int)
graphNil // nil (graph nil/null)
graphString // string (graph string/boolean/double)
)
const (
graphResultBasic graphRowType = iota + 1 // int/nil/string
graphResultNode // node(3), id +labels +properties
graphResultEdge // edge(5), id + type + src_node + dest_node + properties
)
type GraphResult struct {
noResult bool
text []string
field []string
rows [][]*graphRow
err error
}
// Message is used to obtain statistical information about the execution of graph statements.
func (g *GraphResult) Message() []string {
return g.text
}
// IsResult indicates whether the statement has a result set response.
// For example, a MATCH statement always has a result set (regardless of the number of rows in the response),
// while a CREATE statement does not.
func (g *GraphResult) IsResult() bool {
return !g.noResult
}
// Error return exec error
func (g *GraphResult) Error() error {
return g.err
}
// Field returns the fields in the query result set.
// If the statement has no result set (IsResult() == false), it returns nil.
func (g *GraphResult) Field() []string {
return g.field
}
// Len return the number of rows in the result set.
func (g *GraphResult) Len() int {
return len(g.rows)
}
// Row read the first row of data in the result set. If there is no data response, return redis.Nil.
func (g *GraphResult) Row() (map[string]GraphValue, error) {
if g.err != nil {
return nil, g.err
}
if g.noResult || len(g.field) == 0 || len(g.rows) == 0 || len(g.rows[0]) == 0 {
return nil, Nil
}
row := make(map[string]GraphValue, len(g.field))
for i, f := range g.field {
row[f] = g.rows[0][i]
}
return row, nil
}
// Rows return all result set responses, If there is no data response, return redis.Nil.
func (g *GraphResult) Rows() ([]map[string]GraphValue, error) {
if g.err != nil {
return nil, g.err
}
if g.noResult || len(g.field) == 0 || len(g.rows) == 0 || len(g.rows[0]) == 0 {
return nil, Nil
}
rows := make([]map[string]GraphValue, 0, len(g.rows))
for i := 0; i < len(g.rows); i++ {
row := make(map[string]GraphValue, len(g.field))
for x, f := range g.field {
row[f] = g.rows[i][x]
}
rows = append(rows, row)
}
return rows, nil
}
type graphRow struct {
typ graphRowType
basic GraphData
node GraphNode
edge GraphEdge
}
func (g *graphRow) IsNil() bool { return g.basic.IsNil() }
func (g *graphRow) String() string { return g.basic.String() }
func (g *graphRow) Int() int { return g.basic.Int() }
func (g *graphRow) Bool() bool { return g.basic.Bool() }
func (g *graphRow) Float64() float64 { return g.basic.Float64() }
func (g *graphRow) Node() (*GraphNode, bool) { return &g.node, g.typ == graphResultNode }
func (g *graphRow) Edge() (*GraphEdge, bool) { return &g.edge, g.typ == graphResultEdge }
type GraphData struct {
typ graphDataType
integerVal int64
stringVal string
}
func (d GraphData) IsNil() bool {
return d.typ == graphNil
}
func (d GraphData) String() string {
switch d.typ {
case graphInteger:
return strconv.FormatInt(d.integerVal, 10)
case graphNil:
return ""
case graphString:
return d.stringVal
default:
return ""
}
}
func (d GraphData) Int() int {
switch d.typ {
case graphInteger:
return int(d.integerVal)
case graphString:
n, _ := strconv.Atoi(d.stringVal)
return n
default:
return 0
}
}
func (d GraphData) Bool() bool {
switch d.typ {
case graphInteger:
return d.integerVal != 0
case graphNil:
return false
case graphString:
return d.stringVal == "true"
default:
return false
}
}
func (d GraphData) Float64() float64 {
if d.typ == graphInteger {
return float64(d.integerVal)
}
if d.typ == graphString {
v, _ := strconv.ParseFloat(d.stringVal, 64)
return v
}
return 0
}
type GraphNode struct {
ID int64
Labels []string
Properties map[string]GraphData
}
type GraphEdge struct {
ID int64
Typ string
SrcNode int64
DstNode int64
Properties map[string]GraphData
}

134
graph_test.go Normal file
View File

@ -0,0 +1,134 @@
package redis_test
import (
. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
)
var _ = Describe("Client", func() {
var client *redis.Client
var graph = "test-graph"
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("graph no-result", func() {
res, err := client.GraphQuery(ctx, graph, "CREATE ()").Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.IsResult()).To(BeFalse())
})
It("graph query result-basic", func() {
query := "CREATE (:per {id: 1024, name: 'foo', pr: 3.14, success: true})"
_, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
query = "MATCH (p:per {id: 1024}) RETURN p.id as id, p.name as name, p.pr as pr, p.success as success, p.non as non"
res, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.IsResult()).To(BeTrue())
Expect(res.Len()).To(Equal(1))
row, err := res.Row()
Expect(err).NotTo(HaveOccurred())
Expect(row).To(HaveLen(5))
Expect(row["id"].Int()).To(Equal(1024))
Expect(row["name"].Int()).To(Equal("foo"))
Expect(row["pr"].Float64()).To(Equal(3.14))
Expect(row["success"].Bool()).To(BeTrue())
Expect(row["non"].IsNil()).To(BeTrue())
})
It("graph query result-node", func() {
query := "CREATE (:per {id: 1024, name: 'foo', pr: 3.14, success: true})"
_, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
query = "MATCH (p:per {id: 1024}) RETURN p"
res, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.IsResult()).To(BeTrue())
Expect(res.Len()).To(Equal(1))
row, err := res.Row()
Expect(err).NotTo(HaveOccurred())
Expect(row).To(HaveLen(1))
node, ok := row["p"].Node()
Expect(ok).To(BeTrue())
Expect(node.Labels).To(Equal([]string{"per"}))
Expect(node.Properties["id"].Int()).To(Equal(1024))
Expect(node.Properties["name"].String()).To(Equal("foo"))
Expect(node.Properties["pr"].Float64()).To(Equal(3.14))
Expect(node.Properties["success"].Bool()).To(BeTrue())
})
It("graph query result-edge", func() {
query := "CREATE (:per {id: 1024}) - [:FRIENDS {ts: 100, msg: 'txt-msg'}] -> (:per {id: 2048})"
_, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
query = "MATCH (:per {id: 1024}) - [r:FRIENDS] -> (:per {id: 2048}) RETURN r"
res, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.IsResult()).To(BeTrue())
Expect(res.Len()).To(Equal(1))
row, err := res.Row()
Expect(err).NotTo(HaveOccurred())
Expect(row).To(HaveLen(1))
edge, ok := row["r"].Edge()
Expect(ok).To(BeTrue())
Expect(edge.Typ).To(Equal("FRIENDS"))
Expect(edge.Properties["ts"].Int()).To(Equal(100))
Expect(edge.Properties["msg"].String()).To(Equal("txt-msg"))
})
It("graph query no-row", func() {
query := "MATCH (p:per {id: 999}) RETURN p.name"
res, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.IsResult()).To(BeTrue())
Expect(res.Len()).To(Equal(0))
row, err := res.Row()
Expect(err).To(Equal(redis.Nil))
Expect(row).To(HaveLen(0))
rows, err := res.Rows()
Expect(err).To(Equal(redis.Nil))
Expect(rows).To(HaveLen(0))
})
It("graph query rows", func() {
query := "CREATE (:per {id: 1024}),(:per {id: 2048}),(:per {id: 4096})"
_, err := client.GraphQuery(ctx, graph, query).Result()
Expect(err).NotTo(HaveOccurred())
query = "MATCH (p:per) return p.id as id"
res, err := client.GraphQuery(ctx, graph, query).Result()
row, err := res.Row()
Expect(err).NotTo(HaveOccurred())
Expect(row).To(HaveLen(1))
Expect(row["id"].Int()).To(Equal(1024))
rows, err := res.Rows()
Expect(err).NotTo(HaveOccurred())
Expect(rows).To(HaveLen(3))
Expect(rows[0]["id"].Int()).To(Equal(1024))
Expect(rows[1]["id"].Int()).To(Equal(2048))
Expect(rows[2]["id"].Int()).To(Equal(4096))
})
})