well-goknown/vendor/github.com/gobwas/ws/wsutil/reader.go

373 lines
9.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package wsutil
import (
"encoding/binary"
"errors"
"io"
"io/ioutil"
"github.com/gobwas/ws"
)
// ErrNoFrameAdvance means that Reader's Read() method was called without
// preceding NextFrame() call.
var ErrNoFrameAdvance = errors.New("no frame advance")
// ErrFrameTooLarge indicates that a message of length higher than
// MaxFrameSize was being read.
var ErrFrameTooLarge = errors.New("frame too large")
// FrameHandlerFunc handles parsed frame header and its body represented by
// io.Reader.
//
// Note that reader represents already unmasked body.
type FrameHandlerFunc func(ws.Header, io.Reader) error
// Reader is a wrapper around source io.Reader which represents WebSocket
// connection. It contains options for reading messages from source.
//
// Reader implements io.Reader, which Read() method reads payload of incoming
// WebSocket frames. It also takes care on fragmented frames and possibly
// intermediate control frames between them.
//
// Note that Reader's methods are not goroutine safe.
type Reader struct {
Source io.Reader
State ws.State
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
SkipHeaderCheck bool
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
CheckUTF8 bool
// Extensions is a list of negotiated extensions for reader Source.
// It is used to meet the specs and clear appropriate bits in fragment
// header RSV segment.
Extensions []RecvExtension
// MaxFrameSize controls the maximum frame size in bytes
// that can be read. A message exceeding that size will return
// a ErrFrameTooLarge to the application.
//
// Not setting this field means there is no limit.
MaxFrameSize int64
OnContinuation FrameHandlerFunc
OnIntermediate FrameHandlerFunc
opCode ws.OpCode // Used to store message op code on fragmentation.
frame io.Reader // Used to as frame reader.
raw io.LimitedReader // Used to discard frames without cipher.
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers.
cr *CipherReader // Used by NextFrame() to unmask frame payload.
}
// NewReader creates new frame reader that reads from r keeping given state to
// make some protocol validity checks when it needed.
func NewReader(r io.Reader, s ws.State) *Reader {
return &Reader{
Source: r,
State: s,
}
}
// NewClientSideReader is a helper function that calls NewReader with r and
// ws.StateClientSide.
func NewClientSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateClientSide)
}
// NewServerSideReader is a helper function that calls NewReader with r and
// ws.StateServerSide.
func NewServerSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateServerSide)
}
// Read implements io.Reader. It reads the next message payload into p.
// It takes care on fragmented messages.
//
// The error is io.EOF only if all of message bytes were read.
// If an io.EOF happens during reading some but not all the message bytes
// Read() returns io.ErrUnexpectedEOF.
//
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
// reading next message bytes.
func (r *Reader) Read(p []byte) (n int, err error) {
if r.frame == nil {
if !r.fragmented() {
// Every new Read() must be preceded by NextFrame() call.
return 0, ErrNoFrameAdvance
}
// Read next continuation or intermediate control frame.
_, err := r.NextFrame()
if err != nil {
return 0, err
}
if r.frame == nil {
// We handled intermediate control and now got nothing to read.
return 0, nil
}
}
n, err = r.frame.Read(p)
if err != nil && err != io.EOF {
return n, err
}
if err == nil && r.raw.N != 0 {
return n, nil
}
// EOF condition (either err is io.EOF or r.raw.N is zero).
switch {
case r.raw.N != 0:
err = io.ErrUnexpectedEOF
case r.fragmented():
err = nil
r.resetFragment()
case r.CheckUTF8 && !r.utf8.Valid():
// NOTE: check utf8 only when full message received, since partial
// reads may be invalid.
n = r.utf8.Accepted()
err = ErrInvalidUTF8
default:
r.reset()
err = io.EOF
}
return n, err
}
// Discard discards current message unread bytes.
// It discards all frames of fragmented message.
func (r *Reader) Discard() (err error) {
for {
_, err = io.Copy(ioutil.Discard, &r.raw)
if err != nil {
break
}
if !r.fragmented() {
break
}
if _, err = r.NextFrame(); err != nil {
break
}
}
r.reset()
return err
}
// NextFrame prepares r to read next message. It returns received frame header
// and non-nil error on failure.
//
// Note that next NextFrame() call must be done after receiving or discarding
// all current message bytes.
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
hdr, err = r.readHeader(r.Source)
if err == io.EOF && r.fragmented() {
// If we are in fragmented state EOF means that is was totally
// unexpected.
//
// NOTE: This is necessary to prevent callers such that
// ioutil.ReadAll to receive some amount of bytes without an error.
// ReadAll() ignores an io.EOF error, thus caller may think that
// whole message fetched, but actually only part of it.
err = io.ErrUnexpectedEOF
}
if err == nil && !r.SkipHeaderCheck {
err = ws.CheckHeader(hdr, r.State)
}
if err != nil {
return hdr, err
}
if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
return hdr, ErrFrameTooLarge
}
// Save raw reader to use it on discarding frame without ciphering and
// other streaming checks.
r.raw = io.LimitedReader{
R: r.Source,
N: hdr.Length,
}
frame := io.Reader(&r.raw)
if hdr.Masked {
if r.cr == nil {
r.cr = NewCipherReader(frame, hdr.Mask)
} else {
r.cr.Reset(frame, hdr.Mask)
}
frame = r.cr
}
for _, x := range r.Extensions {
hdr, err = x.UnsetBits(hdr)
if err != nil {
return hdr, err
}
}
if r.fragmented() {
if hdr.OpCode.IsControl() {
if cb := r.OnIntermediate; cb != nil {
err = cb(hdr, frame)
}
if err == nil {
// Ensure that src is empty.
_, err = io.Copy(ioutil.Discard, &r.raw)
}
return hdr, err
}
} else {
r.opCode = hdr.OpCode
}
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
r.utf8.Source = frame
frame = &r.utf8
}
// Save reader with ciphering and other streaming checks.
r.frame = frame
if hdr.OpCode == ws.OpContinuation {
if cb := r.OnContinuation; cb != nil {
err = cb(hdr, frame)
}
}
if hdr.Fin {
r.State = r.State.Clear(ws.StateFragmented)
} else {
r.State = r.State.Set(ws.StateFragmented)
}
return hdr, err
}
func (r *Reader) fragmented() bool {
return r.State.Fragmented()
}
func (r *Reader) resetFragment() {
r.raw = io.LimitedReader{}
r.frame = nil
// Reset source of the UTF8Reader, but not the state.
r.utf8.Source = nil
}
func (r *Reader) reset() {
r.raw = io.LimitedReader{}
r.frame = nil
r.utf8 = UTF8Reader{}
r.opCode = 0
}
// readHeader reads a frame header from in.
func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) {
// Make slice of bytes with capacity 12 that could hold any header.
//
// The maximum header size is 14, but due to the 2 hop reads,
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
// So 14 - 2 = 12.
bts := r.tmp[:2]
// Prepare to hold first 2 bytes to choose size of next read.
_, err = io.ReadFull(in, bts)
if err != nil {
return h, err
}
const bit0 = 0x80
h.Fin = bts[0]&bit0 != 0
h.Rsv = (bts[0] & 0x70) >> 4
h.OpCode = ws.OpCode(bts[0] & 0x0f)
var extra int
if bts[1]&bit0 != 0 {
h.Masked = true
extra += 4
}
length := bts[1] & 0x7f
switch {
case length < 126:
h.Length = int64(length)
case length == 126:
extra += 2
case length == 127:
extra += 8
default:
err = ws.ErrHeaderLengthUnexpected
return h, err
}
if extra == 0 {
return h, err
}
// Increase len of bts to extra bytes need to read.
// Overwrite first 2 bytes that was read before.
bts = bts[:extra]
_, err = io.ReadFull(in, bts)
if err != nil {
return h, err
}
switch {
case length == 126:
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
bts = bts[2:]
case length == 127:
if bts[0]&0x80 != 0 {
err = ws.ErrHeaderLengthMSB
return h, err
}
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
bts = bts[8:]
}
if h.Masked {
copy(h.Mask[:], bts)
}
return h, nil
}
// NextReader prepares next message read from r. It returns header that
// describes the message and io.Reader to read message's payload. It returns
// non-nil error when it is not possible to read message's initial frame.
//
// Note that next NextReader() on the same r should be done after reading all
// bytes from previously returned io.Reader. For more performant way to discard
// message use Reader and its Discard() method.
//
// Note that it will not handle any "intermediate" frames, that possibly could
// be received between text/binary continuation frames. That is, if peer sent
// text/binary frame with fin flag "false", then it could send ping frame, and
// eventually remaining part of text/binary frame with fin "true" with
// NextReader() the ping frame will be dropped without any notice. To handle
// this rare, but possible situation (and if you do not know exactly which
// frames peer could send), you could use Reader with OnIntermediate field set.
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
rd := &Reader{
Source: r,
State: s,
}
header, err := rd.NextFrame()
if err != nil {
return header, nil, err
}
return header, rd, nil
}