diff --git a/protocol/rtsp/rtsp.go b/protocol/rtsp/rtsp.go index 90dd806e..4f11d134 100644 --- a/protocol/rtsp/rtsp.go +++ b/protocol/rtsp/rtsp.go @@ -57,6 +57,7 @@ package rtsp import ( "bufio" + "errors" "fmt" "io" "io/ioutil" @@ -78,6 +79,10 @@ const ( setup = "SETUP" ) +const minResponse = 15 + +var ErrSmallResponse = errors.New("response too small") + // Session describes an RTSP Session. type Session struct { cSeq int @@ -242,25 +247,24 @@ func (r Response) String() string { func ReadResponse(r io.Reader) (res *Response, err error) { res = new(Response) res.Header = make(map[string][]string) - b := bufio.NewReader(r) var s string - // TODO: allow CR, LF, or CRLF if s, err = b.ReadString('\n'); err != nil { + return } - + if len(s) < minResponse { + return nil, ErrSmallResponse + } parts := strings.SplitN(s, " ", 3) res.Proto, res.ProtoMajor, res.ProtoMinor, err = parseVersion(parts[0]) if err != nil { return } - if res.StatusCode, err = strconv.Atoi(parts[1]); err != nil { return } - res.Status = strings.TrimSpace(parts[2]) // read headers diff --git a/protocol/rtsp/rtsp_test.go b/protocol/rtsp/rtsp_test.go index 6fb93012..c70dfdf6 100644 --- a/protocol/rtsp/rtsp_test.go +++ b/protocol/rtsp/rtsp_test.go @@ -28,21 +28,215 @@ LICENSE package rtsp import ( + "errors" + "fmt" + "io" + "net" + "net/url" + "strings" "testing" + "time" + "unicode" ) -func TestDescribe(t *testing.T) { +func TestMethods(t *testing.T) { + const dummyAddr = "rtsp://admin:admin@192.168.0.50:8554/CH001.sdp" + dummyURL, err := url.Parse(dummyAddr) + if err != nil { + t.Fatalf("could not parse dummy address, failed with err: %v", err) + } + tests := []struct { + method func(s *Session) (*Response, error) + serverRes []byte + expected []byte + }{ + { + method: func(s *Session) (*Response, error) { + return s.writeRequest(dummyURL, describe, func(req *Request) { req.Header.Add("Accept", "application/sdp") }, nil) + }, + expected: []byte{ + 0x44, 0x45, 0x53, 0x43, 0x52, 0x49, 0x42, 0x45, 0x20, 0x72, 0x74, + 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, + 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, + 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, + 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, + 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, + 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x32, 0x0d, 0x0a, 0x41, 0x63, + 0x63, 0x65, 0x70, 0x74, 0x3a, 0x20, 0x61, 0x70, 0x70, 0x6c, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, 0x64, 0x70, 0x0d, + 0x0a, 0x0d, 0x0a, + }, + }, + { + method: func(s *Session) (*Response, error) { + return s.writeRequest(dummyURL, options, nil, nil) + }, + expected: []byte{ + 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x53, 0x20, 0x72, 0x74, 0x73, 0x70, + 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, + 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x30, + 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, + 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, + 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x31, + 0x0d, 0x0a, 0x0d, 0x0a, + }, + }, + { + method: func(s *Session) (*Response, error) { + url, err := url.Parse(dummyAddr + "/track1") + if err != nil { + t.Fatalf("could not parse url with track, failed with err: %v", err) + } + return s.writeRequest( + url, + setup, + func(req *Request) { + req.Header.Add("Transport", fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d", 6870, 6871)) + }, + nil, + ) + }, + expected: []byte{ + 0x53, 0x45, 0x54, 0x55, 0x50, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, + 0x2f, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, + 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, + 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, + 0x2e, 0x73, 0x64, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x6b, 0x31, 0x20, + 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x54, 0x72, + 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x3a, 0x20, 0x52, 0x54, 0x50, + 0x2f, 0x41, 0x56, 0x50, 0x3b, 0x75, 0x6e, 0x69, 0x63, 0x61, 0x73, 0x74, + 0x3b, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, + 0x3d, 0x36, 0x38, 0x37, 0x30, 0x2d, 0x36, 0x38, 0x37, 0x31, 0x0d, 0x0a, + 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x33, 0x0d, 0x0a, 0x0d, 0x0a, + }, + }, + { + method: func(s *Session) (*Response, error) { + return s.writeRequest(dummyURL, play, func(req *Request) { req.Header.Add("Session", "00000021") }, nil) + }, + expected: []byte{ + 0x50, 0x4c, 0x41, 0x59, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, + 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, + 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30, + 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, 0x2e, + 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, + 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x34, 0x0d, 0x0a, 0x53, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x32, 0x31, 0x0d, 0x0a, 0x0d, 0x0a, + }, + }, + } + + const serverAddr = "rtsp://localhost:8005" + const retries = 10 + clientErr := make(chan error) + serverErr := make(chan error) + done := make(chan struct{}) + + // start server + go func() { + l, err := net.Listen("tcp", strings.TrimLeft(serverAddr, "rtsp://")) + if err != nil { + serverErr <- errors.New(fmt.Sprintf("server could not listen, error: %v", err)) + } + + conn, err := l.Accept() + if err != nil { + serverErr <- errors.New(fmt.Sprintf("server could not accept connection, error: %v", err)) + } + + buf := make([]byte, 1024) + var n int + for i, test := range tests { + loop: + for { + n, err = conn.Read(buf) + err, ok := err.(net.Error) + switch { + case err == nil: + break loop + case err == io.EOF: + case ok && err.Timeout(): + default: + serverErr <- errors.New(fmt.Sprintf("server could not read conn, error: %v", err)) + return + } + } + conn.Write([]byte{'\n'}) + want := test.expected + got := buf[:n] + if !equal(got, want) { + serverErr <- errors.New(fmt.Sprintf("unexpected result for test: %v. \nGot: %v\n Want: %v\n", i, got, want)) + } + } + close(done) + }() + + // start client + go func() { + var sess *Session + var err error + for retry := 0; ; retry++ { + sess, err = NewSession(serverAddr) + if err == nil { + break + } + if retry > 10 { + clientErr <- errors.New(fmt.Sprintf("client could not connect to server, error: %v", err)) + } + time.Sleep(10 * time.Millisecond) + } + for i, test := range tests { + _, err = test.method(sess) + if err != nil && err != io.EOF && err != ErrSmallResponse { + clientErr <- errors.New(fmt.Sprintf("error request for: %v err: %v", i, err)) + } + } + }() + + // start error checking + for { + select { + case err := <-clientErr: + t.Fatalf("client error: %v", err) + case err := <-serverErr: + t.Fatalf("server error: %v", err) + case <-done: + return + default: + } + } } -func TestOptions(t *testing.T) { - +func equal(got, want []byte) bool { + const eol = "\r\n" + gotParts := strings.Split(strings.TrimRight(string(got), eol), eol) + wantParts := strings.Split(strings.TrimRight(string(want), eol), eol) + gotParts, ok := rmSeqNum(gotParts) + if !ok { + return false + } + wantParts, ok = rmSeqNum(wantParts) + if !ok { + return false + } + for _, gotStr := range gotParts { + for i, wantStr := range wantParts { + if gotStr == wantStr { + wantParts = append(wantParts[:i], wantParts[i+1:]...) + } + } + } + return len(wantParts) == 0 } -func TestSetup(t *testing.T) { - -} - -func TestPlay(t *testing.T) { - +func rmSeqNum(s []string) ([]string, bool) { + for i, _s := range s { + if strings.Contains(_s, "CSeq") { + s[i] = strings.TrimFunc(s[i], func(r rune) bool { return unicode.IsNumber(r) }) + return s, true + } + } + return nil, false }