Add redis.Scan() to scan results from redis maps into structs.

The package uses reflection to decode default types (int, string
etc.) from Redis map results (key-value pair sequences) into
struct fields where the fields are matched to Redis keys by tags.

Similar to how `encoding/json` allows custom decoders using
`UnmarshalJSON()`, the package supports decoding of arbitrary
types into struct fields by defining a `Decode(string) error`
function on types.

The field/type spec of every struct that's passed to Scan() is
cached in the package so that subsequent scans avoid iteration
and reflection of the struct's fields.
This commit is contained in:
Kailash Nadh 2021-01-28 10:58:24 +05:30
parent bf010a705d
commit 6f96bebac7
2 changed files with 243 additions and 0 deletions

156
internal/hscan/hscan.go Normal file
View File

@ -0,0 +1,156 @@
package hscan
import (
"errors"
"fmt"
"reflect"
"strconv"
)
// decoderFunc represents decoding functions for default built-in types.
type decoderFunc func(reflect.Value, string) error
var (
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1)
decoders = []decoderFunc{
reflect.Bool: decodeBool,
reflect.Int: decodeInt,
reflect.Int8: decodeInt,
reflect.Int16: decodeInt,
reflect.Int32: decodeInt,
reflect.Int64: decodeInt,
reflect.Uint: decodeUint,
reflect.Uint8: decodeUint,
reflect.Uint16: decodeUint,
reflect.Uint32: decodeUint,
reflect.Uint64: decodeUint,
reflect.Float32: decodeFloat,
reflect.Float64: decodeFloat,
reflect.Complex64: decodeUnsupported,
reflect.Complex128: decodeUnsupported,
reflect.Array: decodeUnsupported,
reflect.Chan: decodeUnsupported,
reflect.Func: decodeUnsupported,
reflect.Interface: decodeUnsupported,
reflect.Map: decodeUnsupported,
reflect.Ptr: decodeUnsupported,
reflect.Slice: decodeStringSlice,
reflect.String: decodeString,
reflect.Struct: decodeUnsupported,
reflect.UnsafePointer: decodeUnsupported,
}
// Global map of struct field specs that is populated once for every new
// struct type that is scanned. This caches the field types and the corresponding
// decoder functions to avoid iterating through struct fields on subsequent scans.
structSpecs = newStructMap()
)
// Scan scans the results from a key-value Redis map result set ([]interface{})
// to a destination struct. The Redis keys are matched to the struct's field
// with the `redis` tag.
func Scan(vals []interface{}, dest interface{}) error {
if len(vals)%2 != 0 {
return errors.New("args should have an even number of items (key-val)")
}
// The destination to scan into should be a struct pointer.
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr || v.IsNil() {
return fmt.Errorf("redis.Scan(non-pointer %T)", dest)
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("redis.Scan(non-struct %T)", dest)
}
// If the struct field spec is not cached, build and cache it to avoid
// iterating through the fields of a struct type every time values are
// scanned into it.
typ := v.Type()
fMap, ok := structSpecs.get(typ)
if !ok {
fMap = makeStructSpecs(v, "redis")
structSpecs.set(typ, fMap)
}
// Iterate through the (key, value) sequence.
for i := 0; i < len(vals); i += 2 {
key, ok := vals[i].(string)
if !ok {
continue
}
val, ok := vals[i+1].(string)
if !ok {
continue
}
// Check if the field name is in the field spec map.
field, ok := fMap.get(key)
if !ok {
continue
}
if err := field.fn(v.Field(field.index), val); err != nil {
return err
}
}
return nil
}
func decodeBool(f reflect.Value, s string) error {
b, err := strconv.ParseBool(s)
if err != nil {
return err
}
f.SetBool(b)
return nil
}
func decodeInt(f reflect.Value, s string) error {
v, err := strconv.ParseInt(s, 10, 0)
if err != nil {
return err
}
f.SetInt(v)
return nil
}
func decodeUint(f reflect.Value, s string) error {
v, err := strconv.ParseUint(s, 10, 0)
if err != nil {
return err
}
f.SetUint(v)
return nil
}
func decodeFloat(f reflect.Value, s string) error {
v, err := strconv.ParseFloat(s, 0)
if err != nil {
return err
}
f.SetFloat(v)
return nil
}
func decodeString(f reflect.Value, s string) error {
f.SetString(s)
return nil
}
func decodeStringSlice(f reflect.Value, s string) error {
// []byte slice ([]uint8).
if f.Type().Elem().Kind() == reflect.Uint8 {
f.SetBytes([]byte(s))
}
return nil
}
func decodeUnsupported(f reflect.Value, s string) error {
return fmt.Errorf("redis.Scan(unsupported type %v)", f)
}

View File

@ -0,0 +1,87 @@
package hscan
import (
"reflect"
"strings"
"sync"
)
// structField represents a single field in a target struct.
type structField struct {
index int
fn decoderFunc
}
// structFields contains the list of all fields in a target struct.
type structFields struct {
m map[string]*structField
}
// structMap contains the map of struct fields for target structs
// indexed by the struct type.
type structMap struct {
m sync.Map
}
func newStructMap() *structMap {
return &structMap{
m: sync.Map{},
}
}
func (s *structMap) get(t reflect.Type) (*structFields, bool) {
m, ok := s.m.Load(t)
if !ok {
return nil, ok
}
return m.(*structFields), true
}
func (s *structMap) set(t reflect.Type, sf *structFields) {
s.m.Store(t, sf)
}
func newStructFields() *structFields {
return &structFields{
m: make(map[string]*structField),
}
}
func (s *structFields) set(tag string, sf *structField) {
s.m[tag] = sf
}
func (s *structFields) get(tag string) (*structField, bool) {
f, ok := s.m[tag]
return f, ok
}
func makeStructSpecs(ob reflect.Value, fieldTag string) *structFields {
var (
num = ob.NumField()
out = newStructFields()
)
for i := 0; i < num; i++ {
f := ob.Field(i)
if !f.IsValid() || !f.CanSet() {
continue
}
tag := ob.Type().Field(i).Tag.Get(fieldTag)
if tag == "" || tag == "-" {
continue
}
tag = strings.Split(tag, ",")[0]
if tag == "" {
continue
}
// Use the built-in decoder.
out.set(tag, &structField{index: i, fn: decoders[f.Kind()]})
}
return out
}