diff --git a/internal/proto/scan.go b/internal/proto/scan.go index 0e994765..072b8e62 100644 --- a/internal/proto/scan.go +++ b/internal/proto/scan.go @@ -3,6 +3,7 @@ package proto import ( "encoding" "fmt" + "net" "reflect" "time" @@ -115,6 +116,9 @@ func Scan(b []byte, v interface{}) error { return nil case encoding.BinaryUnmarshaler: return v.UnmarshalBinary(b) + case *net.IP: + *v = b + return nil default: return fmt.Errorf( "redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", v) diff --git a/internal/proto/writer.go b/internal/proto/writer.go index c4260981..c5d48cda 100644 --- a/internal/proto/writer.go +++ b/internal/proto/writer.go @@ -4,6 +4,7 @@ import ( "encoding" "fmt" "io" + "net" "strconv" "time" @@ -106,6 +107,8 @@ func (w *Writer) WriteArg(v interface{}) error { return err } return w.bytes(b) + case net.IP: + return w.bytes(v) default: return fmt.Errorf( "redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) diff --git a/internal/proto/writer_test.go b/internal/proto/writer_test.go index ebae5692..91efa017 100644 --- a/internal/proto/writer_test.go +++ b/internal/proto/writer_test.go @@ -3,6 +3,8 @@ package proto_test import ( "bytes" "encoding" + "fmt" + "net" "testing" "time" @@ -64,6 +66,13 @@ var _ = Describe("WriteBuffer", func() { Expect(buf.Len()).To(Equal(15)) }) + + It("should append net.IP", func() { + ip := net.ParseIP("192.168.1.1") + err := wr.WriteArgs([]interface{}{ip}) + Expect(err).NotTo(HaveOccurred()) + Expect(buf.String()).To(Equal(fmt.Sprintf("*1\r\n$16\r\n%s\r\n", bytes.NewBuffer(ip)))) + }) }) type discard struct{} diff --git a/redis_test.go b/redis_test.go index 095da2db..47792ac8 100644 --- a/redis_test.go +++ b/redis_test.go @@ -316,6 +316,18 @@ var _ = Describe("Client", func() { err := client.Conn(ctx).Get(ctx, "this-key-does-not-exist").Err() Expect(err).To(Equal(redis.Nil)) }) + + It("should set and scan net.IP", func() { + ip := net.ParseIP("192.168.1.1") + err := client.Set(ctx, "ip", ip, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + var ip2 net.IP + err = client.Get(ctx, "ip").Scan(&ip2) + Expect(err).NotTo(HaveOccurred()) + + Expect(ip2).To(Equal(ip)) + }) }) var _ = Describe("Client timeout", func() {