diff --git a/protocol/rtsp/rtsp.go b/protocol/rtsp/rtsp.go index 006d2375..ebe5d978 100644 --- a/protocol/rtsp/rtsp.go +++ b/protocol/rtsp/rtsp.go @@ -34,7 +34,7 @@ func NewSession() *Session { } func (s *Session) Describe(urlStr string) (*Response, error) { - return s.writeRequest(describe, urlStr, func(req *request) { req.Header.Add("Accept", "application/sdp") }, nil) + return s.writeRequest(describe, urlStr, func(req *Request) { req.Header.Add("Accept", "application/sdp") }, nil) } func (s *Session) Options(urlStr string) (*Response, error) { @@ -45,7 +45,7 @@ func (s *Session) Setup(urlStr, transport string) (*Response, error) { return s.writeRequest( setup, urlStr, - func(req *request) { + func(req *Request) { req.Header.Add("Transport", transport) }, func(resp *Response) { @@ -55,25 +55,15 @@ func (s *Session) Setup(urlStr, transport string) (*Response, error) { } func (s *Session) Play(urlStr string) (*Response, error) { - return s.writeRequest(play, urlStr, func(req *request) { req.Header.Add("Session", s.session) }, nil) + return s.writeRequest(play, urlStr, func(req *Request) { req.Header.Add("Session", s.session) }, nil) } -func (s *Session) writeRequest(method, urlStr string, headerOp func(*request), respOp func(*Response)) (*Response, error) { - u, err := url.Parse(urlStr) +func (s *Session) writeRequest(method, urlStr string, headerOp func(*Request), respOp func(*Response)) (*Response, error) { + req, err := NewRequest(method, urlStr, s.nextCSeq(), nil) if err != nil { return nil, err } - req := (*request)(&http.Request{ - Method: method, - URL: u, - Proto: "RTSP", - ProtoMajor: 1, - ProtoMinor: 0, - Header: map[string][]string{"CSeq": []string{s.nextCSeq()}}, - Body: nil, - }) - if headerOp != nil { headerOp(req) } @@ -85,7 +75,7 @@ func (s *Session) writeRequest(method, urlStr string, headerOp func(*request), r } } - _, err = io.WriteString(s.conn, req.String()) + _, err = io.WriteString(s.conn, (*Request)(req).String()) if err != nil { return nil, err } @@ -106,9 +96,18 @@ func (s *Session) nextCSeq() string { return strconv.Itoa(s.cSeq) } -type request http.Request +type Request struct { + Method string + URL *url.URL + Proto string + ProtoMajor int + ProtoMinor int + Header http.Header + ContentLength int + Body io.ReadCloser +} -func (r *request) String() string { +func (r Request) String() string { s := fmt.Sprintf("%s %s %s/%d.%d\r\n", r.Method, r.URL, r.Proto, r.ProtoMajor, r.ProtoMinor) for k, v := range r.Header { for _, v := range v { @@ -123,36 +122,22 @@ func (r *request) String() string { return s } -type closer struct { - *bufio.Reader - r io.Reader -} +func NewRequest(method, urlStr, cSeq string, body io.ReadCloser) (*Request, error) { + u, err := url.Parse(urlStr) + if err != nil { + return nil, err + } -func (c closer) Close() error { - if c.Reader == nil { - return nil + req := &Request{ + Method: method, + URL: u, + Proto: "RTSP", + ProtoMajor: 1, + ProtoMinor: 0, + Header: map[string][]string{"CSeq": []string{cSeq}}, + Body: body, } - defer func() { - c.Reader = nil - c.r = nil - }() - if r, ok := c.r.(io.ReadCloser); ok { - return r.Close() - } - return nil -} - -func ParseRTSPVersion(s string) (proto string, major int, minor int, err error) { - parts := strings.SplitN(s, "/", 2) - proto = parts[0] - parts = strings.SplitN(parts[1], ".", 2) - if major, err = strconv.Atoi(parts[0]); err != nil { - return - } - if minor, err = strconv.Atoi(parts[0]); err != nil { - return - } - return + return req, nil } type Response struct { @@ -220,3 +205,35 @@ func ReadResponse(r io.Reader) (res *Response, err error) { res.Body = closer{b, r} return } + +type closer struct { + *bufio.Reader + r io.Reader +} + +func (c closer) Close() error { + if c.Reader == nil { + return nil + } + defer func() { + c.Reader = nil + c.r = nil + }() + if r, ok := c.r.(io.ReadCloser); ok { + return r.Close() + } + return nil +} + +func ParseRTSPVersion(s string) (proto string, major int, minor int, err error) { + parts := strings.SplitN(s, "/", 2) + proto = parts[0] + parts = strings.SplitN(parts[1], ".", 2) + if major, err = strconv.Atoi(parts[0]); err != nil { + return + } + if minor, err = strconv.Atoi(parts[0]); err != nil { + return + } + return +}