av/protocol/rtsp/rtsp_test.go

276 lines
8.4 KiB
Go
Raw Normal View History

/*
NAME
2019-04-25 09:00:28 +03:00
rtsp_test.go
DESCRIPTION
2019-04-25 09:00:28 +03:00
rtsp_test.go provides a test to check functionality provided in rtsp.go.
AUTHORS
Saxon A. Nelson-Milton <saxon@ausocean.org>
LICENSE
This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean).
It is free software: you can redistribute it and/or modify them
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 3 of the License, or (at your
option) any later version.
It is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
for more details.
You should have received a copy of the GNU General Public License
in gpl.txt. If not, see http://www.gnu.org/licenses.
*/
package rtsp
import (
"errors"
"fmt"
"io"
"net"
"net/url"
"strings"
"testing"
"time"
"unicode"
)
2019-04-25 09:00:28 +03:00
// TestMethods checks that we can correctly form requests for each of the RTSP
// methods supported in the rtsp pkg. This test also checks that communication
// over a TCP connection is performed correctly.
func TestMethods(t *testing.T) {
2019-04-25 09:00:28 +03:00
const dummyURL = "rtsp://admin:admin@192.168.0.50:8554/CH001.sdp"
url, err := url.Parse(dummyURL)
if err != nil {
t.Fatalf("could not parse dummy address, failed with err: %v", err)
}
2019-04-25 09:00:28 +03:00
// tests holds tests which consist of a function used to create and write a
// request of a particular method, and also the expected request bytes
// to be received on the server side. The bytes in these tests have been
// obtained from a valid RTSP communication cltion..
tests := []struct {
method func(c *Client) (*Response, error)
2019-04-25 09:00:28 +03:00
expected []byte
}{
{
method: func(c *Client) (*Response, error) {
req, err := NewRequest(describe, c.nextCSeq(), url, nil)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/sdp")
return c.Do(req)
},
expected: []byte{
2019-04-25 09:00:28 +03:00
0x44, 0x45, 0x53, 0x43, 0x52, 0x49, 0x42, 0x45, 0x20, 0x72, 0x74, 0x73,
0x70, 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64,
0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e,
0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48,
0x30, 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50,
0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20,
0x32, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x3a, 0x20, 0x61,
0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73,
0x64, 0x70, 0x0d, 0x0a, 0x0d, 0x0a,
},
},
{
method: func(c *Client) (*Response, error) {
req, err := NewRequest(options, c.nextCSeq(), url, nil)
if err != nil {
return nil, err
}
return c.Do(req)
},
expected: []byte{
0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x53, 0x20, 0x72, 0x74, 0x73, 0x70,
0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d,
0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x30,
0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30,
0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f,
0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x31,
0x0d, 0x0a, 0x0d, 0x0a,
},
},
{
method: func(c *Client) (*Response, error) {
u, err := url.Parse(dummyURL + "/track1")
if err != nil {
return nil, err
}
req, err := NewRequest(setup, c.nextCSeq(), u, nil)
if err != nil {
return nil, err
}
req.Header.Add("Transport", fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d", 6870, 6871))
return c.Do(req)
},
expected: []byte{
0x53, 0x45, 0x54, 0x55, 0x50, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f,
0x2f, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e,
0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35,
0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31,
0x2e, 0x73, 0x64, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x6b, 0x31, 0x20,
0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x54, 0x72,
0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x3a, 0x20, 0x52, 0x54, 0x50,
0x2f, 0x41, 0x56, 0x50, 0x3b, 0x75, 0x6e, 0x69, 0x63, 0x61, 0x73, 0x74,
0x3b, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74,
0x3d, 0x36, 0x38, 0x37, 0x30, 0x2d, 0x36, 0x38, 0x37, 0x31, 0x0d, 0x0a,
0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x33, 0x0d, 0x0a, 0x0d, 0x0a,
},
},
{
method: func(c *Client) (*Response, error) {
req, err := NewRequest(play, c.nextCSeq(), url, nil)
if err != nil {
return nil, err
}
req.Header.Add("Session", "00000021")
return c.Do(req)
},
expected: []byte{
0x50, 0x4c, 0x41, 0x59, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f,
0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40,
0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30,
0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, 0x2e,
0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30,
0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x34, 0x0d, 0x0a, 0x53,
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x30, 0x30, 0x30, 0x30,
0x30, 0x30, 0x32, 0x31, 0x0d, 0x0a, 0x0d, 0x0a,
},
},
}
const serverAddr = "rtsp://localhost:8005"
const retries = 10
2019-04-25 09:00:28 +03:00
clientErr := make(chan error)
serverErr := make(chan error)
done := make(chan struct{})
2019-04-25 09:00:28 +03:00
// This routine acts as the server.
go func() {
l, err := net.Listen("tcp", strings.TrimLeft(serverAddr, "rtsp://"))
if err != nil {
serverErr <- errors.New(fmt.Sprintf("server could not listen, error: %v", err))
}
conn, err := l.Accept()
if err != nil {
serverErr <- errors.New(fmt.Sprintf("server could not accept connection, error: %v", err))
}
buf := make([]byte, 1024)
var n int
for i, test := range tests {
loop:
for {
n, err = conn.Read(buf)
err, ok := err.(net.Error)
2019-04-25 09:00:28 +03:00
switch {
case err == nil:
break loop
case err == io.EOF:
case ok && err.Timeout():
default:
serverErr <- errors.New(fmt.Sprintf("server could not read conn, error: %v", err))
return
}
}
2019-04-25 09:00:28 +03:00
// Write a dummy response, client won't care.
conn.Write([]byte{'\n'})
2019-04-25 09:00:28 +03:00
want := test.expected
got := buf[:n]
if !equal(got, want) {
serverErr <- errors.New(fmt.Sprintf("unexpected result for test: %v. \nGot: %v\n Want: %v\n", i, got, want))
}
}
close(done)
}()
2019-04-25 09:00:28 +03:00
// This routine acts as the client.
go func() {
var clt *Client
var err error
2019-04-25 09:00:28 +03:00
// Keep trying to connect to server.
for retry := 0; ; retry++ {
clt, err = NewClient(serverAddr)
if err == nil {
break
}
2019-04-25 09:00:28 +03:00
if retry > retries {
clientErr <- errors.New(fmt.Sprintf("client could not connect to server, error: %v", err))
}
time.Sleep(10 * time.Millisecond)
}
2019-04-25 09:00:28 +03:00
for i, test := range tests {
_, err = test.method(clt)
2019-04-25 09:00:28 +03:00
if err != nil && err != io.EOF && err != errSmallResponse {
clientErr <- errors.New(fmt.Sprintf("error request for: %v err: %v", i, err))
}
}
}()
2019-04-25 09:00:28 +03:00
// We check for errors or a done signal from the server and client routines.
for {
select {
case err := <-clientErr:
t.Fatalf("client error: %v", err)
case err := <-serverErr:
t.Fatalf("server error: %v", err)
case <-done:
return
default:
}
}
}
2019-04-25 09:00:28 +03:00
// equal checks that the got slice is considered equivalent to the want slice,
// neglecting unimportant differences such as order of items in header and the
// CSeq number.
func equal(got, want []byte) bool {
const eol = "\r\n"
gotParts := strings.Split(strings.TrimRight(string(got), eol), eol)
wantParts := strings.Split(strings.TrimRight(string(want), eol), eol)
gotParts, ok := rmSeqNum(gotParts)
if !ok {
return false
}
wantParts, ok = rmSeqNum(wantParts)
if !ok {
return false
}
for _, gotStr := range gotParts {
for i, wantStr := range wantParts {
if gotStr == wantStr {
wantParts = append(wantParts[:i], wantParts[i+1:]...)
}
}
}
return len(wantParts) == 0
}
2019-04-25 09:00:28 +03:00
// rmSeqNum removes the CSeq number from a string in []string that contains it.
// If a CSeq field is not found nil and false is returned.
func rmSeqNum(s []string) ([]string, bool) {
for i, _s := range s {
if strings.Contains(_s, "CSeq") {
s[i] = strings.TrimFunc(s[i], func(r rune) bool { return unicode.IsNumber(r) })
return s, true
}
}
return nil, false
}