From 860625c0b83402d2dc3452fc0b74468b01f84cbe Mon Sep 17 00:00:00 2001 From: siddontang Date: Wed, 24 Sep 2014 09:38:19 +0800 Subject: [PATCH] add package --- bson/LICENSE | 25 + bson/bson.go | 682 +++++++++++++++++ bson/bson_test.go | 1472 ++++++++++++++++++++++++++++++++++++ bson/decode.go | 806 ++++++++++++++++++++ bson/encode.go | 479 ++++++++++++ log/doc.go | 20 + log/filehandler.go | 200 +++++ log/handler.go | 49 ++ log/log.go | 239 ++++++ log/log_test.go | 52 ++ log/sockethandler.go | 65 ++ snappy/LICENSE | 27 + snappy/decode.go | 124 +++ snappy/encode.go | 174 +++++ snappy/snappy.go | 38 + snappy/snappy_test.go | 261 +++++++ websocket/client.go | 60 ++ websocket/client_test.go | 99 +++ websocket/conn.go | 321 ++++++++ websocket/pingpong_test.go | 51 ++ websocket/server.go | 105 +++ websocket/server_test.go | 97 +++ websocket/util.go | 36 + 23 files changed, 5482 insertions(+) create mode 100644 bson/LICENSE create mode 100644 bson/bson.go create mode 100644 bson/bson_test.go create mode 100644 bson/decode.go create mode 100644 bson/encode.go create mode 100644 log/doc.go create mode 100644 log/filehandler.go create mode 100644 log/handler.go create mode 100644 log/log.go create mode 100644 log/log_test.go create mode 100644 log/sockethandler.go create mode 100644 snappy/LICENSE create mode 100644 snappy/decode.go create mode 100644 snappy/encode.go create mode 100644 snappy/snappy.go create mode 100644 snappy/snappy_test.go create mode 100644 websocket/client.go create mode 100644 websocket/client_test.go create mode 100644 websocket/conn.go create mode 100644 websocket/pingpong_test.go create mode 100644 websocket/server.go create mode 100644 websocket/server_test.go create mode 100644 websocket/util.go diff --git a/bson/LICENSE b/bson/LICENSE new file mode 100644 index 0000000..8903260 --- /dev/null +++ b/bson/LICENSE @@ -0,0 +1,25 @@ +BSON library for Go + +Copyright (c) 2010-2012 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/bson/bson.go b/bson/bson.go new file mode 100644 index 0000000..3ebfd84 --- /dev/null +++ b/bson/bson.go @@ -0,0 +1,682 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package bson is an implementation of the BSON specification for Go: +// +// http://bsonspec.org +// +// It was created as part of the mgo MongoDB driver for Go, but is standalone +// and may be used on its own without the driver. +package bson + +import ( + "crypto/md5" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +// -------------------------------------------------------------------------- +// The public API. + +// A value implementing the bson.Getter interface will have its GetBSON +// method called when the given value has to be marshalled, and the result +// of this method will be marshaled in place of the actual object. +// +// If GetBSON returns return a non-nil error, the marshalling procedure +// will stop and error out with the provided value. +type Getter interface { + GetBSON() (interface{}, error) +} + +// A value implementing the bson.Setter interface will receive the BSON +// value via the SetBSON method during unmarshaling, and the object +// itself will not be changed as usual. +// +// If setting the value works, the method should return nil or alternatively +// bson.SetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a value of type bson.TypeError, the +// BSON value will be omitted from a map or slice being decoded and the +// unmarshalling will continue. If it returns any other non-nil error, the +// unmarshalling procedure will stop and error out with the provided value. +// +// This interface is generally useful in pointer receivers, since the method +// will want to change the receiver. A type field that implements the Setter +// interface doesn't have to be a pointer, though. +// +// Unlike the usual behavior, unmarshalling onto a value that implements a +// Setter interface will NOT reset the value to its zero state. This allows +// the value to decide by itself how to be unmarshalled. +// +// For example: +// +// type MyString string +// +// func (s *MyString) SetBSON(raw bson.Raw) error { +// return raw.Unmarshal(s) +// } +// +type Setter interface { + SetBSON(raw Raw) error +} + +// SetZero may be returned from a SetBSON method to have the value set to +// its respective zero value. When used in pointer values, this will set the +// field to nil rather than to the pre-allocated value. +var SetZero = errors.New("set to zero") + +// M is a convenient alias for a map[string]interface{} map, useful for +// dealing with BSON in a native way. For instance: +// +// bson.M{"a": 1, "b": true} +// +// There's no special handling for this type in addition to what's done anyway +// for an equivalent map type. Elements in the map will be dumped in an +// undefined ordered. See also the bson.D type for an ordered alternative. +type M map[string]interface{} + +// D represents a BSON document containing ordered elements. For example: +// +// bson.D{{"a", 1}, {"b", true}} +// +// In some situations, such as when creating indexes for MongoDB, the order in +// which the elements are defined is important. If the order is not important, +// using a map is generally more comfortable. See bson.M and bson.RawD. +type D []DocElem + +// See the D type. +type DocElem struct { + Name string + Value interface{} +} + +// Map returns a map out of the ordered element name/value pairs in d. +func (d D) Map() (m M) { + m = make(M, len(d)) + for _, item := range d { + m[item.Name] = item.Value + } + return m +} + +// The Raw type represents raw unprocessed BSON documents and elements. +// Kind is the kind of element as defined per the BSON specification, and +// Data is the raw unprocessed data for the respective element. +// Using this type it is possible to unmarshal or marshal values partially. +// +// Relevant documentation: +// +// http://bsonspec.org/#/specification +// +type Raw struct { + Kind byte + Data []byte +} + +// RawD represents a BSON document containing raw unprocessed elements. +// This low-level representation may be useful when lazily processing +// documents of uncertain content, or when manipulating the raw content +// documents in general. +type RawD []RawDocElem + +// See the RawD type. +type RawDocElem struct { + Name string + Value Raw +} + +// ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes +// long. MongoDB objects by default have such a property set in their "_id" +// property. +// +// http://www.mongodb.org/display/DOCS/Object+IDs +type ObjectId string + +// ObjectIdHex returns an ObjectId from the provided hex representation. +// Calling this function with an invalid hex representation will +// cause a runtime panic. See the IsObjectIdHex function. +func ObjectIdHex(s string) ObjectId { + d, err := hex.DecodeString(s) + if err != nil || len(d) != 12 { + panic(fmt.Sprintf("Invalid input to ObjectIdHex: %q", s)) + } + return ObjectId(d) +} + +// IsObjectIdHex returns whether s is a valid hex representation of +// an ObjectId. See the ObjectIdHex function. +func IsObjectIdHex(s string) bool { + if len(s) != 24 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + +// objectIdCounter is atomically incremented when generating a new ObjectId +// using NewObjectId() function. It's used as a counter part of an id. +var objectIdCounter uint32 = 0 + +// machineId stores machine id generated once and used in subsequent calls +// to NewObjectId function. +var machineId = readMachineId() + +// readMachineId generates machine id and puts it into the machineId global +// variable. If this function fails to get the hostname, it will cause +// a runtime error. +func readMachineId() []byte { + var sum [3]byte + id := sum[:] + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id + } + hw := md5.New() + hw.Write([]byte(hostname)) + copy(id, hw.Sum(nil)) + return id +} + +// NewObjectId returns a new unique ObjectId. +func NewObjectId() ObjectId { + var b [12]byte + // Timestamp, 4 bytes, big endian + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) + // Machine, first 3 bytes of md5(hostname) + b[4] = machineId[0] + b[5] = machineId[1] + b[6] = machineId[2] + // Pid, 2 bytes, specs don't specify endianness, but we use big endian. + pid := os.Getpid() + b[7] = byte(pid >> 8) + b[8] = byte(pid) + // Increment, 3 bytes, big endian + i := atomic.AddUint32(&objectIdCounter, 1) + b[9] = byte(i >> 16) + b[10] = byte(i >> 8) + b[11] = byte(i) + return ObjectId(b[:]) +} + +// NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled +// with the provided number of seconds from epoch UTC, and all other parts +// filled with zeroes. It's not safe to insert a document with an id generated +// by this method, it is useful only for queries to find documents with ids +// generated before or after the specified timestamp. +func NewObjectIdWithTime(t time.Time) ObjectId { + var b [12]byte + binary.BigEndian.PutUint32(b[:4], uint32(t.Unix())) + return ObjectId(string(b[:])) +} + +// String returns a hex string representation of the id. +// Example: ObjectIdHex("4d88e15b60f486e428412dc9"). +func (id ObjectId) String() string { + return fmt.Sprintf(`ObjectIdHex("%x")`, string(id)) +} + +// Hex returns a hex representation of the ObjectId. +func (id ObjectId) Hex() string { + return hex.EncodeToString([]byte(id)) +} + +// MarshalJSON turns a bson.ObjectId into a json.Marshaller. +func (id ObjectId) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%x"`, string(id))), nil +} + +// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller. +func (id *ObjectId) UnmarshalJSON(data []byte) error { + if len(data) != 26 || data[0] != '"' || data[25] != '"' { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s", string(data))) + } + var buf [12]byte + _, err := hex.Decode(buf[:], data[1:25]) + if err != nil { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s (%s)", string(data), err)) + } + *id = ObjectId(string(buf[:])) + return nil +} + +// Valid returns true if id is valid. A valid id must contain exactly 12 bytes. +func (id ObjectId) Valid() bool { + return len(id) == 12 +} + +// byteSlice returns byte slice of id from start to end. +// Calling this function with an invalid id will cause a runtime panic. +func (id ObjectId) byteSlice(start, end int) []byte { + if len(id) != 12 { + panic(fmt.Sprintf("Invalid ObjectId: %q", string(id))) + } + return []byte(string(id)[start:end]) +} + +// Time returns the timestamp part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Time() time.Time { + // First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch. + secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4))) + return time.Unix(secs, 0) +} + +// Machine returns the 3-byte machine id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Machine() []byte { + return id.byteSlice(4, 7) +} + +// Pid returns the process id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Pid() uint16 { + return binary.BigEndian.Uint16(id.byteSlice(7, 9)) +} + +// Counter returns the incrementing value part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Counter() int32 { + b := id.byteSlice(9, 12) + // Counter is stored as big-endian 3-byte value + return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2])) +} + +// The Symbol type is similar to a string and is used in languages with a +// distinct symbol type. +type Symbol string + +// Now returns the current time with millisecond precision. MongoDB stores +// timestamps with the same precision, so a Time returned from this method +// will not change after a roundtrip to the database. That's the only reason +// why this function exists. Using the time.Now function also works fine +// otherwise. +func Now() time.Time { + return time.Unix(0, time.Now().UnixNano()/1e6*1e6) +} + +// MongoTimestamp is a special internal type used by MongoDB that for some +// strange reason has its own datatype defined in BSON. +type MongoTimestamp int64 + +type orderKey int64 + +// MaxKey is a special value that compares higher than all other possible BSON +// values in a MongoDB database. +var MaxKey = orderKey(1<<63 - 1) + +// MinKey is a special value that compares lower than all other possible BSON +// values in a MongoDB database. +var MinKey = orderKey(-1 << 63) + +type undefined struct{} + +// Undefined represents the undefined BSON value. +var Undefined undefined + +// Binary is a representation for non-standard binary values. Any kind should +// work, but the following are known as of this writing: +// +// 0x00 - Generic. This is decoded as []byte(data), not Binary{0x00, data}. +// 0x01 - Function (!?) +// 0x02 - Obsolete generic. +// 0x03 - UUID +// 0x05 - MD5 +// 0x80 - User defined. +// +type Binary struct { + Kind byte + Data []byte +} + +// RegEx represents a regular expression. The Options field may contain +// individual characters defining the way in which the pattern should be +// applied, and must be sorted. Valid options as of this writing are 'i' for +// case insensitive matching, 'm' for multi-line matching, 'x' for verbose +// mode, 'l' to make \w, \W, and similar be locale-dependent, 's' for dot-all +// mode (a '.' matches everything), and 'u' to make \w, \W, and similar match +// unicode. The value of the Options parameter is not verified before being +// marshaled into the BSON format. +type RegEx struct { + Pattern string + Options string +} + +// JavaScript is a type that holds JavaScript code. If Scope is non-nil, it +// will be marshaled as a mapping from identifiers to values that may be +// used when evaluating the provided Code. +type JavaScript struct { + Code string + Scope interface{} +} + +const initialBufferSize = 64 + +func handleErr(err *error) { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } else if _, ok := r.(externalPanic); ok { + panic(r) + } else if s, ok := r.(string); ok { + *err = errors.New(s) + } else if e, ok := r.(error); ok { + *err = e + } else { + panic(r) + } + } +} + +// Marshal serializes the in value, which may be a map or a struct value. +// In the case of struct values, only exported fields will be serialized. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported: +// +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. +// +// minsize Marshal an int64 value as an int32, if that's feasible +// while preserving the numeric value. +// +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the bson keys of other struct fields. +// +// Some examples: +// +// type T struct { +// A bool +// B int "myb" +// C string "myc,omitempty" +// D string `bson:",omitempty" json:"jsonkey"` +// E int64 ",minsize" +// F int64 "myf,omitempty,minsize" +// } +// +func Marshal(in interface{}) (out []byte, err error) { + defer handleErr(&err) + e := &encoder{make([]byte, 0, initialBufferSize)} + e.addDoc(reflect.ValueOf(in)) + return e.out, nil +} + +// Unmarshal deserializes data from in into the out value. The out value +// must be a map, a pointer to a struct, or a pointer to a bson.D value. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported during unmarshal (see the +// Marshal method for other flags): +// +// inline Inline the field, which must be a struct or a map. +// Inlined structs are handled as if its fields were part +// of the outer struct. An inlined map causes keys that do +// not match any other struct field to be inserted in the +// map rather than being discarded as usual. +// +// The target field or element types of out may not necessarily match +// the BSON values of the provided data. The following conversions are +// made automatically: +// +// - Numeric types are converted if at least the integer part of the +// value would be preserved correctly +// - Bools are converted to numeric types as 1 or 0 +// - Numeric types are converted to bools as true if not 0 or false otherwise +// - Binary and string BSON data is converted to a string, array or byte slice +// +// If the value would not fit the type and cannot be converted, it's +// silently skipped. +// +// Pointer values are initialized when necessary. +func Unmarshal(in []byte, out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Map, reflect.Ptr: + d := newDecoder(in) + d.readDocTo(v) + case reflect.Struct: + return errors.New("Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Unmarshal needs a map or a pointer to a struct.") + } + return nil +} + +// Unmarshal deserializes raw into the out value. If the out value type +// is not compatible with raw, a *bson.TypeError is returned. +// +// See the Unmarshal function documentation for more details on the +// unmarshalling process. +func (raw Raw) Unmarshal(out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + v = v.Elem() + fallthrough + case reflect.Map: + d := newDecoder(raw.Data) + good := d.readElemTo(v, raw.Kind) + if !good { + return &TypeError{v.Type(), raw.Kind} + } + case reflect.Struct: + return errors.New("Raw Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Raw Unmarshal needs a map or a valid pointer.") + } + return nil +} + +type TypeError struct { + Type reflect.Type + Kind byte +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("BSON kind 0x%02x isn't compatible with type %s", e.Kind, e.Type.String()) +} + +// -------------------------------------------------------------------------- +// Maintain a mapping of keys to structure field indexes + +type structInfo struct { + FieldsMap map[string]fieldInfo + FieldsList []fieldInfo + InlineMap int + Zero reflect.Value +} + +type fieldInfo struct { + Key string + Num int + OmitEmpty bool + MinSize bool + Inline []int +} + +var structMap = make(map[reflect.Type]*structInfo) +var structMapMutex sync.RWMutex + +type externalPanic string + +func (e externalPanic) String() string { + return string(e) +} + +func getStructInfo(st reflect.Type) (*structInfo, error) { + structMapMutex.RLock() + sinfo, found := structMap[st] + structMapMutex.RUnlock() + if found { + return sinfo, nil + } + n := st.NumField() + fieldsMap := make(map[string]fieldInfo) + fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 + for i := 0; i != n; i++ { + field := st.Field(i) + if field.PkgPath != "" { + continue // Private field + } + + info := fieldInfo{Num: i} + + tag := field.Tag.Get("bson") + if tag == "" && strings.Index(string(field.Tag), ":") < 0 { + tag = string(field.Tag) + } + if tag == "-" { + continue + } + + // XXX Drop this after a few releases. + if s := strings.Index(tag, "/"); s >= 0 { + recommend := tag[:s] + for _, c := range tag[s+1:] { + switch c { + case 'c': + recommend += ",omitempty" + case 's': + recommend += ",minsize" + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", string([]byte{uint8(c)}), tag, st) + panic(externalPanic(msg)) + } + } + msg := fmt.Sprintf("Replace tag %q in field %s of type %s by %q", tag, field.Name, st, recommend) + panic(externalPanic(msg)) + } + + inline := false + fields := strings.Split(tag, ",") + if len(fields) > 1 { + for _, flag := range fields[1:] { + switch flag { + case "omitempty": + info.OmitEmpty = true + case "minsize": + info.MinSize = true + case "inline": + inline = true + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st) + panic(externalPanic(msg)) + } + } + tag = fields[0] + } + + if inline { + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("Multiple ,inline maps in struct " + st.String()) + } + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) + } + inlineMap = info.Num + case reflect.Struct: + sinfo, err := getStructInfo(field.Type) + if err != nil { + return nil, err + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + default: + panic("Option ,inline needs a struct value or map field") + } + continue + } + + if tag != "" { + info.Key = tag + } else { + info.Key = strings.ToLower(field.Name) + } + + if _, found = fieldsMap[info.Key]; found { + msg := "Duplicated key '" + info.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + + fieldsList = append(fieldsList, info) + fieldsMap[info.Key] = info + } + sinfo = &structInfo{ + fieldsMap, + fieldsList, + inlineMap, + reflect.New(st).Elem(), + } + structMapMutex.Lock() + structMap[st] = sinfo + structMapMutex.Unlock() + return sinfo, nil +} diff --git a/bson/bson_test.go b/bson/bson_test.go new file mode 100644 index 0000000..3d97998 --- /dev/null +++ b/bson/bson_test.go @@ -0,0 +1,1472 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson_test + +import ( + "encoding/binary" + "encoding/json" + "errors" + "net/url" + "reflect" + "testing" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2/bson" +) + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct{} + +var _ = Suite(&S{}) + +// Wrap up the document elements contained in data, prepending the int32 +// length of the data, and appending the '\x00' value closing the document. +func wrapInDoc(data string) string { + result := make([]byte, len(data)+5) + binary.LittleEndian.PutUint32(result, uint32(len(result))) + copy(result[4:], []byte(data)) + return string(result) +} + +func makeZeroDoc(value interface{}) (zero interface{}) { + v := reflect.ValueOf(value) + t := v.Type() + switch t.Kind() { + case reflect.Map: + mv := reflect.MakeMap(t) + zero = mv.Interface() + case reflect.Ptr: + pv := reflect.New(v.Type().Elem()) + zero = pv.Interface() + case reflect.Slice: + zero = reflect.New(t).Interface() + default: + panic("unsupported doc type") + } + return zero +} + +func testUnmarshal(c *C, data string, obj interface{}) { + zero := makeZeroDoc(obj) + err := bson.Unmarshal([]byte(data), zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, obj) +} + +type testItemType struct { + obj interface{} + data string +} + +// -------------------------------------------------------------------------- +// Samples from bsonspec.org: + +var sampleItems = []testItemType{ + {bson.M{"hello": "world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {bson.M{"BSON": []interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalSampleItems(c *C) { + for i, item := range sampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalSampleItems(c *C) { + for i, item := range sampleItems { + value := bson.M{} + err := bson.Unmarshal([]byte(item.data), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Every type, ordered by the type flag. These are not wrapped with the +// length and last \x00 from the document. wrapInDoc() computes them. +// Note that all of them should be supported as two-way conversions. + +var allItems = []testItemType{ + {bson.M{}, + ""}, + {bson.M{"_": float64(5.05)}, + "\x01_\x00333333\x14@"}, + {bson.M{"_": "yo"}, + "\x02_\x00\x03\x00\x00\x00yo\x00"}, + {bson.M{"_": bson.M{"a": true}}, + "\x03_\x00\x09\x00\x00\x00\x08a\x00\x01\x00"}, + {bson.M{"_": []interface{}{true, false}}, + "\x04_\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + {bson.M{"_": []byte("yo")}, + "\x05_\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"_": bson.Binary{0x80, []byte("udef")}}, + "\x05_\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"_": bson.Undefined}, // Obsolete, but still seen in the wild. + "\x06_\x00"}, + {bson.M{"_": bson.ObjectId("0123456789ab")}, + "\x07_\x000123456789ab"}, + {bson.M{"_": false}, + "\x08_\x00\x00"}, + {bson.M{"_": true}, + "\x08_\x00\x01"}, + {bson.M{"_": time.Unix(0, 258e6)}, // Note the NS <=> MS conversion. + "\x09_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": nil}, + "\x0A_\x00"}, + {bson.M{"_": bson.RegEx{"ab", "cd"}}, + "\x0B_\x00ab\x00cd\x00"}, + {bson.M{"_": bson.JavaScript{"code", nil}}, + "\x0D_\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"_": bson.Symbol("sym")}, + "\x0E_\x00\x04\x00\x00\x00sym\x00"}, + {bson.M{"_": bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F_\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + {bson.M{"_": 258}, + "\x10_\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MongoTimestamp(258)}, + "\x11_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258)}, + "\x12_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258 << 32)}, + "\x12_\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MaxKey}, + "\x7F_\x00"}, + {bson.M{"_": bson.MinKey}, + "\xFF_\x00"}, +} + +func (s *S) TestMarshalAllItems(c *C) { + for i, item := range allItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalAllItems(c *C) { + for i, item := range allItems { + value := bson.M{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawAllItems(c *C) { + for i, item := range allItems { + if len(item.data) == 0 { + continue + } + value := item.obj.(bson.M)["_"] + if value == nil { + continue + } + pv := reflect.New(reflect.ValueOf(value).Type()) + raw := bson.Raw{item.data[0], []byte(item.data[3:])} + c.Logf("Unmarshal raw: %#v, %#v", raw, pv.Interface()) + err := raw.Unmarshal(pv.Interface()) + c.Assert(err, IsNil) + c.Assert(pv.Elem().Interface(), DeepEquals, value, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawIncompatible(c *C) { + raw := bson.Raw{0x08, []byte{0x01}} // true + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, ErrorMatches, "BSON kind 0x08 isn't compatible with type struct \\{\\}") +} + +func (s *S) TestUnmarshalZeroesStruct(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + type T struct{ A, B int } + v := T{A: 1} + err = bson.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(v.A, Equals, 0) + c.Assert(v.B, Equals, 2) +} + +func (s *S) TestUnmarshalZeroesMap(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + err = bson.Unmarshal(data, &m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"b": 2}) +} + +func (s *S) TestUnmarshalNonNilInterface(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + var i interface{} + i = m + err = bson.Unmarshal(data, &i) + c.Assert(err, IsNil) + c.Assert(i, DeepEquals, bson.M{"b": 2}) + c.Assert(m, DeepEquals, bson.M{"a": 1}) +} + +// -------------------------------------------------------------------------- +// Some one way marshaling operations which would unmarshal differently. + +var oneWayMarshalItems = []testItemType{ + // These are being passed as pointers, and will unmarshal as values. + {bson.M{"": &bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + {bson.M{"": &bson.Binary{0x80, []byte("udef")}}, + "\x05\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"": &bson.RegEx{"ab", "cd"}}, + "\x0B\x00ab\x00cd\x00"}, + {bson.M{"": &bson.JavaScript{"code", nil}}, + "\x0D\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"": &bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + + // There's no float32 type in BSON. Will encode as a float64. + {bson.M{"": float32(5.05)}, + "\x01\x00\x00\x00\x00@33\x14@"}, + + // The array will be unmarshaled as a slice instead. + {bson.M{"": [2]bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // The typed slice will be unmarshaled as []interface{}. + {bson.M{"": []bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // Will unmarshal as a []byte. + {bson.M{"": bson.Binary{0x00, []byte("yo")}}, + "\x05\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"": bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // No way to preserve the type information here. We might encode as a zero + // value, but this would mean that pointer values in structs wouldn't be + // able to correctly distinguish between unset and set to the zero value. + {bson.M{"": (*byte)(nil)}, + "\x0A\x00"}, + + // No int types smaller than int32 in BSON. Could encode this as a char, + // but it would still be ambiguous, take more, and be awkward in Go when + // loaded without typing information. + {bson.M{"": byte(8)}, + "\x10\x00\x08\x00\x00\x00"}, + + // There are no unsigned types in BSON. Will unmarshal as int32 or int64. + {bson.M{"": uint32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + {bson.M{"": uint64(258)}, + "\x12\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"": uint64(258 << 32)}, + "\x12\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + + // This will unmarshal as int. + {bson.M{"": int32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + + // That's a special case. The unsigned value is too large for an int32, + // so an int64 is used instead. + {bson.M{"": uint32(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, + {bson.M{"": uint(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, +} + +func (s *S) TestOneWayMarshalItems(c *C) { + for i, item := range oneWayMarshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Two-way tests for user-defined structures using the samples +// from bsonspec.org. + +type specSample1 struct { + Hello string +} + +type specSample2 struct { + BSON []interface{} "BSON" +} + +var structSampleItems = []testItemType{ + {&specSample1{"world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {&specSample2{[]interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalStructSampleItems(c *C) { + for i, item := range structSampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructSampleItems(c *C) { + for _, item := range structSampleItems { + testUnmarshal(c, item.data, item.obj) + } +} + +func (s *S) Test64bitInt(c *C) { + var i int64 = (1 << 31) + if int(i) > 0 { + data, err := bson.Marshal(bson.M{"i": int(i)}) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00")) + + var result struct{ I int } + err = bson.Unmarshal(data, &result) + c.Assert(err, IsNil) + c.Assert(int64(result.I), Equals, i) + } +} + +// -------------------------------------------------------------------------- +// Generic two-way struct marshaling tests. + +var bytevar = byte(8) +var byteptr = &bytevar + +var structItems = []testItemType{ + {&struct{ Ptr *byte }{nil}, + "\x0Aptr\x00"}, + {&struct{ Ptr *byte }{&bytevar}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Ptr **byte }{&byteptr}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{8}, + "\x10byte\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{0}, + "\x10byte\x00\x00\x00\x00\x00"}, + {&struct { + V byte "Tag" + }{8}, + "\x10Tag\x00\x08\x00\x00\x00"}, + {&struct { + V *struct { + Byte byte + } + }{&struct{ Byte byte }{8}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ priv byte }{}, ""}, + + // The order of the dumped fields should be the same in the struct. + {&struct{ A, C, B, D, F, E *byte }{}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x0Ae\x00"}, + + {&struct{ V bson.Raw }{bson.Raw{0x03, []byte("\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00")}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ V bson.Raw }{bson.Raw{0x10, []byte("\x00\x00\x00\x00")}}, + "\x10v\x00" + "\x00\x00\x00\x00"}, + + // Byte arrays. + {&struct{ V [2]byte }{[2]byte{'y', 'o'}}, + "\x05v\x00\x02\x00\x00\x00\x00yo"}, +} + +func (s *S) TestMarshalStructItems(c *C) { + for i, item := range structItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructItems(c *C) { + for _, item := range structItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalRawStructItems(c *C) { + for i, item := range structItems { + raw := bson.Raw{0x03, []byte(wrapInDoc(item.data))} + zero := makeZeroDoc(item.obj) + err := raw.Unmarshal(zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawNil(c *C) { + // Regression test: shouldn't try to nil out the pointer itself, + // as it's not settable. + raw := bson.Raw{0x0A, []byte{}} + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, IsNil) +} + +// -------------------------------------------------------------------------- +// One-way marshaling tests. + +type dOnIface struct { + D interface{} +} + +type ignoreField struct { + Before string + Ignore string `bson:"-"` + After string +} + +var marshalItems = []testItemType{ + // Ordered document dump. Will unmarshal as a dictionary by default. + {bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {MyD{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {&dOnIface{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + {bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {MyRawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {&dOnIface{bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00"+"\x0Ac\x00"+"\x08b\x00\x01")}, + + {&ignoreField{"before", "ignore", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Marshalling a Raw document does nothing. + {bson.Raw{0x03, []byte(wrapInDoc("anything"))}, + "anything"}, + {bson.Raw{Data: []byte(wrapInDoc("anything"))}, + "anything"}, +} + +func (s *S) TestMarshalOneWayItems(c *C) { + for _, item := range marshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data)) + } +} + +// -------------------------------------------------------------------------- +// One-way unmarshaling tests. + +var unmarshalItems = []testItemType{ + // Field is private. Should not attempt to unmarshal it. + {&struct{ priv byte }{}, + "\x10priv\x00\x08\x00\x00\x00"}, + + // Wrong casing. Field names are lowercased. + {&struct{ Byte byte }{}, + "\x10Byte\x00\x08\x00\x00\x00"}, + + // Ignore non-existing field. + {&struct{ Byte byte }{9}, + "\x10boot\x00\x08\x00\x00\x00" + "\x10byte\x00\x09\x00\x00\x00"}, + + // Do not unmarshal on ignored field. + {&ignoreField{"before", "", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00" + + "\x02-\x00\a\x00\x00\x00ignore\x00" + + "\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Ignore unsuitable types silently. + {map[string]string{"str": "s"}, + "\x02str\x00\x02\x00\x00\x00s\x00" + "\x10int\x00\x01\x00\x00\x00"}, + {map[string][]int{"array": []int{5, 9}}, + "\x04array\x00" + wrapInDoc("\x100\x00\x05\x00\x00\x00"+"\x021\x00\x02\x00\x00\x00s\x00"+"\x102\x00\x09\x00\x00\x00")}, + + // Wrong type. Shouldn't init pointer. + {&struct{ Str *byte }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + {&struct{ Str *struct{ Str string } }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + + // Ordered document. + {&struct{ bson.D }{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + // Raw document. + {&bson.Raw{0x03, []byte(wrapInDoc("\x10byte\x00\x08\x00\x00\x00"))}, + "\x10byte\x00\x08\x00\x00\x00"}, + + // RawD document. + {&struct{ bson.RawD }{bson.RawD{{"a", bson.Raw{0x0A, []byte{}}}, {"c", bson.Raw{0x0A, []byte{}}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03rawd\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x08b\x00\x01")}, + + // Decode old binary. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // Decode old binary without length. According to the spec, this shouldn't happen. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x03\x00\x00\x00\x02old"}, +} + +func (s *S) TestUnmarshalOneWayItems(c *C) { + for _, item := range unmarshalItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalNilInStruct(c *C) { + // Nil is the default value, so we need to ensure it's indeed being set. + b := byte(1) + v := &struct{ Ptr *byte }{&b} + err := bson.Unmarshal([]byte(wrapInDoc("\x0Aptr\x00")), v) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, &struct{ Ptr *byte }{nil}) +} + +// -------------------------------------------------------------------------- +// Marshalling error cases. + +type structWithDupKeys struct { + Name byte + Other byte "name" // Tag should precede. +} + +var marshalErrorItems = []testItemType{ + {bson.M{"": uint64(1 << 63)}, + "BSON has no uint64 type, and value is too large to fit correctly in an int64"}, + {bson.M{"": bson.ObjectId("tooshort")}, + "ObjectIDs must be exactly 12 bytes long \\(got 8\\)"}, + {int64(123), + "Can't marshal int64 as a BSON document"}, + {bson.M{"": 1i}, + "Can't marshal complex128 in a BSON document"}, + {&structWithDupKeys{}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + {bson.Raw{0x0A, []byte{}}, + "Attempted to unmarshal Raw kind 10 as a document"}, + {&inlineCantPtr{&struct{ A, B int }{1, 2}}, + "Option ,inline needs a struct value or map field"}, + {&inlineDupName{1, struct{ A, B int }{2, 3}}, + "Duplicated key 'a' in struct bson_test.inlineDupName"}, + {&inlineDupMap{}, + "Multiple ,inline maps in struct bson_test.inlineDupMap"}, + {&inlineBadKeyMap{}, + "Option ,inline needs a map with string keys in struct bson_test.inlineBadKeyMap"}, + {&inlineMap{A: 1, M: map[string]interface{}{"a": 1}}, + `Can't have key "a" in inlined map; conflicts with struct field`}, +} + +func (s *S) TestMarshalErrorItems(c *C) { + for _, item := range marshalErrorItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, ErrorMatches, item.data) + c.Assert(data, IsNil) + } +} + +// -------------------------------------------------------------------------- +// Unmarshalling error cases. + +type unmarshalErrorType struct { + obj interface{} + data string + error string +} + +var unmarshalErrorItems = []unmarshalErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + "\x10name\x00\x08\x00\x00\x00", + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + // Non-string map key. + {map[int]interface{}{}, + "\x10name\x00\x08\x00\x00\x00", + "BSON map must have string keys. Got: map\\[int\\]interface \\{\\}"}, + + {nil, + "\xEEname\x00", + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal needs a map or a pointer to a struct."}, +} + +func (s *S) TestUnmarshalErrorItems(c *C) { + for _, item := range unmarshalErrorItems { + data := []byte(wrapInDoc(item.data)) + var value interface{} + switch reflect.ValueOf(item.obj).Kind() { + case reflect.Map, reflect.Ptr: + value = makeZeroDoc(item.obj) + case reflect.Invalid: + value = bson.M{} + default: + value = item.obj + } + err := bson.Unmarshal(data, value) + c.Assert(err, ErrorMatches, item.error) + } +} + +type unmarshalRawErrorType struct { + obj interface{} + raw bson.Raw + error string +} + +var unmarshalRawErrorItems = []unmarshalRawErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + bson.Raw{0x03, []byte("\x10byte\x00\x08\x00\x00\x00")}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + {&struct{}{}, + bson.Raw{0xEE, []byte{}}, + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal needs a map or a valid pointer."}, +} + +func (s *S) TestUnmarshalRawErrorItems(c *C) { + for i, item := range unmarshalRawErrorItems { + err := item.raw.Unmarshal(item.obj) + c.Assert(err, ErrorMatches, item.error, Commentf("Failed on item %d: %#v\n", i, item)) + } +} + +var corruptedData = []string{ + "\x04\x00\x00\x00\x00", // Shorter than minimum + "\x06\x00\x00\x00\x00", // Not enough data + "\x05\x00\x00", // Broken length + "\x05\x00\x00\x00\xff", // Corrupted termination + "\x0A\x00\x00\x00\x0Aooop\x00", // Unfinished C string + + // Array end past end of string (s[2]=0x07 is correct) + wrapInDoc("\x04\x00\x09\x00\x00\x00\x0A\x00\x00"), + + // Array end within string, but past acceptable. + wrapInDoc("\x04\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // Document end within string, but past acceptable. + wrapInDoc("\x03\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // String with corrupted end. + wrapInDoc("\x02\x00\x03\x00\x00\x00yo\xFF"), +} + +func (s *S) TestUnmarshalMapDocumentTooShort(c *C) { + for _, data := range corruptedData { + err := bson.Unmarshal([]byte(data), bson.M{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + + err = bson.Unmarshal([]byte(data), &struct{}{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + } +} + +// -------------------------------------------------------------------------- +// Setter test cases. + +var setterResult = map[string]error{} + +type setterType struct { + received interface{} +} + +func (o *setterType) SetBSON(raw bson.Raw) error { + err := raw.Unmarshal(&o.received) + if err != nil { + panic("The panic:" + err.Error()) + } + if s, ok := o.received.(string); ok { + if result, ok := setterResult[s]; ok { + return result + } + } + return nil +} + +type ptrSetterDoc struct { + Field *setterType "_" +} + +type valSetterDoc struct { + Field setterType "_" +} + +func (s *S) TestUnmarshalAllItemsWithPtrSetter(c *C) { + for _, item := range allItems { + for i := 0; i != 2; i++ { + var field *setterType + if i == 0 { + obj := &ptrSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = obj.Field + } else { + obj := &valSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = &obj.Field + } + if item.data == "" { + // Nothing to unmarshal. Should be untouched. + if i == 0 { + c.Assert(field, IsNil) + } else { + c.Assert(field.received, IsNil) + } + } else { + expected := item.obj.(bson.M)["_"] + c.Assert(field, NotNil, Commentf("Pointer not initialized (%#v)", expected)) + c.Assert(field.received, DeepEquals, expected) + } + } + } +} + +func (s *S) TestUnmarshalWholeDocumentWithSetter(c *C) { + obj := &setterType{} + err := bson.Unmarshal([]byte(sampleItems[0].data), obj) + c.Assert(err, IsNil) + c.Assert(obj.received, DeepEquals, bson.M{"hello": "world"}) +} + +func (s *S) TestUnmarshalSetterOmits(c *C) { + setterResult["2"] = &bson.TypeError{} + setterResult["4"] = &bson.TypeError{} + defer func() { + delete(setterResult, "2") + delete(setterResult, "4") + }() + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00" + + "\x02jkl\x00\x02\x00\x00\x004\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], NotNil) + c.Assert(m["jkl"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") + c.Assert(m["ghi"].received, Equals, "3") +} + +func (s *S) TestUnmarshalSetterErrors(c *C) { + boom := errors.New("BOOM") + setterResult["2"] = boom + defer delete(setterResult, "2") + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, Equals, boom) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") +} + +func (s *S) TestDMap(c *C) { + d := bson.D{{"a", 1}, {"b", 2}} + c.Assert(d.Map(), DeepEquals, bson.M{"a": 1, "b": 2}) +} + +func (s *S) TestUnmarshalSetterSetZero(c *C) { + setterResult["foo"] = bson.SetZero + defer delete(setterResult, "field") + + data, err := bson.Marshal(bson.M{"field": "foo"}) + c.Assert(err, IsNil) + + m := map[string]*setterType{} + err = bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + + value, ok := m["field"] + c.Assert(ok, Equals, true) + c.Assert(value, IsNil) +} + +// -------------------------------------------------------------------------- +// Getter test cases. + +type typeWithGetter struct { + result interface{} + err error +} + +func (t *typeWithGetter) GetBSON() (interface{}, error) { + if t == nil { + return "", nil + } + return t.result, t.err +} + +type docWithGetterField struct { + Field *typeWithGetter "_" +} + +func (s *S) TestMarshalAllItemsWithGetter(c *C) { + for i, item := range allItems { + if item.data == "" { + continue + } + obj := &docWithGetterField{} + obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item #%d", i)) + } +} + +func (s *S) TestMarshalWholeDocumentWithGetter(c *C) { + obj := &typeWithGetter{result: sampleItems[0].obj} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, sampleItems[0].data) +} + +func (s *S) TestGetterErrors(c *C) { + e := errors.New("oops") + + obj1 := &docWithGetterField{} + obj1.Field = &typeWithGetter{sampleItems[0].obj, e} + data, err := bson.Marshal(obj1) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) + + obj2 := &typeWithGetter{sampleItems[0].obj, e} + data, err = bson.Marshal(obj2) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) +} + +type intGetter int64 + +func (t intGetter) GetBSON() (interface{}, error) { + return int64(t), nil +} + +type typeWithIntGetter struct { + V intGetter ",minsize" +} + +func (s *S) TestMarshalShortWithGetter(c *C) { + obj := typeWithIntGetter{42} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(err, IsNil) + c.Assert(m["v"], Equals, 42) +} + +func (s *S) TestMarshalWithGetterNil(c *C) { + obj := docWithGetterField{} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"_": ""}) +} + +// -------------------------------------------------------------------------- +// Cross-type conversion tests. + +type crossTypeItem struct { + obj1 interface{} + obj2 interface{} +} + +type condStr struct { + V string ",omitempty" +} +type condStrNS struct { + V string `a:"A" bson:",omitempty" b:"B"` +} +type condBool struct { + V bool ",omitempty" +} +type condInt struct { + V int ",omitempty" +} +type condUInt struct { + V uint ",omitempty" +} +type condFloat struct { + V float64 ",omitempty" +} +type condIface struct { + V interface{} ",omitempty" +} +type condPtr struct { + V *bool ",omitempty" +} +type condSlice struct { + V []string ",omitempty" +} +type condMap struct { + V map[string]int ",omitempty" +} +type namedCondStr struct { + V string "myv,omitempty" +} +type condTime struct { + V time.Time ",omitempty" +} +type condStruct struct { + V struct{ A []int } ",omitempty" +} + +type shortInt struct { + V int64 ",minsize" +} +type shortUint struct { + V uint64 ",minsize" +} +type shortIface struct { + V interface{} ",minsize" +} +type shortPtr struct { + V *int64 ",minsize" +} +type shortNonEmptyInt struct { + V int64 ",minsize,omitempty" +} + +type inlineInt struct { + V struct{ A, B int } ",inline" +} +type inlineCantPtr struct { + V *struct{ A, B int } ",inline" +} +type inlineDupName struct { + A int + V struct{ A, B int } ",inline" +} +type inlineMap struct { + A int + M map[string]interface{} ",inline" +} +type inlineMapInt struct { + A int + M map[string]int ",inline" +} +type inlineMapMyM struct { + A int + M MyM ",inline" +} +type inlineDupMap struct { + M1 map[string]interface{} ",inline" + M2 map[string]interface{} ",inline" +} +type inlineBadKeyMap struct { + M map[int]int ",inline" +} + +type ( + MyString string + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} +) + +var ( + truevar = true + falsevar = false + + int64var = int64(42) + int64ptr = &int64var + intvar = int(42) + intptr = &intvar +) + +func parseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +// That's a pretty fun test. It will dump the first item, generate a zero +// value equivalent to the second one, load the dumped data onto it, and then +// verify that the resulting value is deep-equal to the untouched second value. +// Then, it will do the same in the *opposite* direction! +var twoWayCrossItems = []crossTypeItem{ + // int<=>int + {&struct{ I int }{42}, &struct{ I int8 }{42}}, + {&struct{ I int }{42}, &struct{ I int32 }{42}}, + {&struct{ I int }{42}, &struct{ I int64 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I int32 }{42}, &struct{ I int64 }{42}}, + + // uint<=>uint + {&struct{ I uint }{42}, &struct{ I uint8 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I uint64 }{42}}, + + // float32<=>float64 + {&struct{ I float32 }{42}, &struct{ I float64 }{42}}, + + // int<=>uint + {&struct{ I uint }{42}, &struct{ I int }{42}}, + {&struct{ I uint }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int64 }{42}}, + + // int <=> float + {&struct{ I int }{42}, &struct{ I float64 }{42}}, + + // int <=> bool + {&struct{ I int }{1}, &struct{ I bool }{true}}, + {&struct{ I int }{0}, &struct{ I bool }{false}}, + + // uint <=> float64 + {&struct{ I uint }{42}, &struct{ I float64 }{42}}, + + // uint <=> bool + {&struct{ I uint }{1}, &struct{ I bool }{true}}, + {&struct{ I uint }{0}, &struct{ I bool }{false}}, + + // float64 <=> bool + {&struct{ I float64 }{1}, &struct{ I bool }{true}}, + {&struct{ I float64 }{0}, &struct{ I bool }{false}}, + + // string <=> string and string <=> []byte + {&struct{ S []byte }{[]byte("abc")}, &struct{ S string }{"abc"}}, + {&struct{ S []byte }{[]byte("def")}, &struct{ S bson.Symbol }{"def"}}, + {&struct{ S string }{"ghi"}, &struct{ S bson.Symbol }{"ghi"}}, + + // map <=> struct + {&struct { + A struct { + B, C int + } + }{struct{ B, C int }{1, 2}}, + map[string]map[string]int{"a": map[string]int{"b": 1, "c": 2}}}, + + {&struct{ A bson.Symbol }{"abc"}, map[string]string{"a": "abc"}}, + {&struct{ A bson.Symbol }{"abc"}, map[string][]byte{"a": []byte("abc")}}, + {&struct{ A []byte }{[]byte("abc")}, map[string]string{"a": "abc"}}, + {&struct{ A uint }{42}, map[string]int{"a": 42}}, + {&struct{ A uint }{42}, map[string]float64{"a": 42}}, + {&struct{ A uint }{1}, map[string]bool{"a": true}}, + {&struct{ A int }{42}, map[string]uint{"a": 42}}, + {&struct{ A int }{42}, map[string]float64{"a": 42}}, + {&struct{ A int }{1}, map[string]bool{"a": true}}, + {&struct{ A float64 }{42}, map[string]float32{"a": 42}}, + {&struct{ A float64 }{42}, map[string]int{"a": 42}}, + {&struct{ A float64 }{42}, map[string]uint{"a": 42}}, + {&struct{ A float64 }{1}, map[string]bool{"a": true}}, + {&struct{ A bool }{true}, map[string]int{"a": 1}}, + {&struct{ A bool }{true}, map[string]uint{"a": 1}}, + {&struct{ A bool }{true}, map[string]float64{"a": 1}}, + {&struct{ A **byte }{&byteptr}, map[string]byte{"a": 8}}, + + // url.URL <=> string + {&struct{ URL *url.URL }{parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + {&struct{ URL url.URL }{*parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + + // Slices + {&struct{ S []int }{[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + {&struct{ S *[]int }{&[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + + // Conditionals + {&condBool{true}, map[string]bool{"v": true}}, + {&condBool{}, map[string]bool{}}, + {&condInt{1}, map[string]int{"v": 1}}, + {&condInt{}, map[string]int{}}, + {&condUInt{1}, map[string]uint{"v": 1}}, + {&condUInt{}, map[string]uint{}}, + {&condFloat{}, map[string]int{}}, + {&condStr{"yo"}, map[string]string{"v": "yo"}}, + {&condStr{}, map[string]string{}}, + {&condStrNS{"yo"}, map[string]string{"v": "yo"}}, + {&condStrNS{}, map[string]string{}}, + {&condSlice{[]string{"yo"}}, map[string][]string{"v": []string{"yo"}}}, + {&condSlice{}, map[string][]string{}}, + {&condMap{map[string]int{"k": 1}}, bson.M{"v": bson.M{"k": 1}}}, + {&condMap{}, map[string][]string{}}, + {&condIface{"yo"}, map[string]string{"v": "yo"}}, + {&condIface{""}, map[string]string{"v": ""}}, + {&condIface{}, map[string]string{}}, + {&condPtr{&truevar}, map[string]bool{"v": true}}, + {&condPtr{&falsevar}, map[string]bool{"v": false}}, + {&condPtr{}, map[string]string{}}, + + {&condTime{time.Unix(123456789, 123e6)}, map[string]time.Time{"v": time.Unix(123456789, 123e6)}}, + {&condTime{}, map[string]string{}}, + + {&condStruct{struct{ A []int }{[]int{1}}}, bson.M{"v": bson.M{"a": []interface{}{1}}}}, + {&condStruct{struct{ A []int }{}}, bson.M{}}, + + {&namedCondStr{"yo"}, map[string]string{"myv": "yo"}}, + {&namedCondStr{}, map[string]string{}}, + + {&shortInt{1}, map[string]interface{}{"v": 1}}, + {&shortInt{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortUint{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortUint{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortIface{int64(1) << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortPtr{int64ptr}, map[string]interface{}{"v": intvar}}, + + {&shortNonEmptyInt{1}, map[string]interface{}{"v": 1}}, + {&shortNonEmptyInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortNonEmptyInt{}, map[string]interface{}{}}, + + {&inlineInt{struct{ A, B int }{1, 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: map[string]interface{}{"b": 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: nil}, map[string]interface{}{"a": 1}}, + {&inlineMapInt{A: 1, M: map[string]int{"b": 2}}, map[string]int{"a": 1, "b": 2}}, + {&inlineMapInt{A: 1, M: nil}, map[string]int{"a": 1}}, + {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}}, + + // []byte <=> MyBytes + {&struct{ B MyBytes }{[]byte("abc")}, map[string]string{"b": "abc"}}, + {&struct{ B MyBytes }{[]byte{}}, map[string]string{"b": ""}}, + {&struct{ B MyBytes }{}, map[string]bool{}}, + {&struct{ B []byte }{[]byte("abc")}, map[string]MyBytes{"b": []byte("abc")}}, + + // bool <=> MyBool + {&struct{ B MyBool }{true}, map[string]bool{"b": true}}, + {&struct{ B MyBool }{}, map[string]bool{"b": false}}, + {&struct{ B MyBool }{}, map[string]string{}}, + {&struct{ B bool }{}, map[string]MyBool{"b": false}}, + + // arrays + {&struct{ V [2]int }{[...]int{1, 2}}, map[string][2]int{"v": [2]int{1, 2}}}, + + // zero time + {&struct{ V time.Time }{}, map[string]interface{}{"v": time.Time{}}}, + + // zero time + 1 second + 1 millisecond; overflows int64 as nanoseconds + {&struct{ V time.Time }{time.Unix(-62135596799, 1e6).Local()}, + map[string]interface{}{"v": time.Unix(-62135596799, 1e6).Local()}}, + + // bson.D <=> []DocElem + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}}, + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}}, + + // bson.RawD <=> []RawDocElem + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &MyRawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + + // bson.M <=> map + {bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": MyM{"b": 1, "c": 2}}}, + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}}, + + // bson.M <=> map[MyString] + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}}, + + // json.Number <=> int64, float64 + {&struct{ N json.Number }{"5"}, map[string]interface{}{"n": int64(5)}}, + {&struct{ N json.Number }{"5.05"}, map[string]interface{}{"n": 5.05}}, + {&struct{ N json.Number }{"9223372036854776000"}, map[string]interface{}{"n": float64(1 << 63)}}, +} + +// Same thing, but only one way (obj1 => obj2). +var oneWayCrossItems = []crossTypeItem{ + // map <=> struct + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, map[string]int{"a": 1, "c": 3}}, + + // inline map elides badly typed values + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, &inlineMapInt{A: 1, M: map[string]int{"c": 3}}}, + + // Can't decode int into struct. + {bson.M{"a": bson.M{"b": 2}}, &struct{ A bool }{}}, + + // Would get decoded into a int32 too in the opposite direction. + {&shortIface{int64(1) << 30}, map[string]interface{}{"v": 1 << 30}}, +} + +func testCrossPair(c *C, dump interface{}, load interface{}) { + c.Logf("Dump: %#v", dump) + c.Logf("Load: %#v", load) + zero := makeZeroDoc(load) + data, err := bson.Marshal(dump) + c.Assert(err, IsNil) + c.Logf("Dumped: %#v", string(data)) + err = bson.Unmarshal(data, zero) + c.Assert(err, IsNil) + c.Logf("Loaded: %#v", zero) + c.Assert(zero, DeepEquals, load) +} + +func (s *S) TestTwoWayCrossPairs(c *C) { + for _, item := range twoWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + testCrossPair(c, item.obj2, item.obj1) + } +} + +func (s *S) TestOneWayCrossPairs(c *C) { + for _, item := range oneWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + } +} + +// -------------------------------------------------------------------------- +// ObjectId hex representation test. + +func (s *S) TestObjectIdHex(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + c.Assert(id.String(), Equals, `ObjectIdHex("4d88e15b60f486e428412dc9")`) + c.Assert(id.Hex(), Equals, "4d88e15b60f486e428412dc9") +} + +func (s *S) TestIsObjectIdHex(c *C) { + test := []struct { + id string + valid bool + }{ + {"4d88e15b60f486e428412dc9", true}, + {"4d88e15b60f486e428412dc", false}, + {"4d88e15b60f486e428412dc9e", false}, + {"4d88e15b60f486e428412dcx", false}, + } + for _, t := range test { + c.Assert(bson.IsObjectIdHex(t.id), Equals, t.valid) + } +} + +// -------------------------------------------------------------------------- +// ObjectId parts extraction tests. + +type objectIdParts struct { + id bson.ObjectId + timestamp int64 + machine []byte + pid uint16 + counter int32 +} + +var objectIds = []objectIdParts{ + objectIdParts{ + bson.ObjectIdHex("4d88e15b60f486e428412dc9"), + 1300816219, + []byte{0x60, 0xf4, 0x86}, + 0xe428, + 4271561, + }, + objectIdParts{ + bson.ObjectIdHex("000000000000000000000000"), + 0, + []byte{0x00, 0x00, 0x00}, + 0x0000, + 0, + }, + objectIdParts{ + bson.ObjectIdHex("00000000aabbccddee000001"), + 0, + []byte{0xaa, 0xbb, 0xcc}, + 0xddee, + 1, + }, +} + +func (s *S) TestObjectIdPartsExtraction(c *C) { + for i, v := range objectIds { + t := time.Unix(v.timestamp, 0) + c.Assert(v.id.Time(), Equals, t, Commentf("#%d Wrong timestamp value", i)) + c.Assert(v.id.Machine(), DeepEquals, v.machine, Commentf("#%d Wrong machine id value", i)) + c.Assert(v.id.Pid(), Equals, v.pid, Commentf("#%d Wrong pid value", i)) + c.Assert(v.id.Counter(), Equals, v.counter, Commentf("#%d Wrong counter value", i)) + } +} + +func (s *S) TestNow(c *C) { + before := time.Now() + time.Sleep(1e6) + now := bson.Now() + time.Sleep(1e6) + after := time.Now() + c.Assert(now.After(before) && now.Before(after), Equals, true, Commentf("now=%s, before=%s, after=%s", now, before, after)) +} + +// -------------------------------------------------------------------------- +// ObjectId generation tests. + +func (s *S) TestNewObjectId(c *C) { + // Generate 10 ids + ids := make([]bson.ObjectId, 10) + for i := 0; i < 10; i++ { + ids[i] = bson.NewObjectId() + } + for i := 1; i < 10; i++ { + prevId := ids[i-1] + id := ids[i] + // Test for uniqueness among all other 9 generated ids + for j, tid := range ids { + if j != i { + c.Assert(id, Not(Equals), tid, Commentf("Generated ObjectId is not unique")) + } + } + // Check that timestamp was incremented and is within 30 seconds of the previous one + secs := id.Time().Sub(prevId.Time()).Seconds() + c.Assert((secs >= 0 && secs <= 30), Equals, true, Commentf("Wrong timestamp in generated ObjectId")) + // Check that machine ids are the same + c.Assert(id.Machine(), DeepEquals, prevId.Machine()) + // Check that pids are the same + c.Assert(id.Pid(), Equals, prevId.Pid()) + // Test for proper increment + delta := int(id.Counter() - prevId.Counter()) + c.Assert(delta, Equals, 1, Commentf("Wrong increment in generated ObjectId")) + } +} + +func (s *S) TestNewObjectIdWithTime(c *C) { + t := time.Unix(12345678, 0) + id := bson.NewObjectIdWithTime(t) + c.Assert(id.Time(), Equals, t) + c.Assert(id.Machine(), DeepEquals, []byte{0x00, 0x00, 0x00}) + c.Assert(int(id.Pid()), Equals, 0) + c.Assert(int(id.Counter()), Equals, 0) +} + +// -------------------------------------------------------------------------- +// ObjectId JSON marshalling. + +type jsonType struct { + Id *bson.ObjectId +} + +func (s *S) TestObjectIdJSONMarshaling(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + v := jsonType{Id: &id} + data, err := json.Marshal(&v) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, `{"Id":"4d88e15b60f486e428412dc9"}`) +} + +func (s *S) TestObjectIdJSONUnmarshaling(c *C) { + data := []byte(`{"Id":"4d88e15b60f486e428412dc9"}`) + v := jsonType{} + err := json.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(*v.Id, Equals, bson.ObjectIdHex("4d88e15b60f486e428412dc9")) +} + +func (s *S) TestObjectIdJSONUnmarshalingError(c *C) { + v := jsonType{} + err := json.Unmarshal([]byte(`{"Id":"4d88e15b60f486e428412dc9A"}`), &v) + c.Assert(err, ErrorMatches, `Invalid ObjectId in JSON: "4d88e15b60f486e428412dc9A"`) + err = json.Unmarshal([]byte(`{"Id":"4d88e15b60f486e428412dcZ"}`), &v) + c.Assert(err, ErrorMatches, `Invalid ObjectId in JSON: "4d88e15b60f486e428412dcZ" .*`) +} + +// -------------------------------------------------------------------------- +// Some simple benchmarks. + +type BenchT struct { + A, B, C, D, E, F string +} + +func BenchmarkUnmarhsalStruct(b *testing.B) { + v := BenchT{A: "A", D: "D", E: "E"} + data, err := bson.Marshal(&v) + if err != nil { + panic(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = bson.Unmarshal(data, &v) + } + if err != nil { + panic(err) + } +} + +func BenchmarkUnmarhsalMap(b *testing.B) { + m := bson.M{"a": "a", "d": "d", "e": "e"} + data, err := bson.Marshal(&m) + if err != nil { + panic(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = bson.Unmarshal(data, &m) + } + if err != nil { + panic(err) + } +} diff --git a/bson/decode.go b/bson/decode.go new file mode 100644 index 0000000..f1c8b4f --- /dev/null +++ b/bson/decode.go @@ -0,0 +1,806 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "sync" + "time" +) + +type decoder struct { + in []byte + i int + docType reflect.Type +} + +var typeM = reflect.TypeOf(M{}) + +func newDecoder(in []byte) *decoder { + return &decoder{in, 0, typeM} +} + +// -------------------------------------------------------------------------- +// Some helper functions. + +func corrupted() { + panic("Document is corrupted") +} + +func settableValueOf(i interface{}) reflect.Value { + v := reflect.ValueOf(i) + sv := reflect.New(v.Type()).Elem() + sv.Set(v) + return sv +} + +// -------------------------------------------------------------------------- +// Unmarshaling of documents. + +const ( + setterUnknown = iota + setterNone + setterType + setterAddr +) + +var setterStyle map[reflect.Type]int +var setterIface reflect.Type +var setterMutex sync.RWMutex + +func init() { + var iface Setter + setterIface = reflect.TypeOf(&iface).Elem() + setterStyle = make(map[reflect.Type]int) +} + +func getSetter(outt reflect.Type, out reflect.Value) Setter { + setterMutex.RLock() + style := setterStyle[outt] + setterMutex.RUnlock() + if style == setterNone { + return nil + } + if style == setterUnknown { + setterMutex.Lock() + defer setterMutex.Unlock() + if outt.Implements(setterIface) { + setterStyle[outt] = setterType + } else if reflect.PtrTo(outt).Implements(setterIface) { + setterStyle[outt] = setterAddr + } else { + setterStyle[outt] = setterNone + return nil + } + style = setterStyle[outt] + } + if style == setterAddr { + if !out.CanAddr() { + return nil + } + out = out.Addr() + } else if outt.Kind() == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + return out.Interface().(Setter) +} + +func clearMap(m reflect.Value) { + var none reflect.Value + for _, k := range m.MapKeys() { + m.SetMapIndex(k, none) + } +} + +func (d *decoder) readDocTo(out reflect.Value) { + var elemType reflect.Type + outt := out.Type() + outk := outt.Kind() + + for { + if outk == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + if setter := getSetter(outt, out); setter != nil { + var raw Raw + d.readDocTo(reflect.ValueOf(&raw)) + err := setter.SetBSON(raw) + if _, ok := err.(*TypeError); err != nil && !ok { + panic(err) + } + return + } + if outk == reflect.Ptr { + out = out.Elem() + outt = out.Type() + outk = out.Kind() + continue + } + break + } + + var fieldsMap map[string]fieldInfo + var inlineMap reflect.Value + start := d.i + + origout := out + if outk == reflect.Interface { + if d.docType.Kind() == reflect.Map { + mv := reflect.MakeMap(d.docType) + out.Set(mv) + out = mv + } else { + dv := reflect.New(d.docType).Elem() + out.Set(dv) + out = dv + } + outt = out.Type() + outk = outt.Kind() + } + + docType := d.docType + keyType := typeString + convertKey := false + switch outk { + case reflect.Map: + keyType = outt.Key() + if keyType.Kind() != reflect.String { + panic("BSON map must have string keys. Got: " + outt.String()) + } + if keyType != typeString { + convertKey = true + } + elemType = outt.Elem() + if elemType == typeIface { + d.docType = outt + } + if out.IsNil() { + out.Set(reflect.MakeMap(out.Type())) + } else if out.Len() > 0 { + clearMap(out) + } + case reflect.Struct: + if outt != typeRaw { + sinfo, err := getStructInfo(out.Type()) + if err != nil { + panic(err) + } + fieldsMap = sinfo.FieldsMap + out.Set(sinfo.Zero) + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + if !inlineMap.IsNil() && inlineMap.Len() > 0 { + clearMap(inlineMap) + } + elemType = inlineMap.Type().Elem() + if elemType == typeIface { + d.docType = inlineMap.Type() + } + } + } + case reflect.Slice: + switch outt.Elem() { + case typeDocElem: + origout.Set(d.readDocElems(outt)) + return + case typeRawDocElem: + origout.Set(d.readRawDocElems(outt)) + return + } + fallthrough + default: + panic("Unsupported document type for unmarshalling: " + out.Type().String()) + } + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + + switch outk { + case reflect.Map: + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + k := reflect.ValueOf(name) + if convertKey { + k = k.Convert(keyType) + } + out.SetMapIndex(k, e) + } + case reflect.Struct: + if outt == typeRaw { + d.dropElem(kind) + } else { + if info, ok := fieldsMap[name]; ok { + if info.Inline == nil { + d.readElemTo(out.Field(info.Num), kind) + } else { + d.readElemTo(out.FieldByIndex(info.Inline), kind) + } + } else if inlineMap.IsValid() { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + inlineMap.SetMapIndex(reflect.ValueOf(name), e) + } + } else { + d.dropElem(kind) + } + } + case reflect.Slice: + } + + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + d.docType = docType + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) + } +} + +func (d *decoder) readArrayDocTo(out reflect.Value) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + i := 0 + l := out.Len() + for d.in[d.i] != '\x00' { + if i >= l { + panic("Length mismatch on array field") + } + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + d.readElemTo(out.Index(i), kind) + if d.i >= end { + corrupted() + } + i++ + } + if i != l { + panic("Length mismatch on array field") + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +func (d *decoder) readSliceDoc(t reflect.Type) interface{} { + tmp := make([]reflect.Value, 0, 8) + elemType := t.Elem() + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + tmp = append(tmp, e) + } + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + + n := len(tmp) + slice := reflect.MakeSlice(t, n, n) + for i := 0; i != n; i++ { + slice.Index(i).Set(tmp[i]) + } + return slice.Interface() +} + +var typeSlice = reflect.TypeOf([]interface{}{}) +var typeIface = typeSlice.Elem() + +func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]DocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := DocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]RawDocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := RawDocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readDocWith(f func(kind byte, name string)) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + f(kind, name) + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +// -------------------------------------------------------------------------- +// Unmarshaling of individual elements within a document. + +var blackHole = settableValueOf(struct{}{}) + +func (d *decoder) dropElem(kind byte) { + d.readElemTo(blackHole, kind) +} + +// Attempt to decode an element from the document and put it into out. +// If the types are not compatible, the returned ok value will be +// false and out will be unchanged. +func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { + + start := d.i + + if kind == '\x03' { + // Special case for documents. Delegate to readDocTo(). + switch out.Kind() { + case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: + d.readDocTo(out) + default: + switch out.Interface().(type) { + case D: + out.Set(d.readDocElems(out.Type())) + case RawD: + out.Set(d.readRawDocElems(out.Type())) + default: + d.readDocTo(blackHole) + } + } + return true + } + + var in interface{} + + switch kind { + case 0x01: // Float64 + in = d.readFloat64() + case 0x02: // UTF-8 string + in = d.readStr() + case 0x03: // Document + panic("Can't happen. Handled above.") + case 0x04: // Array + outt := out.Type() + for outt.Kind() == reflect.Ptr { + outt = outt.Elem() + } + switch outt.Kind() { + case reflect.Array: + d.readArrayDocTo(out) + return true + case reflect.Slice: + in = d.readSliceDoc(outt) + default: + in = d.readSliceDoc(typeSlice) + } + case 0x05: // Binary + b := d.readBinary() + if b.Kind == 0x00 || b.Kind == 0x02 { + in = b.Data + } else { + in = b + } + case 0x06: // Undefined (obsolete, but still seen in the wild) + in = Undefined + case 0x07: // ObjectId + in = ObjectId(d.readBytes(12)) + case 0x08: // Bool + in = d.readBool() + case 0x09: // Timestamp + // MongoDB handles timestamps as milliseconds. + i := d.readInt64() + if i == -62135596800000 { + in = time.Time{} // In UTC for convenience. + } else { + in = time.Unix(i/1e3, i%1e3*1e6) + } + case 0x0A: // Nil + in = nil + case 0x0B: // RegEx + in = d.readRegEx() + case 0x0D: // JavaScript without scope + in = JavaScript{Code: d.readStr()} + case 0x0E: // Symbol + in = Symbol(d.readStr()) + case 0x0F: // JavaScript with scope + d.i += 4 // Skip length + js := JavaScript{d.readStr(), make(M)} + d.readDocTo(reflect.ValueOf(js.Scope)) + in = js + case 0x10: // Int32 + in = int(d.readInt32()) + case 0x11: // Mongo-specific timestamp + in = MongoTimestamp(d.readInt64()) + case 0x12: // Int64 + in = d.readInt64() + case 0x7F: // Max key + in = MaxKey + case 0xFF: // Min key + in = MinKey + default: + panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) + } + + outt := out.Type() + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]})) + return true + } + + if setter := getSetter(outt, out); setter != nil { + err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } + if err == nil { + return true + } + if _, ok := err.(*TypeError); !ok { + panic(err) + } + return false + } + + if in == nil { + out.Set(reflect.Zero(outt)) + return true + } + + outk := outt.Kind() + + // Dereference and initialize pointer if necessary. + first := true + for outk == reflect.Ptr { + if !out.IsNil() { + out = out.Elem() + } else { + elem := reflect.New(outt.Elem()) + if first { + // Only set if value is compatible. + first = false + defer func(out, elem reflect.Value) { + if good { + out.Set(elem) + } + }(out, elem) + } else { + out.Set(elem) + } + out = elem + } + outt = out.Type() + outk = outt.Kind() + } + + inv := reflect.ValueOf(in) + if outt == inv.Type() { + out.Set(inv) + return true + } + + switch outk { + case reflect.Interface: + out.Set(inv) + return true + case reflect.String: + switch inv.Kind() { + case reflect.String: + out.SetString(inv.String()) + return true + case reflect.Slice: + if b, ok := in.([]byte); ok { + out.SetString(string(b)) + return true + } + case reflect.Int, reflect.Int64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatInt(inv.Int(), 10)) + return true + } + case reflect.Float64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatFloat(inv.Float(), 'f', -1, 64)) + return true + } + } + case reflect.Slice, reflect.Array: + // Remember, array (0x04) slices are built with the correct + // element type. If we are here, must be a cross BSON kind + // conversion (e.g. 0x05 unmarshalling on string). + if outt.Elem().Kind() != reflect.Uint8 { + break + } + switch inv.Kind() { + case reflect.String: + slice := []byte(inv.String()) + out.Set(reflect.ValueOf(slice)) + return true + case reflect.Slice: + switch outt.Kind() { + case reflect.Array: + reflect.Copy(out, inv) + case reflect.Slice: + out.SetBytes(inv.Bytes()) + } + return true + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetInt(inv.Int()) + return true + case reflect.Float32, reflect.Float64: + out.SetInt(int64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetInt(1) + } else { + out.SetInt(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("can't happen: no uint types in BSON (!?)") + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetUint(uint64(inv.Int())) + return true + case reflect.Float32, reflect.Float64: + out.SetUint(uint64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetUint(1) + } else { + out.SetUint(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON.") + } + case reflect.Float32, reflect.Float64: + switch inv.Kind() { + case reflect.Float32, reflect.Float64: + out.SetFloat(inv.Float()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetFloat(float64(inv.Int())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetFloat(1) + } else { + out.SetFloat(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Bool: + switch inv.Kind() { + case reflect.Bool: + out.SetBool(inv.Bool()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetBool(inv.Int() != 0) + return true + case reflect.Float32, reflect.Float64: + out.SetBool(inv.Float() != 0) + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Struct: + if outt == typeURL && inv.Kind() == reflect.String { + u, err := url.Parse(inv.String()) + if err != nil { + panic(err) + } + out.Set(reflect.ValueOf(u).Elem()) + return true + } + } + + return false +} + +// -------------------------------------------------------------------------- +// Parsers of basic types. + +func (d *decoder) readRegEx() RegEx { + re := RegEx{} + re.Pattern = d.readCStr() + re.Options = d.readCStr() + return re +} + +func (d *decoder) readBinary() Binary { + l := d.readInt32() + b := Binary{} + b.Kind = d.readByte() + b.Data = d.readBytes(l) + if b.Kind == 0x02 && len(b.Data) >= 4 { + // Weird obsolete format with redundant length. + b.Data = b.Data[4:] + } + return b +} + +func (d *decoder) readStr() string { + l := d.readInt32() + b := d.readBytes(l - 1) + if d.readByte() != '\x00' { + corrupted() + } + return string(b) +} + +func (d *decoder) readCStr() string { + start := d.i + end := start + l := len(d.in) + for ; end != l; end++ { + if d.in[end] == '\x00' { + break + } + } + d.i = end + 1 + if d.i > l { + corrupted() + } + return string(d.in[start:end]) +} + +func (d *decoder) readBool() bool { + if d.readByte() == 1 { + return true + } + return false +} + +func (d *decoder) readFloat64() float64 { + return math.Float64frombits(uint64(d.readInt64())) +} + +func (d *decoder) readInt32() int32 { + b := d.readBytes(4) + return int32((uint32(b[0]) << 0) | + (uint32(b[1]) << 8) | + (uint32(b[2]) << 16) | + (uint32(b[3]) << 24)) +} + +func (d *decoder) readInt64() int64 { + b := d.readBytes(8) + return int64((uint64(b[0]) << 0) | + (uint64(b[1]) << 8) | + (uint64(b[2]) << 16) | + (uint64(b[3]) << 24) | + (uint64(b[4]) << 32) | + (uint64(b[5]) << 40) | + (uint64(b[6]) << 48) | + (uint64(b[7]) << 56)) +} + +func (d *decoder) readByte() byte { + i := d.i + d.i++ + if d.i > len(d.in) { + corrupted() + } + return d.in[i] +} + +func (d *decoder) readBytes(length int32) []byte { + start := d.i + d.i += int(length) + if d.i > len(d.in) { + corrupted() + } + return d.in[start : start+int(length)] +} diff --git a/bson/encode.go b/bson/encode.go new file mode 100644 index 0000000..6544748 --- /dev/null +++ b/bson/encode.go @@ -0,0 +1,479 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "encoding/json" + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" +) + +// -------------------------------------------------------------------------- +// Some internal infrastructure. + +var ( + typeBinary = reflect.TypeOf(Binary{}) + typeObjectId = reflect.TypeOf(ObjectId("")) + typeSymbol = reflect.TypeOf(Symbol("")) + typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0)) + typeOrderKey = reflect.TypeOf(MinKey) + typeDocElem = reflect.TypeOf(DocElem{}) + typeRawDocElem = reflect.TypeOf(RawDocElem{}) + typeRaw = reflect.TypeOf(Raw{}) + typeURL = reflect.TypeOf(url.URL{}) + typeTime = reflect.TypeOf(time.Time{}) + typeString = reflect.TypeOf("") + typeJSONNumber = reflect.TypeOf(json.Number("")) +) + +const itoaCacheSize = 32 + +var itoaCache []string + +func init() { + itoaCache = make([]string, itoaCacheSize) + for i := 0; i != itoaCacheSize; i++ { + itoaCache[i] = strconv.Itoa(i) + } +} + +func itoa(i int) string { + if i < itoaCacheSize { + return itoaCache[i] + } + return strconv.Itoa(i) +} + +// -------------------------------------------------------------------------- +// Marshaling of the document value itself. + +type encoder struct { + out []byte +} + +func (e *encoder) addDoc(v reflect.Value) { + for { + if vi, ok := v.Interface().(Getter); ok { + getv, err := vi.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + continue + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + continue + } + break + } + + if v.Type() == typeRaw { + raw := v.Interface().(Raw) + if raw.Kind != 0x03 && raw.Kind != 0x00 { + panic("Attempted to unmarshal Raw kind " + strconv.Itoa(int(raw.Kind)) + " as a document") + } + e.addBytes(raw.Data...) + return + } + + start := e.reserveInt32() + + switch v.Kind() { + case reflect.Map: + e.addMap(v) + case reflect.Struct: + e.addStruct(v) + case reflect.Array, reflect.Slice: + e.addSlice(v) + default: + panic("Can't marshal " + v.Type().String() + " as a BSON document") + } + + e.addBytes(0) + e.setInt32(start, int32(len(e.out)-start)) +} + +func (e *encoder) addMap(v reflect.Value) { + for _, k := range v.MapKeys() { + e.addElem(k.String(), v.MapIndex(k), false) + } +} + +func (e *encoder) addStruct(v reflect.Value) { + sinfo, err := getStructInfo(v.Type()) + if err != nil { + panic(err) + } + var value reflect.Value + if sinfo.InlineMap >= 0 { + m := v.Field(sinfo.InlineMap) + if m.Len() > 0 { + for _, k := range m.MapKeys() { + ks := k.String() + if _, found := sinfo.FieldsMap[ks]; found { + panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) + } + e.addElem(ks, m.MapIndex(k), false) + } + } + } + for _, info := range sinfo.FieldsList { + if info.Inline == nil { + value = v.Field(info.Num) + } else { + value = v.FieldByIndex(info.Inline) + } + if info.OmitEmpty && isZero(value) { + continue + } + e.addElem(info.Key, value, info.MinSize) + } +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return len(v.String()) == 0 + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Slice: + return v.Len() == 0 + case reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Struct: + if v.Type() == typeTime { + return v.Interface().(time.Time).IsZero() + } + for i := v.NumField()-1; i >= 0; i-- { + if !isZero(v.Field(i)) { + return false + } + } + return true + } + return false +} + +func (e *encoder) addSlice(v reflect.Value) { + vi := v.Interface() + if d, ok := vi.(D); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if d, ok := vi.(RawD); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + l := v.Len() + et := v.Type().Elem() + if et == typeDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(DocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if et == typeRawDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(RawDocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + for i := 0; i < l; i++ { + e.addElem(itoa(i), v.Index(i), false) + } +} + +// -------------------------------------------------------------------------- +// Marshaling of elements in a document. + +func (e *encoder) addElemName(kind byte, name string) { + e.addBytes(kind) + e.addBytes([]byte(name)...) + e.addBytes(0) +} + +func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { + + if !v.IsValid() { + e.addElemName('\x0A', name) + return + } + + if getter, ok := v.Interface().(Getter); ok { + getv, err := getter.GetBSON() + if err != nil { + panic(err) + } + e.addElem(name, reflect.ValueOf(getv), minSize) + return + } + + switch v.Kind() { + + case reflect.Interface: + e.addElem(name, v.Elem(), minSize) + + case reflect.Ptr: + e.addElem(name, v.Elem(), minSize) + + case reflect.String: + s := v.String() + switch v.Type() { + case typeObjectId: + if len(s) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s)) + ")") + } + e.addElemName('\x07', name) + e.addBytes([]byte(s)...) + case typeSymbol: + e.addElemName('\x0E', name) + e.addStr(s) + case typeJSONNumber: + n := v.Interface().(json.Number) + if i, err := n.Int64(); err == nil { + e.addElemName('\x12', name) + e.addInt64(i) + } else if f, err := n.Float64(); err == nil { + e.addElemName('\x01', name) + e.addFloat64(f) + } else { + panic("failed to convert json.Number to a number: " + s) + } + default: + e.addElemName('\x02', name) + e.addStr(s) + } + + case reflect.Float32, reflect.Float64: + e.addElemName('\x01', name) + e.addFloat64(v.Float()) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + u := v.Uint() + if int64(u) < 0 { + panic("BSON has no uint64 type, and value is too large to fit correctly in an int64") + } else if u <= math.MaxInt32 && (minSize || v.Kind() <= reflect.Uint32) { + e.addElemName('\x10', name) + e.addInt32(int32(u)) + } else { + e.addElemName('\x12', name) + e.addInt64(int64(u)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type() { + case typeMongoTimestamp: + e.addElemName('\x11', name) + e.addInt64(v.Int()) + + case typeOrderKey: + if v.Int() == int64(MaxKey) { + e.addElemName('\x7F', name) + } else { + e.addElemName('\xFF', name) + } + + default: + i := v.Int() + if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 { + // It fits into an int32, encode as such. + e.addElemName('\x10', name) + e.addInt32(int32(i)) + } else { + e.addElemName('\x12', name) + e.addInt64(i) + } + } + + case reflect.Bool: + e.addElemName('\x08', name) + if v.Bool() { + e.addBytes(1) + } else { + e.addBytes(0) + } + + case reflect.Map: + e.addElemName('\x03', name) + e.addDoc(v) + + case reflect.Slice: + vt := v.Type() + et := vt.Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Bytes()) + } else if et == typeDocElem || et == typeRawDocElem { + e.addElemName('\x03', name) + e.addDoc(v) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Array: + et := v.Type().Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Slice(0, v.Len()).Interface().([]byte)) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Struct: + switch s := v.Interface().(type) { + + case Raw: + kind := s.Kind + if kind == 0x00 { + kind = 0x03 + } + e.addElemName(kind, name) + e.addBytes(s.Data...) + + case Binary: + e.addElemName('\x05', name) + e.addBinary(s.Kind, s.Data) + + case RegEx: + e.addElemName('\x0B', name) + e.addCStr(s.Pattern) + e.addCStr(s.Options) + + case JavaScript: + if s.Scope == nil { + e.addElemName('\x0D', name) + e.addStr(s.Code) + } else { + e.addElemName('\x0F', name) + start := e.reserveInt32() + e.addStr(s.Code) + e.addDoc(reflect.ValueOf(s.Scope)) + e.setInt32(start, int32(len(e.out)-start)) + } + + case time.Time: + // MongoDB handles timestamps as milliseconds. + e.addElemName('\x09', name) + e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6)) + + case url.URL: + e.addElemName('\x02', name) + e.addStr(s.String()) + + case undefined: + e.addElemName('\x06', name) + + default: + e.addElemName('\x03', name) + e.addDoc(v) + } + + default: + panic("Can't marshal " + v.Type().String() + " in a BSON document") + } +} + +// -------------------------------------------------------------------------- +// Marshaling of base types. + +func (e *encoder) addBinary(subtype byte, v []byte) { + if subtype == 0x02 { + // Wonder how that brilliant idea came to life. Obsolete, luckily. + e.addInt32(int32(len(v) + 4)) + e.addBytes(subtype) + e.addInt32(int32(len(v))) + } else { + e.addInt32(int32(len(v))) + e.addBytes(subtype) + } + e.addBytes(v...) +} + +func (e *encoder) addStr(v string) { + e.addInt32(int32(len(v) + 1)) + e.addCStr(v) +} + +func (e *encoder) addCStr(v string) { + e.addBytes([]byte(v)...) + e.addBytes(0) +} + +func (e *encoder) reserveInt32() (pos int) { + pos = len(e.out) + e.addBytes(0, 0, 0, 0) + return pos +} + +func (e *encoder) setInt32(pos int, v int32) { + e.out[pos+0] = byte(v) + e.out[pos+1] = byte(v >> 8) + e.out[pos+2] = byte(v >> 16) + e.out[pos+3] = byte(v >> 24) +} + +func (e *encoder) addInt32(v int32) { + u := uint32(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24)) +} + +func (e *encoder) addInt64(v int64) { + u := uint64(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24), + byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56)) +} + +func (e *encoder) addFloat64(v float64) { + e.addInt64(int64(math.Float64bits(v))) +} + +func (e *encoder) addBytes(v ...byte) { + e.out = append(e.out, v...) +} diff --git a/log/doc.go b/log/doc.go new file mode 100644 index 0000000..a06ec37 --- /dev/null +++ b/log/doc.go @@ -0,0 +1,20 @@ +// log package supplies more advanced features than go orign log package. +// +// It supports log different level: trace, debug, info, warn, error, fatal. +// +// It also supports different log handlers which you can log to stdout, file, socket, etc... +// +// Use +// +// import "github.com/siddontang/go-log/log" +// +// //log with different level +// log.Info("hello world") +// log.Error("hello world") +// +// //create a logger with specified handler +// h := NewStreamHandler(os.Stdout) +// l := log.NewDefault(h) +// l.Info("hello world") +// +package log diff --git a/log/filehandler.go b/log/filehandler.go new file mode 100644 index 0000000..783b652 --- /dev/null +++ b/log/filehandler.go @@ -0,0 +1,200 @@ +package log + +import ( + "fmt" + "os" + "path" + "time" +) + +//FileHandler writes log to a file. +type FileHandler struct { + fd *os.File +} + +func NewFileHandler(fileName string, flag int) (*FileHandler, error) { + dir := path.Dir(fileName) + os.Mkdir(dir, 0777) + + f, err := os.OpenFile(fileName, flag, 0) + if err != nil { + return nil, err + } + + h := new(FileHandler) + + h.fd = f + + return h, nil +} + +func (h *FileHandler) Write(b []byte) (n int, err error) { + return h.fd.Write(b) +} + +func (h *FileHandler) Close() error { + return h.fd.Close() +} + +//RotatingFileHandler writes log a file, if file size exceeds maxBytes, +//it will backup current file and open a new one. +// +//max backup file number is set by backupCount, it will delete oldest if backups too many. +type RotatingFileHandler struct { + fd *os.File + + fileName string + maxBytes int + backupCount int +} + +func NewRotatingFileHandler(fileName string, maxBytes int, backupCount int) (*RotatingFileHandler, error) { + dir := path.Dir(fileName) + os.Mkdir(dir, 0777) + + h := new(RotatingFileHandler) + + if maxBytes <= 0 { + return nil, fmt.Errorf("invalid max bytes") + } + + h.fileName = fileName + h.maxBytes = maxBytes + h.backupCount = backupCount + + var err error + h.fd, err = os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return nil, err + } + + return h, nil +} + +func (h *RotatingFileHandler) Write(p []byte) (n int, err error) { + h.doRollover() + return h.fd.Write(p) +} + +func (h *RotatingFileHandler) Close() error { + if h.fd != nil { + return h.fd.Close() + } + return nil +} + +func (h *RotatingFileHandler) doRollover() { + f, err := h.fd.Stat() + if err != nil { + return + } + + if h.maxBytes <= 0 { + return + } else if f.Size() < int64(h.maxBytes) { + return + } + + if h.backupCount > 0 { + h.fd.Close() + + for i := h.backupCount - 1; i > 0; i-- { + sfn := fmt.Sprintf("%s.%d", h.fileName, i) + dfn := fmt.Sprintf("%s.%d", h.fileName, i+1) + + os.Rename(sfn, dfn) + } + + dfn := fmt.Sprintf("%s.1", h.fileName) + os.Rename(h.fileName, dfn) + + h.fd, _ = os.OpenFile(h.fileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + } +} + +//TimeRotatingFileHandler writes log to a file, +//it will backup current and open a new one, with a period time you sepecified. +// +//refer: http://docs.python.org/2/library/logging.handlers.html. +//same like python TimedRotatingFileHandler. +type TimeRotatingFileHandler struct { + fd *os.File + + baseName string + interval int64 + suffix string + rolloverAt int64 +} + +const ( + WhenSecond = iota + WhenMinute + WhenHour + WhenDay +) + +func NewTimeRotatingFileHandler(baseName string, when int8, interval int) (*TimeRotatingFileHandler, error) { + dir := path.Dir(baseName) + os.Mkdir(dir, 0777) + + h := new(TimeRotatingFileHandler) + + h.baseName = baseName + + switch when { + case WhenSecond: + h.interval = 1 + h.suffix = "2006-01-02_15-04-05" + case WhenMinute: + h.interval = 60 + h.suffix = "2006-01-02_15-04" + case WhenHour: + h.interval = 3600 + h.suffix = "2006-01-02_15" + case WhenDay: + h.interval = 3600 * 24 + h.suffix = "2006-01-02" + default: + return nil, fmt.Errorf("invalid when_rotate: %d", when) + } + + h.interval = h.interval * int64(interval) + + var err error + h.fd, err = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return nil, err + } + + fInfo, _ := h.fd.Stat() + h.rolloverAt = fInfo.ModTime().Unix() + h.interval + + return h, nil +} + +func (h *TimeRotatingFileHandler) doRollover() { + //refer http://hg.python.org/cpython/file/2.7/Lib/logging/handlers.py + now := time.Now() + + if h.rolloverAt <= now.Unix() { + fName := h.baseName + now.Format(h.suffix) + h.fd.Close() + e := os.Rename(h.baseName, fName) + if e != nil { + panic(e) + } + + h.fd, _ = os.OpenFile(h.baseName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + + h.rolloverAt = time.Now().Unix() + h.interval + } +} + +func (h *TimeRotatingFileHandler) Write(b []byte) (n int, err error) { + h.doRollover() + return h.fd.Write(b) +} + +func (h *TimeRotatingFileHandler) Close() error { + return h.fd.Close() +} diff --git a/log/handler.go b/log/handler.go new file mode 100644 index 0000000..352e30c --- /dev/null +++ b/log/handler.go @@ -0,0 +1,49 @@ +package log + +import ( + "io" +) + +//Handler writes logs to somewhere +type Handler interface { + Write(p []byte) (n int, err error) + Close() error +} + +//StreamHandler writes logs to a specified io Writer, maybe stdout, stderr, etc... +type StreamHandler struct { + w io.Writer +} + +func NewStreamHandler(w io.Writer) (*StreamHandler, error) { + h := new(StreamHandler) + + h.w = w + + return h, nil +} + +func (h *StreamHandler) Write(b []byte) (n int, err error) { + return h.w.Write(b) +} + +func (h *StreamHandler) Close() error { + return nil +} + + +//NullHandler does nothing, it discards anything. +type NullHandler struct { +} + +func NewNullHandler() (*NullHandler, error) { + return new(NullHandler), nil +} + +func (h *NullHandler) Write(b []byte) (n int, err error) { + return len(b), nil +} + +func (h *NullHandler) Close() { + +} diff --git a/log/log.go b/log/log.go new file mode 100644 index 0000000..2e6619b --- /dev/null +++ b/log/log.go @@ -0,0 +1,239 @@ +package log + +import ( + "fmt" + "os" + "runtime" + "strconv" + "sync" + "time" +) + +//log level, from low to high, more high means more serious +const ( + LevelTrace = iota + LevelDebug + LevelInfo + LevelWarn + LevelError + LevelFatal +) + +const ( + Ltime = 1 << iota //time format "2006/01/02 15:04:05" + Lfile //file.go:123 + Llevel //[Trace|Debug|Info...] +) + +var LevelName [6]string = [6]string{"Trace", "Debug", "Info", "Warn", "Error", "Fatal"} + +const TimeFormat = "2006/01/02 15:04:05" + +const maxBufPoolSize = 16 + +type Logger struct { + sync.Mutex + + level int + flag int + + handler Handler + + quit chan struct{} + msg chan []byte + + bufs [][]byte +} + +//new a logger with specified handler and flag +func New(handler Handler, flag int) *Logger { + var l = new(Logger) + + l.level = LevelInfo + l.handler = handler + + l.flag = flag + + l.quit = make(chan struct{}) + + l.msg = make(chan []byte, 1024) + + l.bufs = make([][]byte, 0, 16) + + go l.run() + + return l +} + +//new a default logger with specified handler and flag: Ltime|Lfile|Llevel +func NewDefault(handler Handler) *Logger { + return New(handler, Ltime|Lfile|Llevel) +} + +func newStdHandler() *StreamHandler { + h, _ := NewStreamHandler(os.Stdout) + return h +} + +var std = NewDefault(newStdHandler()) + +func (l *Logger) run() { + for { + select { + case msg := <-l.msg: + l.handler.Write(msg) + l.putBuf(msg) + case <-l.quit: + l.handler.Close() + } + } +} + +func (l *Logger) popBuf() []byte { + l.Lock() + var buf []byte + if len(l.bufs) == 0 { + buf = make([]byte, 0, 1024) + } else { + buf = l.bufs[len(l.bufs)-1] + l.bufs = l.bufs[0 : len(l.bufs)-1] + } + l.Unlock() + + return buf +} + +func (l *Logger) putBuf(buf []byte) { + l.Lock() + if len(l.bufs) < maxBufPoolSize { + buf = buf[0:0] + l.bufs = append(l.bufs, buf) + } + l.Unlock() +} + +func (l *Logger) Close() { + if l.quit == nil { + return + } + + close(l.quit) + l.quit = nil +} + +//set log level, any log level less than it will not log +func (l *Logger) SetLevel(level int) { + l.level = level +} + +//a low interface, maybe you can use it for your special log format +//but it may be not exported later...... +func (l *Logger) Output(callDepth int, level int, format string, v ...interface{}) { + if l.level > level { + return + } + + buf := l.popBuf() + + if l.flag&Ltime > 0 { + now := time.Now().Format(TimeFormat) + buf = append(buf, '[') + buf = append(buf, now...) + buf = append(buf, "] "...) + } + + if l.flag&Lfile > 0 { + _, file, line, ok := runtime.Caller(callDepth) + if !ok { + file = "???" + line = 0 + } else { + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + file = file[i+1:] + break + } + } + } + + buf = append(buf, file...) + buf = append(buf, ':') + + buf = strconv.AppendInt(buf, int64(line), 10) + buf = append(buf, ' ') + } + + if l.flag&Llevel > 0 { + buf = append(buf, '[') + buf = append(buf, LevelName[level]...) + buf = append(buf, "] "...) + } + + s := fmt.Sprintf(format, v...) + + buf = append(buf, s...) + + if s[len(s)-1] != '\n' { + buf = append(buf, '\n') + } + + l.msg <- buf +} + +//log with Trace level +func (l *Logger) Trace(format string, v ...interface{}) { + l.Output(2, LevelTrace, format, v...) +} + +//log with Debug level +func (l *Logger) Debug(format string, v ...interface{}) { + l.Output(2, LevelDebug, format, v...) +} + +//log with info level +func (l *Logger) Info(format string, v ...interface{}) { + l.Output(2, LevelInfo, format, v...) +} + +//log with warn level +func (l *Logger) Warn(format string, v ...interface{}) { + l.Output(2, LevelWarn, format, v...) +} + +//log with error level +func (l *Logger) Error(format string, v ...interface{}) { + l.Output(2, LevelError, format, v...) +} + +//log with fatal level +func (l *Logger) Fatal(format string, v ...interface{}) { + l.Output(2, LevelFatal, format, v...) +} + +func SetLevel(level int) { + std.SetLevel(level) +} + +func Trace(format string, v ...interface{}) { + std.Output(2, LevelTrace, format, v...) +} + +func Debug(format string, v ...interface{}) { + std.Output(2, LevelDebug, format, v...) +} + +func Info(format string, v ...interface{}) { + std.Output(2, LevelInfo, format, v...) +} + +func Warn(format string, v ...interface{}) { + std.Output(2, LevelWarn, format, v...) +} + +func Error(format string, v ...interface{}) { + std.Output(2, LevelError, format, v...) +} + +func Fatal(format string, v ...interface{}) { + std.Output(2, LevelFatal, format, v...) +} diff --git a/log/log_test.go b/log/log_test.go new file mode 100644 index 0000000..67d3b0b --- /dev/null +++ b/log/log_test.go @@ -0,0 +1,52 @@ +package log + +import ( + "os" + "testing" +) + +func TestStdStreamLog(t *testing.T) { + h, _ := NewStreamHandler(os.Stdout) + s := NewDefault(h) + s.Info("hello world") + + s.Close() + + Info("hello world") +} + +func TestRotatingFileLog(t *testing.T) { + path := "./test_log" + os.RemoveAll(path) + + os.Mkdir(path, 0777) + fileName := path + "/test" + + h, err := NewRotatingFileHandler(fileName, 10, 2) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, 10) + + h.Write(buf) + + h.Write(buf) + + if _, err := os.Stat(fileName + ".1"); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(fileName + ".2"); err == nil { + t.Fatal(err) + } + + h.Write(buf) + if _, err := os.Stat(fileName + ".2"); err != nil { + t.Fatal(err) + } + + h.Close() + + os.RemoveAll(path) +} diff --git a/log/sockethandler.go b/log/sockethandler.go new file mode 100644 index 0000000..ad81ccd --- /dev/null +++ b/log/sockethandler.go @@ -0,0 +1,65 @@ +package log + +import ( + "encoding/binary" + "net" + "time" +) + +//SocketHandler writes log to a connectionl. +//Network protocol is simple: log length + log | log length + log. log length is uint32, bigendian. +//you must implement your own log server, maybe you can use logd instead simply. +type SocketHandler struct { + c net.Conn + protocol string + addr string +} + +func NewSocketHandler(protocol string, addr string) (*SocketHandler, error) { + s := new(SocketHandler) + + s.protocol = protocol + s.addr = addr + + return s, nil +} + +func (h *SocketHandler) Write(p []byte) (n int, err error) { + if err = h.connect(); err != nil { + return + } + + buf := make([]byte, len(p)+4) + + binary.BigEndian.PutUint32(buf, uint32(len(p))) + + copy(buf[4:], p) + + n, err = h.c.Write(buf) + if err != nil { + h.c.Close() + h.c = nil + } + return +} + +func (h *SocketHandler) Close() error { + if h.c != nil { + h.c.Close() + } + return nil +} + +func (h *SocketHandler) connect() error { + if h.c != nil { + return nil + } + + var err error + h.c, err = net.DialTimeout(h.protocol, h.addr, 20*time.Second) + if err != nil { + return err + } + + return nil +} diff --git a/snappy/LICENSE b/snappy/LICENSE new file mode 100644 index 0000000..6050c10 --- /dev/null +++ b/snappy/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2011 The Snappy-Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/snappy/decode.go b/snappy/decode.go new file mode 100644 index 0000000..d93c1b9 --- /dev/null +++ b/snappy/decode.go @@ -0,0 +1,124 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "encoding/binary" + "errors" +) + +// ErrCorrupt reports that the input is invalid. +var ErrCorrupt = errors.New("snappy: corrupt input") + +// DecodedLen returns the length of the decoded block. +func DecodedLen(src []byte) (int, error) { + v, _, err := decodedLen(src) + return v, err +} + +// decodedLen returns the length of the decoded block and the number of bytes +// that the length header occupied. +func decodedLen(src []byte) (blockLen, headerLen int, err error) { + v, n := binary.Uvarint(src) + if n == 0 { + return 0, 0, ErrCorrupt + } + if uint64(int(v)) != v { + return 0, 0, errors.New("snappy: decoded block is too large") + } + return int(v), n, nil +} + +// Decode returns the decoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire decoded block. +// Otherwise, a newly allocated slice will be returned. +// It is valid to pass a nil dst. +func Decode(dst, src []byte) ([]byte, error) { + dLen, s, err := decodedLen(src) + if err != nil { + return nil, err + } + if len(dst) < dLen { + dst = make([]byte, dLen) + } + + var d, offset, length int + for s < len(src) { + switch src[s] & 0x03 { + case tagLiteral: + x := uint(src[s] >> 2) + switch { + case x < 60: + s += 1 + case x == 60: + s += 2 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-1]) + case x == 61: + s += 3 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-2]) | uint(src[s-1])<<8 + case x == 62: + s += 4 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-3]) | uint(src[s-2])<<8 | uint(src[s-1])<<16 + case x == 63: + s += 5 + if s > len(src) { + return nil, ErrCorrupt + } + x = uint(src[s-4]) | uint(src[s-3])<<8 | uint(src[s-2])<<16 | uint(src[s-1])<<24 + } + length = int(x + 1) + if length <= 0 { + return nil, errors.New("snappy: unsupported literal length") + } + if length > len(dst)-d || length > len(src)-s { + return nil, ErrCorrupt + } + copy(dst[d:], src[s:s+length]) + d += length + s += length + continue + + case tagCopy1: + s += 2 + if s > len(src) { + return nil, ErrCorrupt + } + length = 4 + int(src[s-2])>>2&0x7 + offset = int(src[s-2])&0xe0<<3 | int(src[s-1]) + + case tagCopy2: + s += 3 + if s > len(src) { + return nil, ErrCorrupt + } + length = 1 + int(src[s-3])>>2 + offset = int(src[s-2]) | int(src[s-1])<<8 + + case tagCopy4: + return nil, errors.New("snappy: unsupported COPY_4 tag") + } + + end := d + length + if offset > d || end > len(dst) { + return nil, ErrCorrupt + } + for ; d < end; d++ { + dst[d] = dst[d-offset] + } + } + if d != dLen { + return nil, ErrCorrupt + } + return dst[:d], nil +} diff --git a/snappy/encode.go b/snappy/encode.go new file mode 100644 index 0000000..b2371db --- /dev/null +++ b/snappy/encode.go @@ -0,0 +1,174 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "encoding/binary" +) + +// We limit how far copy back-references can go, the same as the C++ code. +const maxOffset = 1 << 15 + +// emitLiteral writes a literal chunk and returns the number of bytes written. +func emitLiteral(dst, lit []byte) int { + i, n := 0, uint(len(lit)-1) + switch { + case n < 60: + dst[0] = uint8(n)<<2 | tagLiteral + i = 1 + case n < 1<<8: + dst[0] = 60<<2 | tagLiteral + dst[1] = uint8(n) + i = 2 + case n < 1<<16: + dst[0] = 61<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + i = 3 + case n < 1<<24: + dst[0] = 62<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + dst[3] = uint8(n >> 16) + i = 4 + case int64(n) < 1<<32: + dst[0] = 63<<2 | tagLiteral + dst[1] = uint8(n) + dst[2] = uint8(n >> 8) + dst[3] = uint8(n >> 16) + dst[4] = uint8(n >> 24) + i = 5 + default: + panic("snappy: source buffer is too long") + } + if copy(dst[i:], lit) != len(lit) { + panic("snappy: destination buffer is too short") + } + return i + len(lit) +} + +// emitCopy writes a copy chunk and returns the number of bytes written. +func emitCopy(dst []byte, offset, length int) int { + i := 0 + for length > 0 { + x := length - 4 + if 0 <= x && x < 1<<3 && offset < 1<<11 { + dst[i+0] = uint8(offset>>8)&0x07<<5 | uint8(x)<<2 | tagCopy1 + dst[i+1] = uint8(offset) + i += 2 + break + } + + x = length + if x > 1<<6 { + x = 1 << 6 + } + dst[i+0] = uint8(x-1)<<2 | tagCopy2 + dst[i+1] = uint8(offset) + dst[i+2] = uint8(offset >> 8) + i += 3 + length -= x + } + return i +} + +// Encode returns the encoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire encoded block. +// Otherwise, a newly allocated slice will be returned. +// It is valid to pass a nil dst. +func Encode(dst, src []byte) ([]byte, error) { + if n := MaxEncodedLen(len(src)); len(dst) < n { + dst = make([]byte, n) + } + + // The block starts with the varint-encoded length of the decompressed bytes. + d := binary.PutUvarint(dst, uint64(len(src))) + + // Return early if src is short. + if len(src) <= 4 { + if len(src) != 0 { + d += emitLiteral(dst[d:], src) + } + return dst[:d], nil + } + + // Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive. + const maxTableSize = 1 << 14 + shift, tableSize := uint(32-8), 1<<8 + for tableSize < maxTableSize && tableSize < len(src) { + shift-- + tableSize *= 2 + } + var table [maxTableSize]int + + // Iterate over the source bytes. + var ( + s int // The iterator position. + t int // The last position with the same hash as s. + lit int // The start position of any pending literal bytes. + ) + for s+3 < len(src) { + // Update the hash table. + b0, b1, b2, b3 := src[s], src[s+1], src[s+2], src[s+3] + h := uint32(b0) | uint32(b1)<<8 | uint32(b2)<<16 | uint32(b3)<<24 + p := &table[(h*0x1e35a7bd)>>shift] + // We need to to store values in [-1, inf) in table. To save + // some initialization time, (re)use the table's zero value + // and shift the values against this zero: add 1 on writes, + // subtract 1 on reads. + t, *p = *p-1, s+1 + // If t is invalid or src[s:s+4] differs from src[t:t+4], accumulate a literal byte. + if t < 0 || s-t >= maxOffset || b0 != src[t] || b1 != src[t+1] || b2 != src[t+2] || b3 != src[t+3] { + s++ + continue + } + // Otherwise, we have a match. First, emit any pending literal bytes. + if lit != s { + d += emitLiteral(dst[d:], src[lit:s]) + } + // Extend the match to be as long as possible. + s0 := s + s, t = s+4, t+4 + for s < len(src) && src[s] == src[t] { + s++ + t++ + } + // Emit the copied bytes. + d += emitCopy(dst[d:], s-t, s-s0) + lit = s + } + + // Emit any final pending literal bytes and return. + if lit != len(src) { + d += emitLiteral(dst[d:], src[lit:]) + } + return dst[:d], nil +} + +// MaxEncodedLen returns the maximum length of a snappy block, given its +// uncompressed length. +func MaxEncodedLen(srcLen int) int { + // Compressed data can be defined as: + // compressed := item* literal* + // item := literal* copy + // + // The trailing literal sequence has a space blowup of at most 62/60 + // since a literal of length 60 needs one tag byte + one extra byte + // for length information. + // + // Item blowup is trickier to measure. Suppose the "copy" op copies + // 4 bytes of data. Because of a special check in the encoding code, + // we produce a 4-byte copy only if the offset is < 65536. Therefore + // the copy op takes 3 bytes to encode, and this type of item leads + // to at most the 62/60 blowup for representing literals. + // + // Suppose the "copy" op copies 5 bytes of data. If the offset is big + // enough, it will take 5 bytes to encode the copy op. Therefore the + // worst case here is a one-byte literal followed by a five-byte copy. + // That is, 6 bytes of input turn into 7 bytes of "compressed" data. + // + // This last factor dominates the blowup, so the final estimate is: + return 32 + srcLen + srcLen/6 +} diff --git a/snappy/snappy.go b/snappy/snappy.go new file mode 100644 index 0000000..2f1b790 --- /dev/null +++ b/snappy/snappy.go @@ -0,0 +1,38 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package snappy implements the snappy block-based compression format. +// It aims for very high speeds and reasonable compression. +// +// The C++ snappy implementation is at http://code.google.com/p/snappy/ +package snappy + +/* +Each encoded block begins with the varint-encoded length of the decoded data, +followed by a sequence of chunks. Chunks begin and end on byte boundaries. The +first byte of each chunk is broken into its 2 least and 6 most significant bits +called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag. +Zero means a literal tag. All other values mean a copy tag. + +For literal tags: + - If m < 60, the next 1 + m bytes are literal bytes. + - Otherwise, let n be the little-endian unsigned integer denoted by the next + m - 59 bytes. The next 1 + n bytes after that are literal bytes. + +For copy tags, length bytes are copied from offset bytes ago, in the style of +Lempel-Ziv compression algorithms. In particular: + - For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12). + The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10 + of the offset. The next byte is bits 0-7 of the offset. + - For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65). + The length is 1 + m. The offset is the little-endian unsigned integer + denoted by the next 2 bytes. + - For l == 3, this tag is a legacy format that is no longer supported. +*/ +const ( + tagLiteral = 0x00 + tagCopy1 = 0x01 + tagCopy2 = 0x02 + tagCopy4 = 0x03 +) diff --git a/snappy/snappy_test.go b/snappy/snappy_test.go new file mode 100644 index 0000000..7ba8392 --- /dev/null +++ b/snappy/snappy_test.go @@ -0,0 +1,261 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package snappy + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net/http" + "os" + "path/filepath" + "strings" + "testing" +) + +var download = flag.Bool("download", false, "If true, download any missing files before running benchmarks") + +func roundtrip(b, ebuf, dbuf []byte) error { + e, err := Encode(ebuf, b) + if err != nil { + return fmt.Errorf("encoding error: %v", err) + } + d, err := Decode(dbuf, e) + if err != nil { + return fmt.Errorf("decoding error: %v", err) + } + if !bytes.Equal(b, d) { + return fmt.Errorf("roundtrip mismatch:\n\twant %v\n\tgot %v", b, d) + } + return nil +} + +func TestEmpty(t *testing.T) { + if err := roundtrip(nil, nil, nil); err != nil { + t.Fatal(err) + } +} + +func TestSmallCopy(t *testing.T) { + for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} { + for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} { + for i := 0; i < 32; i++ { + s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb" + if err := roundtrip([]byte(s), ebuf, dbuf); err != nil { + t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err) + } + } + } + } +} + +func TestSmallRand(t *testing.T) { + rand.Seed(27354294) + for n := 1; n < 20000; n += 23 { + b := make([]byte, n) + for i, _ := range b { + b[i] = uint8(rand.Uint32()) + } + if err := roundtrip(b, nil, nil); err != nil { + t.Fatal(err) + } + } +} + +func TestSmallRegular(t *testing.T) { + for n := 1; n < 20000; n += 23 { + b := make([]byte, n) + for i, _ := range b { + b[i] = uint8(i%10 + 'a') + } + if err := roundtrip(b, nil, nil); err != nil { + t.Fatal(err) + } + } +} + +func benchDecode(b *testing.B, src []byte) { + encoded, err := Encode(nil, src) + if err != nil { + b.Fatal(err) + } + // Bandwidth is in amount of uncompressed data. + b.SetBytes(int64(len(src))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Decode(src, encoded) + } +} + +func benchEncode(b *testing.B, src []byte) { + // Bandwidth is in amount of uncompressed data. + b.SetBytes(int64(len(src))) + dst := make([]byte, MaxEncodedLen(len(src))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Encode(dst, src) + } +} + +func readFile(b *testing.B, filename string) []byte { + src, err := ioutil.ReadFile(filename) + if err != nil { + b.Fatalf("failed reading %s: %s", filename, err) + } + if len(src) == 0 { + b.Fatalf("%s has zero length", filename) + } + return src +} + +// expand returns a slice of length n containing repeated copies of src. +func expand(src []byte, n int) []byte { + dst := make([]byte, n) + for x := dst; len(x) > 0; { + i := copy(x, src) + x = x[i:] + } + return dst +} + +func benchWords(b *testing.B, n int, decode bool) { + // Note: the file is OS-language dependent so the resulting values are not + // directly comparable for non-US-English OS installations. + data := expand(readFile(b, "/usr/share/dict/words"), n) + if decode { + benchDecode(b, data) + } else { + benchEncode(b, data) + } +} + +func BenchmarkWordsDecode1e3(b *testing.B) { benchWords(b, 1e3, true) } +func BenchmarkWordsDecode1e4(b *testing.B) { benchWords(b, 1e4, true) } +func BenchmarkWordsDecode1e5(b *testing.B) { benchWords(b, 1e5, true) } +func BenchmarkWordsDecode1e6(b *testing.B) { benchWords(b, 1e6, true) } +func BenchmarkWordsEncode1e3(b *testing.B) { benchWords(b, 1e3, false) } +func BenchmarkWordsEncode1e4(b *testing.B) { benchWords(b, 1e4, false) } +func BenchmarkWordsEncode1e5(b *testing.B) { benchWords(b, 1e5, false) } +func BenchmarkWordsEncode1e6(b *testing.B) { benchWords(b, 1e6, false) } + +// testFiles' values are copied directly from +// https://code.google.com/p/snappy/source/browse/trunk/snappy_unittest.cc. +// The label field is unused in snappy-go. +var testFiles = []struct { + label string + filename string +}{ + {"html", "html"}, + {"urls", "urls.10K"}, + {"jpg", "house.jpg"}, + {"pdf", "mapreduce-osdi-1.pdf"}, + {"html4", "html_x_4"}, + {"cp", "cp.html"}, + {"c", "fields.c"}, + {"lsp", "grammar.lsp"}, + {"xls", "kennedy.xls"}, + {"txt1", "alice29.txt"}, + {"txt2", "asyoulik.txt"}, + {"txt3", "lcet10.txt"}, + {"txt4", "plrabn12.txt"}, + {"bin", "ptt5"}, + {"sum", "sum"}, + {"man", "xargs.1"}, + {"pb", "geo.protodata"}, + {"gaviota", "kppkn.gtb"}, +} + +// The test data files are present at this canonical URL. +const baseURL = "https://snappy.googlecode.com/svn/trunk/testdata/" + +func downloadTestdata(basename string) (errRet error) { + filename := filepath.Join("testdata", basename) + f, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to create %s: %s", filename, err) + } + defer f.Close() + defer func() { + if errRet != nil { + os.Remove(filename) + } + }() + resp, err := http.Get(baseURL + basename) + if err != nil { + return fmt.Errorf("failed to download %s: %s", baseURL+basename, err) + } + defer resp.Body.Close() + _, err = io.Copy(f, resp.Body) + if err != nil { + return fmt.Errorf("failed to write %s: %s", filename, err) + } + return nil +} + +func benchFile(b *testing.B, n int, decode bool) { + filename := filepath.Join("testdata", testFiles[n].filename) + if stat, err := os.Stat(filename); err != nil || stat.Size() == 0 { + if !*download { + b.Fatal("test data not found; skipping benchmark without the -download flag") + } + // Download the official snappy C++ implementation reference test data + // files for benchmarking. + if err := os.Mkdir("testdata", 0777); err != nil && !os.IsExist(err) { + b.Fatalf("failed to create testdata: %s", err) + } + for _, tf := range testFiles { + if err := downloadTestdata(tf.filename); err != nil { + b.Fatalf("failed to download testdata: %s", err) + } + } + } + data := readFile(b, filename) + if decode { + benchDecode(b, data) + } else { + benchEncode(b, data) + } +} + +// Naming convention is kept similar to what snappy's C++ implementation uses. +func Benchmark_UFlat0(b *testing.B) { benchFile(b, 0, true) } +func Benchmark_UFlat1(b *testing.B) { benchFile(b, 1, true) } +func Benchmark_UFlat2(b *testing.B) { benchFile(b, 2, true) } +func Benchmark_UFlat3(b *testing.B) { benchFile(b, 3, true) } +func Benchmark_UFlat4(b *testing.B) { benchFile(b, 4, true) } +func Benchmark_UFlat5(b *testing.B) { benchFile(b, 5, true) } +func Benchmark_UFlat6(b *testing.B) { benchFile(b, 6, true) } +func Benchmark_UFlat7(b *testing.B) { benchFile(b, 7, true) } +func Benchmark_UFlat8(b *testing.B) { benchFile(b, 8, true) } +func Benchmark_UFlat9(b *testing.B) { benchFile(b, 9, true) } +func Benchmark_UFlat10(b *testing.B) { benchFile(b, 10, true) } +func Benchmark_UFlat11(b *testing.B) { benchFile(b, 11, true) } +func Benchmark_UFlat12(b *testing.B) { benchFile(b, 12, true) } +func Benchmark_UFlat13(b *testing.B) { benchFile(b, 13, true) } +func Benchmark_UFlat14(b *testing.B) { benchFile(b, 14, true) } +func Benchmark_UFlat15(b *testing.B) { benchFile(b, 15, true) } +func Benchmark_UFlat16(b *testing.B) { benchFile(b, 16, true) } +func Benchmark_UFlat17(b *testing.B) { benchFile(b, 17, true) } +func Benchmark_ZFlat0(b *testing.B) { benchFile(b, 0, false) } +func Benchmark_ZFlat1(b *testing.B) { benchFile(b, 1, false) } +func Benchmark_ZFlat2(b *testing.B) { benchFile(b, 2, false) } +func Benchmark_ZFlat3(b *testing.B) { benchFile(b, 3, false) } +func Benchmark_ZFlat4(b *testing.B) { benchFile(b, 4, false) } +func Benchmark_ZFlat5(b *testing.B) { benchFile(b, 5, false) } +func Benchmark_ZFlat6(b *testing.B) { benchFile(b, 6, false) } +func Benchmark_ZFlat7(b *testing.B) { benchFile(b, 7, false) } +func Benchmark_ZFlat8(b *testing.B) { benchFile(b, 8, false) } +func Benchmark_ZFlat9(b *testing.B) { benchFile(b, 9, false) } +func Benchmark_ZFlat10(b *testing.B) { benchFile(b, 10, false) } +func Benchmark_ZFlat11(b *testing.B) { benchFile(b, 11, false) } +func Benchmark_ZFlat12(b *testing.B) { benchFile(b, 12, false) } +func Benchmark_ZFlat13(b *testing.B) { benchFile(b, 13, false) } +func Benchmark_ZFlat14(b *testing.B) { benchFile(b, 14, false) } +func Benchmark_ZFlat15(b *testing.B) { benchFile(b, 15, false) } +func Benchmark_ZFlat16(b *testing.B) { benchFile(b, 16, false) } +func Benchmark_ZFlat17(b *testing.B) { benchFile(b, 17, false) } diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 0000000..baa92b5 --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,60 @@ +package websocket + +import ( + "bytes" + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +var ( + ErrBadHandshake = errors.New("bad handshake") +) + +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header) (c *Conn, response *http.Response, err error) { + key, err := calcKey() + if err != nil { + return nil, nil, err + } + acceptKey := calcAcceptKey(key) + + c = NewConn(netConn, false) + + buf := bytes.NewBufferString("GET ") + buf.WriteString(u.RequestURI()) + buf.WriteString(" HTTP/1.1\r\nHost: ") + buf.WriteString(u.Host) + buf.WriteString("\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: ") + buf.WriteString(key) + buf.WriteString("\r\n") + + for k, vs := range requestHeader { + for _, v := range vs { + buf.WriteString(k) + buf.WriteString(": ") + buf.WriteString(v) + buf.WriteString("\r\n") + } + } + + buf.WriteString("\r\n") + p := buf.Bytes() + if _, err := netConn.Write(p); err != nil { + return nil, nil, err + } + + resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) + if err != nil { + return nil, nil, err + } + + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != acceptKey { + return nil, resp, ErrBadHandshake + } + return c, resp, nil +} diff --git a/websocket/client_test.go b/websocket/client_test.go new file mode 100644 index 0000000..9b58959 --- /dev/null +++ b/websocket/client_test.go @@ -0,0 +1,99 @@ +package websocket + +import ( + "github.com/gorilla/websocket" + "net" + "net/http" + "net/url" + "testing" + "time" +) + +func TestWSClient(t *testing.T) { + http.HandleFunc("/test/client", func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Upgrade(w, r, nil, 1024, 1024) + if err != nil { + t.Fatal(err.Error()) + } + + msgType, msg, err := conn.ReadMessage() + conn.WriteMessage(websocket.TextMessage, msg) + + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != websocket.TextMessage { + t.Fatal("invalid msg type", msgType) + } + + msgType, msg, err = conn.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != websocket.PingMessage { + t.Fatal("invalid msg type", msgType) + } + + conn.WriteMessage(websocket.PongMessage, []byte{}) + + conn.WriteMessage(websocket.PingMessage, []byte{}) + + msgType, msg, err = conn.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + println(msgType) + if msgType != websocket.PongMessage { + + t.Fatal("invalid msg type", msgType) + } + }) + + go http.ListenAndServe(":65500", nil) + + time.Sleep(time.Second * 1) + + conn, err := net.Dial("tcp", "127.0.0.1:65500") + + if err != nil { + t.Fatal(err.Error()) + } + ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/client"}, nil) + + if err != nil { + t.Fatal(err.Error()) + } + + payload := make([]byte, 4*1024) + for i := 0; i < 4*1024; i++ { + payload[i] = 'x' + } + + ws.WriteString(payload) + + msgType, msg, err := ws.Read() + if err != nil { + t.Fatal(err.Error()) + } + if msgType != TextMessage { + t.Fatal("invalid msg type", msgType) + } + + if string(msg) != string(payload) { + t.Fatal("invalid msg", string(msg)) + + } + + //test ping + ws.Ping([]byte{}) + msgType, msg, err = ws.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + if msgType != PongMessage { + t.Fatal("invalid msg type", msgType) + } + +} diff --git a/websocket/conn.go b/websocket/conn.go new file mode 100644 index 0000000..05d63eb --- /dev/null +++ b/websocket/conn.go @@ -0,0 +1,321 @@ +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "math/rand" + "net" + "time" +) + +//refer RFC6455 + +const ( + TextMessage byte = 1 + BinaryMessage byte = 2 + CloseMessage byte = 8 + PingMessage byte = 9 + PongMessage byte = 10 +) + +var ( + ErrControlTooLong = errors.New("control message too long") + ErrRSVNotSupport = errors.New("reserved bit not support") + ErrPayloadError = errors.New("payload length error") + ErrControlFragmented = errors.New("control message can not be fragmented") + ErrNotTCPConn = errors.New("not a tcp connection") + ErrWriteError = errors.New("write error") +) + +type Conn struct { + conn net.Conn + + br *bufio.Reader + + isServer bool +} + +func NewConn(conn net.Conn, isServer bool) *Conn { + c := new(Conn) + + c.conn = conn + + c.br = bufio.NewReader(conn) + + c.isServer = isServer + + return c +} + +func (c *Conn) ReadMessage() (messageType byte, message []byte, err error) { + return c.Read() +} + +func (c *Conn) Read() (messageType byte, message []byte, err error) { + buf := make([]byte, 8, 8) + + message = []byte{} + + messageType = 0 + + for { + opcode, data, err := c.readFrame(buf) + + if err != nil { + return messageType, message, err + } + + message = append(message, data...) + + if opcode&0x80 != 0 { + //final + if opcode&0x0F > 0 { + //not continue frame + messageType = opcode & 0x0F + } + return messageType, message, nil + + } else { + if opcode&0x0F > 0 { + //first continue frame + messageType = opcode & 0x0F + } + } + } + + return +} + +func (c *Conn) Write(message []byte, binary bool) error { + if binary { + return c.sendFrame(BinaryMessage, message) + } else { + return c.sendFrame(TextMessage, message) + } +} + +func (c *Conn) WriteMessage(messageType byte, message []byte) error { + return c.sendFrame(messageType, message) +} + +//write utf-8 text message +func (c *Conn) WriteString(message []byte) error { + return c.Write(message, false) +} + +//write binary message +func (c *Conn) WriteBinary(message []byte) error { + return c.Write(message, true) +} + +func (c *Conn) Ping(message []byte) error { + return c.sendFrame(PingMessage, message) +} + +func (c *Conn) Pong(message []byte) error { + return c.sendFrame(PongMessage, message) +} + +//close socket, not send websocket close message +func (c *Conn) Close() error { + return c.conn.Close() +} + +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) SetReadBuffer(bytes int) error { + if tcpConn, ok := c.conn.(*net.TCPConn); ok { + return tcpConn.SetReadBuffer(bytes) + } else { + return ErrNotTCPConn + } +} + +func (c *Conn) SetWriteBuffer(bytes int) error { + if tcpConn, ok := c.conn.(*net.TCPConn); ok { + return tcpConn.SetWriteBuffer(bytes) + } else { + return ErrNotTCPConn + } +} + +func (c *Conn) readPayloadLen(length byte, buf []byte) (payloadLen uint64, err error) { + if length < 126 { + payloadLen = uint64(length) + } else if length == 126 { + err = c.read(buf[:2]) + if err != nil { + return + } + payloadLen = uint64(binary.BigEndian.Uint16(buf[:2])) + } else if length == 127 { + err = c.read(buf[:8]) + if err != nil { + return + } + payloadLen = uint64(binary.BigEndian.Uint16(buf[:8])) + } + + return +} + +func (c *Conn) readFrame(buf []byte) (opcode byte, messsage []byte, err error) { + //minimum head may 2 byte + + err = c.read(buf[:2]) + if err != nil { + return + } + + opcode = buf[0] + + if opcode&0x70 > 0 { + err = ErrRSVNotSupport + return + } + + //isMasking := (0x80 & buf[1]) > 0 + isMasking := (0x80 & buf[1]) > 0 + + var payloadLen uint64 + payloadLen, err = c.readPayloadLen(buf[1]&0x7F, buf) + if err != nil { + return + } + + if opcode&0x08 > 0 && payloadLen > 125 { + err = ErrControlTooLong + return + } + + var masking []byte + + if isMasking { + err = c.read(buf[:4]) + if err != nil { + return + } + + masking = buf[:4] + } + + messsage = make([]byte, payloadLen) + err = c.read(messsage) + + if err != nil { + return + } + + if isMasking { + //maskingKey := c.newMaskingKey() + c.maskingData(messsage, masking) + } + + return +} + +func (c *Conn) sendFrame(opcode byte, message []byte) error { + //max frame header may 14 length + buf := make([]byte, 0, len(message)+14) + //here we don not support continue frame, all are final + opcode |= 0x80 + + if opcode&0x08 > 0 && len(message) >= 126 { + return ErrControlTooLong + } + + buf = append(buf, opcode) + + //no mask, because chrome may not support + var mask byte = 0x00 + + if !c.isServer { + //for client, we will mask data + mask = 0x80 + } + + payloadLen := len(message) + + if payloadLen < 126 { + buf = append(buf, mask|byte(payloadLen)) + } else if payloadLen <= 0xFFFF { + buf = append(buf, mask|byte(126), 0, 0) + + binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(payloadLen)) + } else { + buf = append(buf, mask|byte(127), 0, 0, 0, 0, 0, 0, 0, 0) + + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(payloadLen)) + } + + if !c.isServer { + maskingKey := c.newMaskingKey() + buf = append(buf, maskingKey...) + + pos := len(buf) + buf = append(buf, message...) + + c.maskingData(buf[pos:], maskingKey) + + } else { + buf = append(buf, message...) + } + + tmpBuf := buf + for i := 0; i < 3; i++ { + n, err := c.conn.Write(tmpBuf) + if err != nil { + return err + } + if n == len(tmpBuf) { + return nil + } else { + tmpBuf = tmpBuf[n:] + } + } + return ErrWriteError +} + +func (c *Conn) read(buf []byte) error { + var err error + for len(buf) > 0 && err == nil { + var nn int + nn, err = c.br.Read(buf) + buf = buf[nn:] + } + if err == io.EOF { + if len(buf) == 0 { + err = nil + } else { + err = io.ErrUnexpectedEOF + } + } + return err +} + +func (c *Conn) maskingData(data []byte, maskingKey []byte) { + for i := range data { + data[i] ^= maskingKey[i%4] + } +} + +func (c *Conn) newMaskingKey() []byte { + n := rand.Uint32() + return []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} +} diff --git a/websocket/pingpong_test.go b/websocket/pingpong_test.go new file mode 100644 index 0000000..e188b82 --- /dev/null +++ b/websocket/pingpong_test.go @@ -0,0 +1,51 @@ +package websocket + +import ( + "net" + "net/http" + "net/url" + "testing" + "time" +) + +func TestWSPing(t *testing.T) { + http.HandleFunc("/test/ping", func(w http.ResponseWriter, r *http.Request) { + conn, err := Upgrade(w, r, nil) + if err != nil { + t.Fatal(err.Error()) + } + //conn := NewConn(c, true) + conn.Read() + conn.Pong([]byte{}) + conn.Ping([]byte{}) + msgType, _, _ := conn.Read() + println(msgType) + }) + + go http.ListenAndServe(":65500", nil) + time.Sleep(time.Second * 1) + + conn, err := net.Dial("tcp", "127.0.0.1:65500") + + if err != nil { + t.Fatal(err.Error()) + } + ws, _, err := NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/ping"}, nil) + + if err != nil { + t.Fatal(err.Error()) + } + ws.Ping([]byte{}) + + msgType, _, _ := ws.Read() + if msgType != PongMessage { + t.Fatal("invalid msg type", msgType) + } + + msgType, _, _ = ws.Read() + if msgType != PingMessage { + t.Fatal("invalid msg type", msgType) + } + ws.Pong([]byte{}) + time.Sleep(time.Second * 1) +} diff --git a/websocket/server.go b/websocket/server.go new file mode 100644 index 0000000..360233e --- /dev/null +++ b/websocket/server.go @@ -0,0 +1,105 @@ +package websocket + +import ( + "bufio" + "bytes" + "errors" + "net" + "net/http" + "strings" +) + +var ( + ErrInvalidMethod = errors.New("Only GET Supported") + ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13") + ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"") + ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"") + ErrMissingKey = errors.New("Missing Key") + ErrHijacker = errors.New("Not implement http.Hijacker") + ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty") +) + +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + if r.Method != "GET" { + return nil, ErrInvalidMethod + } + + if r.Header.Get("Sec-Websocket-Version") != "13" { + return nil, ErrInvalidVersion + } + + if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { + return nil, ErrInvalidUpgrade + } + + if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { + return nil, ErrInvalidConnection + } + + var acceptKey string + + if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 { + return nil, ErrMissingKey + } else { + acceptKey = calcAcceptKey(key) + } + + var ( + netConn net.Conn + br *bufio.Reader + err error + ) + + h, ok := w.(http.Hijacker) + if !ok { + return nil, ErrHijacker + } + + var rw *bufio.ReadWriter + netConn, rw, err = h.Hijack() + br = rw.Reader + + if br.Buffered() > 0 { + netConn.Close() + return nil, ErrNoEmptyConn + } + + c := NewConn(netConn, true) + + buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") + + buf.WriteString(acceptKey) + buf.WriteString("\r\n") + + subProtol := selectSubProtocol(r) + if len(subProtol) > 0 { + buf.WriteString("Sec-Websocket-Protocol: ") + buf.WriteString(subProtol) + buf.WriteString("\r\n") + } + + for k, vs := range responseHeader { + for _, v := range vs { + buf.WriteString(k) + buf.WriteString(": ") + buf.WriteString(v) + buf.WriteString("\r\n") + } + } + buf.WriteString("\r\n") + + if _, err = netConn.Write(buf.Bytes()); err != nil { + netConn.Close() + return nil, err + } + + return c, nil +} + +func selectSubProtocol(r *http.Request) string { + h := r.Header.Get("Sec-Websocket-Protocol") + if len(h) == 0 { + return "" + } + return strings.Split(h, ",")[0] +} diff --git a/websocket/server_test.go b/websocket/server_test.go new file mode 100644 index 0000000..73eba96 --- /dev/null +++ b/websocket/server_test.go @@ -0,0 +1,97 @@ +package websocket + +import ( + "github.com/gorilla/websocket" + "net" + "net/http" + "net/url" + "testing" + "time" +) + +func TestWSServer(t *testing.T) { + http.HandleFunc("/test/server", func(w http.ResponseWriter, r *http.Request) { + conn, err := Upgrade(w, r, nil) + + if err != nil { + t.Fatal(err.Error()) + } + //err = conn.SetReadBuffer(1024 * 1024 * 4) + //if err != nil { + // println(err.Error()) + //} + //err = conn.SetWriteBuffer(1024 * 1024 * 4) + + //if err != nil { + // println(err.Error()) + //} + + msgType, msg, err := conn.Read() + conn.Write(msg, false) + + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != TextMessage { + t.Fatal("wrong msg type", msgType) + } + + msgType, msg, err = conn.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + + if msgType != PingMessage { + t.Fatal("wrong msg type", msgType) + } + + err = conn.Pong([]byte("abc")) + + if err != nil { + t.Fatal(err.Error()) + } + + }) + + go http.ListenAndServe(":65500", nil) + time.Sleep(time.Second * 1) + + conn, err := net.Dial("tcp", "127.0.0.1:65500") + + if err != nil { + t.Fatal(err.Error()) + } + ws, _, err := websocket.NewClient(conn, &url.URL{Host: "127.0.0.1:65500", Path: "/test/server"}, nil, 1024, 1024) + + ws.SetPongHandler(func(string) error { + println("pong") + return nil + }) + + if err != nil { + t.Fatal(err.Error()) + } + + payload := make([]byte, 4*1024*1024) + for i := 0; i < 4*1024*1024; i++ { + payload[i] = 'x' + } + + ws.WriteMessage(websocket.TextMessage, payload) + + msgType, msg, err := ws.ReadMessage() + if err != nil { + t.Fatal(err.Error()) + } + if msgType != websocket.TextMessage { + t.Fatal("invalid msg type", msgType) + } + + if string(msg) != string(payload) { + t.Fatal("invalid msg", string(msg)) + + } + + time.Sleep(time.Second * 1) +} diff --git a/websocket/util.go b/websocket/util.go new file mode 100644 index 0000000..9ad5ec9 --- /dev/null +++ b/websocket/util.go @@ -0,0 +1,36 @@ +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "errors" + "io" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func calcAcceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func calcKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +func HandleCloseFrame(buf []byte) (int16, string, error) { + + if len(buf) < 2 { + return 0, "", errors.New("close frame msg's length less than 2") + } + code := int16(buf[0])<<8 + int16(buf[1]) + reason := string(buf[2:]) + return code, reason, nil +}