374 lines
9.5 KiB
Go
374 lines
9.5 KiB
Go
|
package wsutil
|
|||
|
|
|||
|
import (
|
|||
|
"encoding/binary"
|
|||
|
"errors"
|
|||
|
"io"
|
|||
|
"io/ioutil"
|
|||
|
|
|||
|
"github.com/gobwas/ws"
|
|||
|
)
|
|||
|
|
|||
|
// ErrNoFrameAdvance means that Reader's Read() method was called without
|
|||
|
// preceding NextFrame() call.
|
|||
|
var ErrNoFrameAdvance = errors.New("no frame advance")
|
|||
|
|
|||
|
// ErrFrameTooLarge indicates that a message of length higher than
|
|||
|
// MaxFrameSize was being read.
|
|||
|
var ErrFrameTooLarge = errors.New("frame too large")
|
|||
|
|
|||
|
// FrameHandlerFunc handles parsed frame header and its body represented by
|
|||
|
// io.Reader.
|
|||
|
//
|
|||
|
// Note that reader represents already unmasked body.
|
|||
|
type FrameHandlerFunc func(ws.Header, io.Reader) error
|
|||
|
|
|||
|
// Reader is a wrapper around source io.Reader which represents WebSocket
|
|||
|
// connection. It contains options for reading messages from source.
|
|||
|
//
|
|||
|
// Reader implements io.Reader, which Read() method reads payload of incoming
|
|||
|
// WebSocket frames. It also takes care on fragmented frames and possibly
|
|||
|
// intermediate control frames between them.
|
|||
|
//
|
|||
|
// Note that Reader's methods are not goroutine safe.
|
|||
|
type Reader struct {
|
|||
|
Source io.Reader
|
|||
|
State ws.State
|
|||
|
|
|||
|
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
|
|||
|
SkipHeaderCheck bool
|
|||
|
|
|||
|
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
|
|||
|
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
|
|||
|
CheckUTF8 bool
|
|||
|
|
|||
|
// Extensions is a list of negotiated extensions for reader Source.
|
|||
|
// It is used to meet the specs and clear appropriate bits in fragment
|
|||
|
// header RSV segment.
|
|||
|
Extensions []RecvExtension
|
|||
|
|
|||
|
// MaxFrameSize controls the maximum frame size in bytes
|
|||
|
// that can be read. A message exceeding that size will return
|
|||
|
// a ErrFrameTooLarge to the application.
|
|||
|
//
|
|||
|
// Not setting this field means there is no limit.
|
|||
|
MaxFrameSize int64
|
|||
|
|
|||
|
OnContinuation FrameHandlerFunc
|
|||
|
OnIntermediate FrameHandlerFunc
|
|||
|
|
|||
|
opCode ws.OpCode // Used to store message op code on fragmentation.
|
|||
|
frame io.Reader // Used to as frame reader.
|
|||
|
raw io.LimitedReader // Used to discard frames without cipher.
|
|||
|
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
|
|||
|
tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers.
|
|||
|
cr *CipherReader // Used by NextFrame() to unmask frame payload.
|
|||
|
}
|
|||
|
|
|||
|
// NewReader creates new frame reader that reads from r keeping given state to
|
|||
|
// make some protocol validity checks when it needed.
|
|||
|
func NewReader(r io.Reader, s ws.State) *Reader {
|
|||
|
return &Reader{
|
|||
|
Source: r,
|
|||
|
State: s,
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// NewClientSideReader is a helper function that calls NewReader with r and
|
|||
|
// ws.StateClientSide.
|
|||
|
func NewClientSideReader(r io.Reader) *Reader {
|
|||
|
return NewReader(r, ws.StateClientSide)
|
|||
|
}
|
|||
|
|
|||
|
// NewServerSideReader is a helper function that calls NewReader with r and
|
|||
|
// ws.StateServerSide.
|
|||
|
func NewServerSideReader(r io.Reader) *Reader {
|
|||
|
return NewReader(r, ws.StateServerSide)
|
|||
|
}
|
|||
|
|
|||
|
// Read implements io.Reader. It reads the next message payload into p.
|
|||
|
// It takes care on fragmented messages.
|
|||
|
//
|
|||
|
// The error is io.EOF only if all of message bytes were read.
|
|||
|
// If an io.EOF happens during reading some but not all the message bytes
|
|||
|
// Read() returns io.ErrUnexpectedEOF.
|
|||
|
//
|
|||
|
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
|
|||
|
// reading next message bytes.
|
|||
|
func (r *Reader) Read(p []byte) (n int, err error) {
|
|||
|
if r.frame == nil {
|
|||
|
if !r.fragmented() {
|
|||
|
// Every new Read() must be preceded by NextFrame() call.
|
|||
|
return 0, ErrNoFrameAdvance
|
|||
|
}
|
|||
|
// Read next continuation or intermediate control frame.
|
|||
|
_, err := r.NextFrame()
|
|||
|
if err != nil {
|
|||
|
return 0, err
|
|||
|
}
|
|||
|
if r.frame == nil {
|
|||
|
// We handled intermediate control and now got nothing to read.
|
|||
|
return 0, nil
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
n, err = r.frame.Read(p)
|
|||
|
if err != nil && err != io.EOF {
|
|||
|
return n, err
|
|||
|
}
|
|||
|
if err == nil && r.raw.N != 0 {
|
|||
|
return n, nil
|
|||
|
}
|
|||
|
|
|||
|
// EOF condition (either err is io.EOF or r.raw.N is zero).
|
|||
|
switch {
|
|||
|
case r.raw.N != 0:
|
|||
|
err = io.ErrUnexpectedEOF
|
|||
|
|
|||
|
case r.fragmented():
|
|||
|
err = nil
|
|||
|
r.resetFragment()
|
|||
|
|
|||
|
case r.CheckUTF8 && !r.utf8.Valid():
|
|||
|
// NOTE: check utf8 only when full message received, since partial
|
|||
|
// reads may be invalid.
|
|||
|
n = r.utf8.Accepted()
|
|||
|
err = ErrInvalidUTF8
|
|||
|
|
|||
|
default:
|
|||
|
r.reset()
|
|||
|
err = io.EOF
|
|||
|
}
|
|||
|
|
|||
|
return n, err
|
|||
|
}
|
|||
|
|
|||
|
// Discard discards current message unread bytes.
|
|||
|
// It discards all frames of fragmented message.
|
|||
|
func (r *Reader) Discard() (err error) {
|
|||
|
for {
|
|||
|
_, err = io.Copy(ioutil.Discard, &r.raw)
|
|||
|
if err != nil {
|
|||
|
break
|
|||
|
}
|
|||
|
if !r.fragmented() {
|
|||
|
break
|
|||
|
}
|
|||
|
if _, err = r.NextFrame(); err != nil {
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
r.reset()
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
// NextFrame prepares r to read next message. It returns received frame header
|
|||
|
// and non-nil error on failure.
|
|||
|
//
|
|||
|
// Note that next NextFrame() call must be done after receiving or discarding
|
|||
|
// all current message bytes.
|
|||
|
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
|
|||
|
hdr, err = r.readHeader(r.Source)
|
|||
|
if err == io.EOF && r.fragmented() {
|
|||
|
// If we are in fragmented state EOF means that is was totally
|
|||
|
// unexpected.
|
|||
|
//
|
|||
|
// NOTE: This is necessary to prevent callers such that
|
|||
|
// ioutil.ReadAll to receive some amount of bytes without an error.
|
|||
|
// ReadAll() ignores an io.EOF error, thus caller may think that
|
|||
|
// whole message fetched, but actually only part of it.
|
|||
|
err = io.ErrUnexpectedEOF
|
|||
|
}
|
|||
|
if err == nil && !r.SkipHeaderCheck {
|
|||
|
err = ws.CheckHeader(hdr, r.State)
|
|||
|
}
|
|||
|
if err != nil {
|
|||
|
return hdr, err
|
|||
|
}
|
|||
|
|
|||
|
if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
|
|||
|
return hdr, ErrFrameTooLarge
|
|||
|
}
|
|||
|
|
|||
|
// Save raw reader to use it on discarding frame without ciphering and
|
|||
|
// other streaming checks.
|
|||
|
r.raw = io.LimitedReader{
|
|||
|
R: r.Source,
|
|||
|
N: hdr.Length,
|
|||
|
}
|
|||
|
|
|||
|
frame := io.Reader(&r.raw)
|
|||
|
if hdr.Masked {
|
|||
|
if r.cr == nil {
|
|||
|
r.cr = NewCipherReader(frame, hdr.Mask)
|
|||
|
} else {
|
|||
|
r.cr.Reset(frame, hdr.Mask)
|
|||
|
}
|
|||
|
frame = r.cr
|
|||
|
}
|
|||
|
|
|||
|
for _, x := range r.Extensions {
|
|||
|
hdr, err = x.UnsetBits(hdr)
|
|||
|
if err != nil {
|
|||
|
return hdr, err
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if r.fragmented() {
|
|||
|
if hdr.OpCode.IsControl() {
|
|||
|
if cb := r.OnIntermediate; cb != nil {
|
|||
|
err = cb(hdr, frame)
|
|||
|
}
|
|||
|
if err == nil {
|
|||
|
// Ensure that src is empty.
|
|||
|
_, err = io.Copy(ioutil.Discard, &r.raw)
|
|||
|
}
|
|||
|
return hdr, err
|
|||
|
}
|
|||
|
} else {
|
|||
|
r.opCode = hdr.OpCode
|
|||
|
}
|
|||
|
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
|
|||
|
r.utf8.Source = frame
|
|||
|
frame = &r.utf8
|
|||
|
}
|
|||
|
|
|||
|
// Save reader with ciphering and other streaming checks.
|
|||
|
r.frame = frame
|
|||
|
|
|||
|
if hdr.OpCode == ws.OpContinuation {
|
|||
|
if cb := r.OnContinuation; cb != nil {
|
|||
|
err = cb(hdr, frame)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if hdr.Fin {
|
|||
|
r.State = r.State.Clear(ws.StateFragmented)
|
|||
|
} else {
|
|||
|
r.State = r.State.Set(ws.StateFragmented)
|
|||
|
}
|
|||
|
|
|||
|
return hdr, err
|
|||
|
}
|
|||
|
|
|||
|
func (r *Reader) fragmented() bool {
|
|||
|
return r.State.Fragmented()
|
|||
|
}
|
|||
|
|
|||
|
func (r *Reader) resetFragment() {
|
|||
|
r.raw = io.LimitedReader{}
|
|||
|
r.frame = nil
|
|||
|
// Reset source of the UTF8Reader, but not the state.
|
|||
|
r.utf8.Source = nil
|
|||
|
}
|
|||
|
|
|||
|
func (r *Reader) reset() {
|
|||
|
r.raw = io.LimitedReader{}
|
|||
|
r.frame = nil
|
|||
|
r.utf8 = UTF8Reader{}
|
|||
|
r.opCode = 0
|
|||
|
}
|
|||
|
|
|||
|
// readHeader reads a frame header from in.
|
|||
|
func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) {
|
|||
|
// Make slice of bytes with capacity 12 that could hold any header.
|
|||
|
//
|
|||
|
// The maximum header size is 14, but due to the 2 hop reads,
|
|||
|
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
|
|||
|
// So 14 - 2 = 12.
|
|||
|
bts := r.tmp[:2]
|
|||
|
|
|||
|
// Prepare to hold first 2 bytes to choose size of next read.
|
|||
|
_, err = io.ReadFull(in, bts)
|
|||
|
if err != nil {
|
|||
|
return h, err
|
|||
|
}
|
|||
|
const bit0 = 0x80
|
|||
|
|
|||
|
h.Fin = bts[0]&bit0 != 0
|
|||
|
h.Rsv = (bts[0] & 0x70) >> 4
|
|||
|
h.OpCode = ws.OpCode(bts[0] & 0x0f)
|
|||
|
|
|||
|
var extra int
|
|||
|
|
|||
|
if bts[1]&bit0 != 0 {
|
|||
|
h.Masked = true
|
|||
|
extra += 4
|
|||
|
}
|
|||
|
|
|||
|
length := bts[1] & 0x7f
|
|||
|
switch {
|
|||
|
case length < 126:
|
|||
|
h.Length = int64(length)
|
|||
|
|
|||
|
case length == 126:
|
|||
|
extra += 2
|
|||
|
|
|||
|
case length == 127:
|
|||
|
extra += 8
|
|||
|
|
|||
|
default:
|
|||
|
err = ws.ErrHeaderLengthUnexpected
|
|||
|
return h, err
|
|||
|
}
|
|||
|
|
|||
|
if extra == 0 {
|
|||
|
return h, err
|
|||
|
}
|
|||
|
|
|||
|
// Increase len of bts to extra bytes need to read.
|
|||
|
// Overwrite first 2 bytes that was read before.
|
|||
|
bts = bts[:extra]
|
|||
|
_, err = io.ReadFull(in, bts)
|
|||
|
if err != nil {
|
|||
|
return h, err
|
|||
|
}
|
|||
|
|
|||
|
switch {
|
|||
|
case length == 126:
|
|||
|
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
|
|||
|
bts = bts[2:]
|
|||
|
|
|||
|
case length == 127:
|
|||
|
if bts[0]&0x80 != 0 {
|
|||
|
err = ws.ErrHeaderLengthMSB
|
|||
|
return h, err
|
|||
|
}
|
|||
|
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
|
|||
|
bts = bts[8:]
|
|||
|
}
|
|||
|
|
|||
|
if h.Masked {
|
|||
|
copy(h.Mask[:], bts)
|
|||
|
}
|
|||
|
|
|||
|
return h, nil
|
|||
|
}
|
|||
|
|
|||
|
// NextReader prepares next message read from r. It returns header that
|
|||
|
// describes the message and io.Reader to read message's payload. It returns
|
|||
|
// non-nil error when it is not possible to read message's initial frame.
|
|||
|
//
|
|||
|
// Note that next NextReader() on the same r should be done after reading all
|
|||
|
// bytes from previously returned io.Reader. For more performant way to discard
|
|||
|
// message use Reader and its Discard() method.
|
|||
|
//
|
|||
|
// Note that it will not handle any "intermediate" frames, that possibly could
|
|||
|
// be received between text/binary continuation frames. That is, if peer sent
|
|||
|
// text/binary frame with fin flag "false", then it could send ping frame, and
|
|||
|
// eventually remaining part of text/binary frame with fin "true" – with
|
|||
|
// NextReader() the ping frame will be dropped without any notice. To handle
|
|||
|
// this rare, but possible situation (and if you do not know exactly which
|
|||
|
// frames peer could send), you could use Reader with OnIntermediate field set.
|
|||
|
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
|
|||
|
rd := &Reader{
|
|||
|
Source: r,
|
|||
|
State: s,
|
|||
|
}
|
|||
|
header, err := rd.NextFrame()
|
|||
|
if err != nil {
|
|||
|
return header, nil, err
|
|||
|
}
|
|||
|
return header, rd, nil
|
|||
|
}
|