av/protocol/rtsp/rtsp.go

240 lines
4.5 KiB
Go

package rtsp
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
"strings"
)
// rtsp protocol described by:
// https://tools.ietf.org/html/rfc7826
// RTSP methods.
const (
describe = "DESCRIBE"
options = "OPTIONS"
play = "PLAY"
setup = "SETUP"
)
type Session struct {
cSeq int
conn net.Conn
session string
}
func NewSession() *Session {
return &Session{}
}
func (s *Session) Describe(urlStr string) (*Response, error) {
return s.writeRequest(describe, urlStr, func(req *Request) { req.Header.Add("Accept", "application/sdp") }, nil)
}
func (s *Session) Options(urlStr string) (*Response, error) {
return s.writeRequest(options, urlStr, nil, nil)
}
func (s *Session) Setup(urlStr, transport string) (*Response, error) {
return s.writeRequest(
setup,
urlStr,
func(req *Request) {
req.Header.Add("Transport", transport)
},
func(resp *Response) {
s.session = resp.Header.Get("Session")
},
)
}
func (s *Session) Play(urlStr string) (*Response, error) {
return s.writeRequest(play, urlStr, func(req *Request) { req.Header.Add("Session", s.session) }, nil)
}
func (s *Session) writeRequest(method, urlStr string, headerOp func(*Request), respOp func(*Response)) (*Response, error) {
req, err := NewRequest(method, urlStr, s.nextCSeq(), nil)
if err != nil {
return nil, err
}
if headerOp != nil {
headerOp(req)
}
if s.conn == nil {
s.conn, err = net.Dial("tcp", req.URL.Host)
if err != nil {
return nil, err
}
}
_, err = io.WriteString(s.conn, (*Request)(req).String())
if err != nil {
return nil, err
}
res, err := ReadResponse(s.conn)
if err != nil {
return nil, err
}
if respOp != nil {
respOp(res)
}
return res, nil
}
func (s *Session) nextCSeq() string {
s.cSeq++
return strconv.Itoa(s.cSeq)
}
type Request struct {
Method string
URL *url.URL
Proto string
ProtoMajor int
ProtoMinor int
Header http.Header
ContentLength int
Body io.ReadCloser
}
func (r Request) String() string {
s := fmt.Sprintf("%s %s %s/%d.%d\r\n", r.Method, r.URL, r.Proto, r.ProtoMajor, r.ProtoMinor)
for k, v := range r.Header {
for _, v := range v {
s += fmt.Sprintf("%s: %s\r\n", k, v)
}
}
s += "\r\n"
if r.Body != nil {
str, _ := ioutil.ReadAll(r.Body)
s += string(str)
}
return s
}
func NewRequest(method, urlStr, cSeq string, body io.ReadCloser) (*Request, error) {
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
req := &Request{
Method: method,
URL: u,
Proto: "RTSP",
ProtoMajor: 1,
ProtoMinor: 0,
Header: map[string][]string{"CSeq": []string{cSeq}},
Body: body,
}
return req, nil
}
type Response struct {
Proto string
ProtoMajor int
ProtoMinor int
StatusCode int
Status string
ContentLength int64
Header http.Header
Body io.ReadCloser
}
func (res Response) String() string {
s := fmt.Sprintf("%s/%d.%d %d %s\n", res.Proto, res.ProtoMajor, res.ProtoMinor, res.StatusCode, res.Status)
for k, v := range res.Header {
for _, v := range v {
s += fmt.Sprintf("%s: %s\n", k, v)
}
}
return s
}
func ReadResponse(r io.Reader) (res *Response, err error) {
res = new(Response)
res.Header = make(map[string][]string)
b := bufio.NewReader(r)
var s string
// TODO: allow CR, LF, or CRLF
if s, err = b.ReadString('\n'); err != nil {
return
}
parts := strings.SplitN(s, " ", 3)
res.Proto, res.ProtoMajor, res.ProtoMinor, err = ParseRTSPVersion(parts[0])
if err != nil {
return
}
if res.StatusCode, err = strconv.Atoi(parts[1]); err != nil {
return
}
res.Status = strings.TrimSpace(parts[2])
// read headers
for {
if s, err = b.ReadString('\n'); err != nil {
return
} else if s = strings.TrimRight(s, "\r\n"); s == "" {
break
}
parts := strings.SplitN(s, ":", 2)
res.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
res.ContentLength, _ = strconv.ParseInt(res.Header.Get("Content-Length"), 10, 64)
res.Body = closer{b, r}
return
}
type closer struct {
*bufio.Reader
r io.Reader
}
func (c closer) Close() error {
if c.Reader == nil {
return nil
}
defer func() {
c.Reader = nil
c.r = nil
}()
if r, ok := c.r.(io.ReadCloser); ok {
return r.Close()
}
return nil
}
func ParseRTSPVersion(s string) (proto string, major int, minor int, err error) {
parts := strings.SplitN(s, "/", 2)
proto = parts[0]
parts = strings.SplitN(parts[1], ".", 2)
if major, err = strconv.Atoi(parts[0]); err != nil {
return
}
if minor, err = strconv.Atoi(parts[0]); err != nil {
return
}
return
}