507 lines
13 KiB
Go
507 lines
13 KiB
Go
package ws
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strconv"
|
||
|
||
"github.com/gobwas/httphead"
|
||
)
|
||
|
||
const (
|
||
crlf = "\r\n"
|
||
colonAndSpace = ": "
|
||
commaAndSpace = ", "
|
||
)
|
||
|
||
const (
|
||
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||
)
|
||
|
||
var (
|
||
textHeadBadRequest = statusText(http.StatusBadRequest)
|
||
textHeadInternalServerError = statusText(http.StatusInternalServerError)
|
||
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
|
||
|
||
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
|
||
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
|
||
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
|
||
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
|
||
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
|
||
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
|
||
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
|
||
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
|
||
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
|
||
)
|
||
|
||
const (
|
||
// Every new header must be added to TestHeaderNames test.
|
||
headerHost = "Host"
|
||
headerUpgrade = "Upgrade"
|
||
headerConnection = "Connection"
|
||
headerSecVersion = "Sec-WebSocket-Version"
|
||
headerSecProtocol = "Sec-WebSocket-Protocol"
|
||
headerSecExtensions = "Sec-WebSocket-Extensions"
|
||
headerSecKey = "Sec-WebSocket-Key"
|
||
headerSecAccept = "Sec-WebSocket-Accept"
|
||
|
||
headerHostCanonical = headerHost
|
||
headerUpgradeCanonical = headerUpgrade
|
||
headerConnectionCanonical = headerConnection
|
||
headerSecVersionCanonical = "Sec-Websocket-Version"
|
||
headerSecProtocolCanonical = "Sec-Websocket-Protocol"
|
||
headerSecExtensionsCanonical = "Sec-Websocket-Extensions"
|
||
headerSecKeyCanonical = "Sec-Websocket-Key"
|
||
headerSecAcceptCanonical = "Sec-Websocket-Accept"
|
||
)
|
||
|
||
var (
|
||
specHeaderValueUpgrade = []byte("websocket")
|
||
specHeaderValueConnection = []byte("Upgrade")
|
||
specHeaderValueConnectionLower = []byte("upgrade")
|
||
specHeaderValueSecVersion = []byte("13")
|
||
)
|
||
|
||
var (
|
||
httpVersion1_0 = []byte("HTTP/1.0")
|
||
httpVersion1_1 = []byte("HTTP/1.1")
|
||
httpVersionPrefix = []byte("HTTP/")
|
||
)
|
||
|
||
type httpRequestLine struct {
|
||
method, uri []byte
|
||
major, minor int
|
||
}
|
||
|
||
type httpResponseLine struct {
|
||
major, minor int
|
||
status int
|
||
reason []byte
|
||
}
|
||
|
||
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
|
||
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
|
||
var proto []byte
|
||
req.method, req.uri, proto = bsplit3(line, ' ')
|
||
|
||
var ok bool
|
||
req.major, req.minor, ok = httpParseVersion(proto)
|
||
if !ok {
|
||
err = ErrMalformedRequest
|
||
}
|
||
return req, err
|
||
}
|
||
|
||
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
|
||
var (
|
||
proto []byte
|
||
status []byte
|
||
)
|
||
proto, status, resp.reason = bsplit3(line, ' ')
|
||
|
||
var ok bool
|
||
resp.major, resp.minor, ok = httpParseVersion(proto)
|
||
if !ok {
|
||
return resp, ErrMalformedResponse
|
||
}
|
||
|
||
var convErr error
|
||
resp.status, convErr = asciiToInt(status)
|
||
if convErr != nil {
|
||
return resp, ErrMalformedResponse
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// httpParseVersion parses major and minor version of HTTP protocol. It returns
|
||
// parsed values and true if parse is ok.
|
||
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
|
||
switch {
|
||
case bytes.Equal(bts, httpVersion1_0):
|
||
return 1, 0, true
|
||
case bytes.Equal(bts, httpVersion1_1):
|
||
return 1, 1, true
|
||
case len(bts) < 8:
|
||
return 0, 0, false
|
||
case !bytes.Equal(bts[:5], httpVersionPrefix):
|
||
return 0, 0, false
|
||
}
|
||
|
||
bts = bts[5:]
|
||
|
||
dot := bytes.IndexByte(bts, '.')
|
||
if dot == -1 {
|
||
return 0, 0, false
|
||
}
|
||
var err error
|
||
major, err = asciiToInt(bts[:dot])
|
||
if err != nil {
|
||
return major, 0, false
|
||
}
|
||
minor, err = asciiToInt(bts[dot+1:])
|
||
if err != nil {
|
||
return major, minor, false
|
||
}
|
||
|
||
return major, minor, true
|
||
}
|
||
|
||
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
|
||
// values and true if parse is ok.
|
||
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
|
||
colon := bytes.IndexByte(line, ':')
|
||
if colon == -1 {
|
||
return nil, nil, false
|
||
}
|
||
|
||
k = btrim(line[:colon])
|
||
// TODO(gobwas): maybe use just lower here?
|
||
canonicalizeHeaderKey(k)
|
||
|
||
v = btrim(line[colon+1:])
|
||
|
||
return k, v, true
|
||
}
|
||
|
||
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
|
||
// that key is already canonical. This helps to increase performance.
|
||
func httpGetHeader(h http.Header, key string) string {
|
||
if h == nil {
|
||
return ""
|
||
}
|
||
v := h[key]
|
||
if len(v) == 0 {
|
||
return ""
|
||
}
|
||
return v[0]
|
||
}
|
||
|
||
// The request MAY include a header field with the name
|
||
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
|
||
// comma-separated subprotocol the client wishes to speak, ordered by
|
||
// preference. The elements that comprise this value MUST be non-empty strings
|
||
// with characters in the range U+0021 to U+007E not including separator
|
||
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
|
||
// for the value of this header field is 1#token, where the definitions of
|
||
// constructs and rules are as given in [RFC2616].
|
||
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
|
||
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
|
||
if check(btsToString(v)) {
|
||
ret = string(v)
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
return ret, ok
|
||
}
|
||
|
||
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
|
||
var selected []byte
|
||
ok = httphead.ScanTokens(h, func(v []byte) bool {
|
||
if check(v) {
|
||
selected = v
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
if ok && selected != nil {
|
||
return string(selected), true
|
||
}
|
||
return ret, ok
|
||
}
|
||
|
||
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
||
s := httphead.OptionSelector{
|
||
Flags: httphead.SelectCopy,
|
||
Check: check,
|
||
}
|
||
return s.Select(h, selected)
|
||
}
|
||
|
||
func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
|
||
if in.Size() == 0 {
|
||
return dest, nil
|
||
}
|
||
opt, err := f(in)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if opt.Size() > 0 {
|
||
dest = append(dest, opt)
|
||
}
|
||
return dest, nil
|
||
}
|
||
|
||
func negotiateExtensions(
|
||
h []byte, dest []httphead.Option,
|
||
f func(httphead.Option) (httphead.Option, error),
|
||
) (_ []httphead.Option, err error) {
|
||
index := -1
|
||
var current httphead.Option
|
||
ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
|
||
if i != index {
|
||
dest, err = negotiateMaybe(current, dest, f)
|
||
if err != nil {
|
||
return httphead.ControlBreak
|
||
}
|
||
index = i
|
||
current = httphead.Option{Name: name}
|
||
}
|
||
if attr != nil {
|
||
current.Parameters.Set(attr, val)
|
||
}
|
||
return httphead.ControlContinue
|
||
})
|
||
if !ok {
|
||
return nil, ErrMalformedRequest
|
||
}
|
||
return negotiateMaybe(current, dest, f)
|
||
}
|
||
|
||
func httpWriteHeader(bw *bufio.Writer, key, value string) {
|
||
httpWriteHeaderKey(bw, key)
|
||
bw.WriteString(value)
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
|
||
httpWriteHeaderKey(bw, key)
|
||
bw.Write(value)
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
|
||
bw.WriteString(key)
|
||
bw.WriteString(colonAndSpace)
|
||
}
|
||
|
||
func httpWriteUpgradeRequest(
|
||
bw *bufio.Writer,
|
||
u *url.URL,
|
||
nonce []byte,
|
||
protocols []string,
|
||
extensions []httphead.Option,
|
||
header HandshakeHeader,
|
||
host string,
|
||
) {
|
||
bw.WriteString("GET ")
|
||
bw.WriteString(u.RequestURI())
|
||
bw.WriteString(" HTTP/1.1\r\n")
|
||
|
||
if host == "" {
|
||
host = u.Host
|
||
}
|
||
httpWriteHeader(bw, headerHost, host)
|
||
|
||
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
|
||
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
|
||
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
|
||
|
||
// NOTE: write nonce bytes as a string to prevent heap allocation –
|
||
// WriteString() copy given string into its inner buffer, unlike Write()
|
||
// which may write p directly to the underlying io.Writer – which in turn
|
||
// will lead to p escape.
|
||
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
|
||
|
||
if len(protocols) > 0 {
|
||
httpWriteHeaderKey(bw, headerSecProtocol)
|
||
for i, p := range protocols {
|
||
if i > 0 {
|
||
bw.WriteString(commaAndSpace)
|
||
}
|
||
bw.WriteString(p)
|
||
}
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
if len(extensions) > 0 {
|
||
httpWriteHeaderKey(bw, headerSecExtensions)
|
||
httphead.WriteOptions(bw, extensions)
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
if header != nil {
|
||
header.WriteTo(bw)
|
||
}
|
||
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
|
||
bw.WriteString(textHeadUpgrade)
|
||
|
||
httpWriteHeaderKey(bw, headerSecAccept)
|
||
writeAccept(bw, nonce)
|
||
bw.WriteString(crlf)
|
||
|
||
if hs.Protocol != "" {
|
||
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
|
||
}
|
||
if len(hs.Extensions) > 0 {
|
||
httpWriteHeaderKey(bw, headerSecExtensions)
|
||
httphead.WriteOptions(bw, hs.Extensions)
|
||
bw.WriteString(crlf)
|
||
}
|
||
if header != nil {
|
||
header(bw)
|
||
}
|
||
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
|
||
switch code {
|
||
case http.StatusBadRequest:
|
||
bw.WriteString(textHeadBadRequest)
|
||
case http.StatusInternalServerError:
|
||
bw.WriteString(textHeadInternalServerError)
|
||
case http.StatusUpgradeRequired:
|
||
bw.WriteString(textHeadUpgradeRequired)
|
||
default:
|
||
writeStatusText(bw, code)
|
||
}
|
||
|
||
// Write custom headers.
|
||
if header != nil {
|
||
header(bw)
|
||
}
|
||
|
||
switch err {
|
||
case ErrHandshakeBadProtocol:
|
||
bw.WriteString(textTailErrHandshakeBadProtocol)
|
||
case ErrHandshakeBadMethod:
|
||
bw.WriteString(textTailErrHandshakeBadMethod)
|
||
case ErrHandshakeBadHost:
|
||
bw.WriteString(textTailErrHandshakeBadHost)
|
||
case ErrHandshakeBadUpgrade:
|
||
bw.WriteString(textTailErrHandshakeBadUpgrade)
|
||
case ErrHandshakeBadConnection:
|
||
bw.WriteString(textTailErrHandshakeBadConnection)
|
||
case ErrHandshakeBadSecAccept:
|
||
bw.WriteString(textTailErrHandshakeBadSecAccept)
|
||
case ErrHandshakeBadSecKey:
|
||
bw.WriteString(textTailErrHandshakeBadSecKey)
|
||
case ErrHandshakeBadSecVersion:
|
||
bw.WriteString(textTailErrHandshakeBadSecVersion)
|
||
case ErrHandshakeUpgradeRequired:
|
||
bw.WriteString(textTailErrUpgradeRequired)
|
||
case nil:
|
||
bw.WriteString(crlf)
|
||
default:
|
||
writeErrorText(bw, err)
|
||
}
|
||
}
|
||
|
||
func writeStatusText(bw *bufio.Writer, code int) {
|
||
bw.WriteString("HTTP/1.1 ")
|
||
bw.WriteString(strconv.Itoa(code))
|
||
bw.WriteByte(' ')
|
||
bw.WriteString(http.StatusText(code))
|
||
bw.WriteString(crlf)
|
||
bw.WriteString("Content-Type: text/plain; charset=utf-8")
|
||
bw.WriteString(crlf)
|
||
}
|
||
|
||
func writeErrorText(bw *bufio.Writer, err error) {
|
||
body := err.Error()
|
||
bw.WriteString("Content-Length: ")
|
||
bw.WriteString(strconv.Itoa(len(body)))
|
||
bw.WriteString(crlf)
|
||
bw.WriteString(crlf)
|
||
bw.WriteString(body)
|
||
}
|
||
|
||
// httpError is like the http.Error with WebSocket context exception.
|
||
func httpError(w http.ResponseWriter, body string, code int) {
|
||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||
w.WriteHeader(code)
|
||
w.Write([]byte(body))
|
||
}
|
||
|
||
// statusText is a non-performant status text generator.
|
||
// NOTE: Used only to generate constants.
|
||
func statusText(code int) string {
|
||
var buf bytes.Buffer
|
||
bw := bufio.NewWriter(&buf)
|
||
writeStatusText(bw, code)
|
||
bw.Flush()
|
||
return buf.String()
|
||
}
|
||
|
||
// errorText is a non-performant error text generator.
|
||
// NOTE: Used only to generate constants.
|
||
func errorText(err error) string {
|
||
var buf bytes.Buffer
|
||
bw := bufio.NewWriter(&buf)
|
||
writeErrorText(bw, err)
|
||
bw.Flush()
|
||
return buf.String()
|
||
}
|
||
|
||
// HandshakeHeader is the interface that writes both upgrade request or
|
||
// response headers into a given io.Writer.
|
||
type HandshakeHeader interface {
|
||
io.WriterTo
|
||
}
|
||
|
||
// HandshakeHeaderString is an adapter to allow the use of headers represented
|
||
// by ordinary string as HandshakeHeader.
|
||
type HandshakeHeaderString string
|
||
|
||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
|
||
n, err := io.WriteString(w, string(s))
|
||
return int64(n), err
|
||
}
|
||
|
||
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
|
||
// by ordinary slice of bytes as HandshakeHeader.
|
||
type HandshakeHeaderBytes []byte
|
||
|
||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
|
||
n, err := w.Write(b)
|
||
return int64(n), err
|
||
}
|
||
|
||
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
|
||
// ordinary function as HandshakeHeader.
|
||
type HandshakeHeaderFunc func(io.Writer) (int64, error)
|
||
|
||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
|
||
return f(w)
|
||
}
|
||
|
||
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
|
||
// HandshakeHeader.
|
||
type HandshakeHeaderHTTP http.Header
|
||
|
||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
|
||
wr := writer{w: w}
|
||
err := http.Header(h).Write(&wr)
|
||
return wr.n, err
|
||
}
|
||
|
||
type writer struct {
|
||
n int64
|
||
w io.Writer
|
||
}
|
||
|
||
func (w *writer) WriteString(s string) (int, error) {
|
||
n, err := io.WriteString(w.w, s)
|
||
w.n += int64(n)
|
||
return n, err
|
||
}
|
||
|
||
func (w *writer) Write(p []byte) (int, error) {
|
||
n, err := w.w.Write(p)
|
||
w.n += int64(n)
|
||
return n, err
|
||
}
|