mirror of https://github.com/tidwall/tile38.git
394 lines
12 KiB
Go
394 lines
12 KiB
Go
|
// Package ndr provides the ability to unmarshal NDR encoded byte steams into Go data structures
|
||
|
package ndr
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
// Struct tag values
|
||
|
const (
|
||
|
TagConformant = "conformant"
|
||
|
TagVarying = "varying"
|
||
|
TagPointer = "pointer"
|
||
|
TagPipe = "pipe"
|
||
|
)
|
||
|
|
||
|
// Decoder unmarshals NDR byte stream data into a Go struct representation
|
||
|
type Decoder struct {
|
||
|
r *bufio.Reader // source of the data
|
||
|
size int // initial size of bytes in buffer
|
||
|
ch CommonHeader // NDR common header
|
||
|
ph PrivateHeader // NDR private header
|
||
|
conformantMax []uint32 // conformant max values that were moved to the beginning of the structure
|
||
|
s interface{} // pointer to the structure being populated
|
||
|
current []string // keeps track of the current field being populated
|
||
|
}
|
||
|
|
||
|
type deferedPtr struct {
|
||
|
v reflect.Value
|
||
|
tag reflect.StructTag
|
||
|
}
|
||
|
|
||
|
// NewDecoder creates a new instance of a NDR Decoder.
|
||
|
func NewDecoder(r io.Reader) *Decoder {
|
||
|
dec := new(Decoder)
|
||
|
dec.r = bufio.NewReader(r)
|
||
|
dec.r.Peek(int(commonHeaderBytes)) // For some reason an operation is needed on the buffer to initialise it so Buffered() != 0
|
||
|
dec.size = dec.r.Buffered()
|
||
|
return dec
|
||
|
}
|
||
|
|
||
|
// Decode unmarshals the NDR encoded bytes into the pointer of a struct provided.
|
||
|
func (dec *Decoder) Decode(s interface{}) error {
|
||
|
dec.s = s
|
||
|
err := dec.readCommonHeader()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
err = dec.readPrivateHeader()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = dec.r.Discard(4) //The next 4 bytes are an RPC unique pointer referent. We just skip these.
|
||
|
if err != nil {
|
||
|
return Errorf("unable to process byte stream: %v", err)
|
||
|
}
|
||
|
|
||
|
return dec.process(s, reflect.StructTag(""))
|
||
|
}
|
||
|
|
||
|
func (dec *Decoder) process(s interface{}, tag reflect.StructTag) error {
|
||
|
// Scan for conformant fields as their max counts are moved to the beginning
|
||
|
// http://pubs.opengroup.org/onlinepubs/9629399/chap14.htm#tagfcjh_37
|
||
|
err := dec.scanConformantArrays(s, tag)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
// Recursively fill the struct fields
|
||
|
var localDef []deferedPtr
|
||
|
err = dec.fill(s, tag, &localDef)
|
||
|
if err != nil {
|
||
|
return Errorf("could not decode: %v", err)
|
||
|
}
|
||
|
// Read any deferred referents associated with pointers
|
||
|
for _, p := range localDef {
|
||
|
err = dec.process(p.v, p.tag)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not decode deferred referent: %v", err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// scanConformantArrays scans the structure for embedded conformant fields and captures the maximum element counts for
|
||
|
// dimensions of the array that are moved to the beginning of the structure.
|
||
|
func (dec *Decoder) scanConformantArrays(s interface{}, tag reflect.StructTag) error {
|
||
|
err := dec.conformantScan(s, tag)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to scan for embedded conformant arrays: %v", err)
|
||
|
}
|
||
|
for i := range dec.conformantMax {
|
||
|
dec.conformantMax[i], err = dec.readUint32()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not read preceding conformant max count index %d: %v", i, err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// conformantScan inspects the structure's fields for whether they are conformant.
|
||
|
func (dec *Decoder) conformantScan(s interface{}, tag reflect.StructTag) error {
|
||
|
ndrTag := parseTags(tag)
|
||
|
if ndrTag.HasValue(TagPointer) {
|
||
|
return nil
|
||
|
}
|
||
|
v := getReflectValue(s)
|
||
|
switch v.Kind() {
|
||
|
case reflect.Struct:
|
||
|
for i := 0; i < v.NumField(); i++ {
|
||
|
err := dec.conformantScan(v.Field(i), v.Type().Field(i).Tag)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
case reflect.String:
|
||
|
if !ndrTag.HasValue(TagConformant) {
|
||
|
break
|
||
|
}
|
||
|
dec.conformantMax = append(dec.conformantMax, uint32(0))
|
||
|
case reflect.Slice:
|
||
|
if !ndrTag.HasValue(TagConformant) {
|
||
|
break
|
||
|
}
|
||
|
d, t := sliceDimensions(v.Type())
|
||
|
for i := 0; i < d; i++ {
|
||
|
dec.conformantMax = append(dec.conformantMax, uint32(0))
|
||
|
}
|
||
|
// For string arrays there is a common max for the strings within the array.
|
||
|
if t.Kind() == reflect.String {
|
||
|
dec.conformantMax = append(dec.conformantMax, uint32(0))
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (dec *Decoder) isPointer(v reflect.Value, tag reflect.StructTag, def *[]deferedPtr) (bool, error) {
|
||
|
// Pointer so defer filling the referent
|
||
|
ndrTag := parseTags(tag)
|
||
|
if ndrTag.HasValue(TagPointer) {
|
||
|
p, err := dec.readUint32()
|
||
|
if err != nil {
|
||
|
return true, fmt.Errorf("could not read pointer: %v", err)
|
||
|
}
|
||
|
ndrTag.delete(TagPointer)
|
||
|
if p != 0 {
|
||
|
// if pointer is not zero add to the deferred items at end of stream
|
||
|
*def = append(*def, deferedPtr{v, ndrTag.StructTag()})
|
||
|
}
|
||
|
return true, nil
|
||
|
}
|
||
|
return false, nil
|
||
|
}
|
||
|
|
||
|
func getReflectValue(s interface{}) (v reflect.Value) {
|
||
|
if r, ok := s.(reflect.Value); ok {
|
||
|
v = r
|
||
|
} else {
|
||
|
if reflect.ValueOf(s).Kind() == reflect.Ptr {
|
||
|
v = reflect.ValueOf(s).Elem()
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// fill populates fields with values from the NDR byte stream.
|
||
|
func (dec *Decoder) fill(s interface{}, tag reflect.StructTag, localDef *[]deferedPtr) error {
|
||
|
v := getReflectValue(s)
|
||
|
|
||
|
//// Pointer so defer filling the referent
|
||
|
ptr, err := dec.isPointer(v, tag, localDef)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err)
|
||
|
}
|
||
|
if ptr {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Populate the value from the byte stream
|
||
|
switch v.Kind() {
|
||
|
case reflect.Struct:
|
||
|
dec.current = append(dec.current, v.Type().Name()) //Track the current field being filled
|
||
|
// in case struct is a union, track this and the selected union field for efficiency
|
||
|
var unionTag reflect.Value
|
||
|
var unionField string // field to fill if struct is a union
|
||
|
// Go through each field in the struct and recursively fill
|
||
|
for i := 0; i < v.NumField(); i++ {
|
||
|
fieldName := v.Type().Field(i).Name
|
||
|
dec.current = append(dec.current, fieldName) //Track the current field being filled
|
||
|
//fmt.Fprintf(os.Stderr, "DEBUG Decoding: %s\n", strings.Join(dec.current, "/"))
|
||
|
structTag := v.Type().Field(i).Tag
|
||
|
ndrTag := parseTags(structTag)
|
||
|
|
||
|
// Union handling
|
||
|
if !unionTag.IsValid() {
|
||
|
// Is this field a union tag?
|
||
|
unionTag = dec.isUnion(v.Field(i), structTag)
|
||
|
} else {
|
||
|
// What is the selected field value of the union if we don't already know
|
||
|
if unionField == "" {
|
||
|
unionField, err = unionSelectedField(v, unionTag)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not determine selected union value field for %s with discriminat"+
|
||
|
" tag %s: %v", v.Type().Name(), unionTag, err)
|
||
|
}
|
||
|
}
|
||
|
if ndrTag.HasValue(TagUnionField) && fieldName != unionField {
|
||
|
// is a union and this field has not been selected so will skip it.
|
||
|
dec.current = dec.current[:len(dec.current)-1] //This field has been skipped so remove it from the current field tracker
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Check if field is a pointer
|
||
|
if v.Field(i).Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) &&
|
||
|
v.Field(i).Type().Kind() == reflect.Slice && v.Field(i).Type().Elem().Kind() == reflect.Uint8 {
|
||
|
//field is for rawbytes
|
||
|
structTag, err = addSizeToTag(v, v.Field(i), structTag)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not get rawbytes field(%s) size: %v", strings.Join(dec.current, "/"), err)
|
||
|
}
|
||
|
ptr, err := dec.isPointer(v.Field(i), structTag, localDef)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not process struct field(%s): %v", strings.Join(dec.current, "/"), err)
|
||
|
}
|
||
|
if !ptr {
|
||
|
err := dec.readRawBytes(v.Field(i), structTag)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err)
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
err := dec.fill(v.Field(i), structTag, localDef)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill struct field(%s): %v", strings.Join(dec.current, "/"), err)
|
||
|
}
|
||
|
}
|
||
|
dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker
|
||
|
}
|
||
|
dec.current = dec.current[:len(dec.current)-1] //This field has been filled so remove it from the current field tracker
|
||
|
case reflect.Bool:
|
||
|
i, err := dec.readBool()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Uint8:
|
||
|
i, err := dec.readUint8()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Uint16:
|
||
|
i, err := dec.readUint16()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Uint32:
|
||
|
i, err := dec.readUint32()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Uint64:
|
||
|
i, err := dec.readUint64()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Int8:
|
||
|
i, err := dec.readInt8()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Int16:
|
||
|
i, err := dec.readInt16()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Int32:
|
||
|
i, err := dec.readInt32()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Int64:
|
||
|
i, err := dec.readInt64()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %s: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.String:
|
||
|
ndrTag := parseTags(tag)
|
||
|
conformant := ndrTag.HasValue(TagConformant)
|
||
|
// strings are always varying so this is assumed without an explicit tag
|
||
|
var s string
|
||
|
var err error
|
||
|
if conformant {
|
||
|
s, err = dec.readConformantVaryingString(localDef)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill with conformant varying string: %v", err)
|
||
|
}
|
||
|
} else {
|
||
|
s, err = dec.readVaryingString(localDef)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill with varying string: %v", err)
|
||
|
}
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(s))
|
||
|
case reflect.Float32:
|
||
|
i, err := dec.readFloat32()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Float64:
|
||
|
i, err := dec.readFloat64()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill %v: %v", v.Type().Name(), err)
|
||
|
}
|
||
|
v.Set(reflect.ValueOf(i))
|
||
|
case reflect.Array:
|
||
|
err := dec.fillFixedArray(v, tag, localDef)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
case reflect.Slice:
|
||
|
if v.Type().Implements(reflect.TypeOf(new(RawBytes)).Elem()) && v.Type().Elem().Kind() == reflect.Uint8 {
|
||
|
//field is for rawbytes
|
||
|
err := dec.readRawBytes(v, tag)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not fill raw bytes struct field(%s): %v", strings.Join(dec.current, "/"), err)
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
ndrTag := parseTags(tag)
|
||
|
conformant := ndrTag.HasValue(TagConformant)
|
||
|
varying := ndrTag.HasValue(TagVarying)
|
||
|
if ndrTag.HasValue(TagPipe) {
|
||
|
err := dec.fillPipe(v, tag)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
_, t := sliceDimensions(v.Type())
|
||
|
if t.Kind() == reflect.String && !ndrTag.HasValue(subStringArrayValue) {
|
||
|
// String array
|
||
|
err := dec.readStringsArray(v, tag, localDef)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
// varying is assumed as fixed arrays use the Go array type rather than slice
|
||
|
if conformant && varying {
|
||
|
err := dec.fillConformantVaryingArray(v, tag, localDef)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else if !conformant && varying {
|
||
|
err := dec.fillVaryingArray(v, tag, localDef)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
//default to conformant and not varying
|
||
|
err := dec.fillConformantArray(v, tag, localDef)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
default:
|
||
|
return fmt.Errorf("unsupported type")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// readBytes returns a number of bytes from the NDR byte stream.
|
||
|
func (dec *Decoder) readBytes(n int) ([]byte, error) {
|
||
|
//TODO make this take an int64 as input to allow for larger values on all systems?
|
||
|
b := make([]byte, n, n)
|
||
|
m, err := dec.r.Read(b)
|
||
|
if err != nil || m != n {
|
||
|
return b, fmt.Errorf("error reading bytes from stream: %v", err)
|
||
|
}
|
||
|
return b, nil
|
||
|
}
|