well-goknown/vendor/github.com/gobwas/ws/wsutil/reader.go

374 lines
9.5 KiB
Go
Raw Normal View History

package wsutil
import (
"encoding/binary"
"errors"
"io"
"io/ioutil"
"github.com/gobwas/ws"
)
// ErrNoFrameAdvance means that Reader's Read() method was called without
// preceding NextFrame() call.
var ErrNoFrameAdvance = errors.New("no frame advance")
// ErrFrameTooLarge indicates that a message of length higher than
// MaxFrameSize was being read.
var ErrFrameTooLarge = errors.New("frame too large")
// FrameHandlerFunc handles parsed frame header and its body represented by
// io.Reader.
//
// Note that reader represents already unmasked body.
type FrameHandlerFunc func(ws.Header, io.Reader) error
// Reader is a wrapper around source io.Reader which represents WebSocket
// connection. It contains options for reading messages from source.
//
// Reader implements io.Reader, which Read() method reads payload of incoming
// WebSocket frames. It also takes care on fragmented frames and possibly
// intermediate control frames between them.
//
// Note that Reader's methods are not goroutine safe.
type Reader struct {
Source io.Reader
State ws.State
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
SkipHeaderCheck bool
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
CheckUTF8 bool
// Extensions is a list of negotiated extensions for reader Source.
// It is used to meet the specs and clear appropriate bits in fragment
// header RSV segment.
Extensions []RecvExtension
// MaxFrameSize controls the maximum frame size in bytes
// that can be read. A message exceeding that size will return
// a ErrFrameTooLarge to the application.
//
// Not setting this field means there is no limit.
MaxFrameSize int64
OnContinuation FrameHandlerFunc
OnIntermediate FrameHandlerFunc
opCode ws.OpCode // Used to store message op code on fragmentation.
frame io.Reader // Used to as frame reader.
raw io.LimitedReader // Used to discard frames without cipher.
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers.
cr *CipherReader // Used by NextFrame() to unmask frame payload.
}
// NewReader creates new frame reader that reads from r keeping given state to
// make some protocol validity checks when it needed.
func NewReader(r io.Reader, s ws.State) *Reader {
return &Reader{
Source: r,
State: s,
}
}
// NewClientSideReader is a helper function that calls NewReader with r and
// ws.StateClientSide.
func NewClientSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateClientSide)
}
// NewServerSideReader is a helper function that calls NewReader with r and
// ws.StateServerSide.
func NewServerSideReader(r io.Reader) *Reader {
return NewReader(r, ws.StateServerSide)
}
// Read implements io.Reader. It reads the next message payload into p.
// It takes care on fragmented messages.
//
// The error is io.EOF only if all of message bytes were read.
// If an io.EOF happens during reading some but not all the message bytes
// Read() returns io.ErrUnexpectedEOF.
//
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
// reading next message bytes.
func (r *Reader) Read(p []byte) (n int, err error) {
if r.frame == nil {
if !r.fragmented() {
// Every new Read() must be preceded by NextFrame() call.
return 0, ErrNoFrameAdvance
}
// Read next continuation or intermediate control frame.
_, err := r.NextFrame()
if err != nil {
return 0, err
}
if r.frame == nil {
// We handled intermediate control and now got nothing to read.
return 0, nil
}
}
n, err = r.frame.Read(p)
if err != nil && err != io.EOF {
return n, err
}
if err == nil && r.raw.N != 0 {
return n, nil
}
// EOF condition (either err is io.EOF or r.raw.N is zero).
switch {
case r.raw.N != 0:
err = io.ErrUnexpectedEOF
case r.fragmented():
err = nil
r.resetFragment()
case r.CheckUTF8 && !r.utf8.Valid():
// NOTE: check utf8 only when full message received, since partial
// reads may be invalid.
n = r.utf8.Accepted()
err = ErrInvalidUTF8
default:
r.reset()
err = io.EOF
}
return n, err
}
// Discard discards current message unread bytes.
// It discards all frames of fragmented message.
func (r *Reader) Discard() (err error) {
for {
_, err = io.Copy(ioutil.Discard, &r.raw)
if err != nil {
break
}
if !r.fragmented() {
break
}
if _, err = r.NextFrame(); err != nil {
break
}
}
r.reset()
return err
}
// NextFrame prepares r to read next message. It returns received frame header
// and non-nil error on failure.
//
// Note that next NextFrame() call must be done after receiving or discarding
// all current message bytes.
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
hdr, err = r.readHeader(r.Source)
if err == io.EOF && r.fragmented() {
// If we are in fragmented state EOF means that is was totally
// unexpected.
//
// NOTE: This is necessary to prevent callers such that
// ioutil.ReadAll to receive some amount of bytes without an error.
// ReadAll() ignores an io.EOF error, thus caller may think that
// whole message fetched, but actually only part of it.
err = io.ErrUnexpectedEOF
}
if err == nil && !r.SkipHeaderCheck {
err = ws.CheckHeader(hdr, r.State)
}
if err != nil {
return hdr, err
}
if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
return hdr, ErrFrameTooLarge
}
// Save raw reader to use it on discarding frame without ciphering and
// other streaming checks.
r.raw = io.LimitedReader{
R: r.Source,
N: hdr.Length,
}
frame := io.Reader(&r.raw)
if hdr.Masked {
if r.cr == nil {
r.cr = NewCipherReader(frame, hdr.Mask)
} else {
r.cr.Reset(frame, hdr.Mask)
}
frame = r.cr
}
for _, x := range r.Extensions {
hdr, err = x.UnsetBits(hdr)
if err != nil {
return hdr, err
}
}
if r.fragmented() {
if hdr.OpCode.IsControl() {
if cb := r.OnIntermediate; cb != nil {
err = cb(hdr, frame)
}
if err == nil {
// Ensure that src is empty.
_, err = io.Copy(ioutil.Discard, &r.raw)
}
return hdr, err
}
} else {
r.opCode = hdr.OpCode
}
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
r.utf8.Source = frame
frame = &r.utf8
}
// Save reader with ciphering and other streaming checks.
r.frame = frame
if hdr.OpCode == ws.OpContinuation {
if cb := r.OnContinuation; cb != nil {
err = cb(hdr, frame)
}
}
if hdr.Fin {
r.State = r.State.Clear(ws.StateFragmented)
} else {
r.State = r.State.Set(ws.StateFragmented)
}
return hdr, err
}
func (r *Reader) fragmented() bool {
return r.State.Fragmented()
}
func (r *Reader) resetFragment() {
r.raw = io.LimitedReader{}
r.frame = nil
// Reset source of the UTF8Reader, but not the state.
r.utf8.Source = nil
}
func (r *Reader) reset() {
r.raw = io.LimitedReader{}
r.frame = nil
r.utf8 = UTF8Reader{}
r.opCode = 0
}
// readHeader reads a frame header from in.
func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) {
// Make slice of bytes with capacity 12 that could hold any header.
//
// The maximum header size is 14, but due to the 2 hop reads,
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
// So 14 - 2 = 12.
bts := r.tmp[:2]
// Prepare to hold first 2 bytes to choose size of next read.
_, err = io.ReadFull(in, bts)
if err != nil {
return h, err
}
const bit0 = 0x80
h.Fin = bts[0]&bit0 != 0
h.Rsv = (bts[0] & 0x70) >> 4
h.OpCode = ws.OpCode(bts[0] & 0x0f)
var extra int
if bts[1]&bit0 != 0 {
h.Masked = true
extra += 4
}
length := bts[1] & 0x7f
switch {
case length < 126:
h.Length = int64(length)
case length == 126:
extra += 2
case length == 127:
extra += 8
default:
err = ws.ErrHeaderLengthUnexpected
return h, err
}
if extra == 0 {
return h, err
}
// Increase len of bts to extra bytes need to read.
// Overwrite first 2 bytes that was read before.
bts = bts[:extra]
_, err = io.ReadFull(in, bts)
if err != nil {
return h, err
}
switch {
case length == 126:
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
bts = bts[2:]
case length == 127:
if bts[0]&0x80 != 0 {
err = ws.ErrHeaderLengthMSB
return h, err
}
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
bts = bts[8:]
}
if h.Masked {
copy(h.Mask[:], bts)
}
return h, nil
}
// NextReader prepares next message read from r. It returns header that
// describes the message and io.Reader to read message's payload. It returns
// non-nil error when it is not possible to read message's initial frame.
//
// Note that next NextReader() on the same r should be done after reading all
// bytes from previously returned io.Reader. For more performant way to discard
// message use Reader and its Discard() method.
//
// Note that it will not handle any "intermediate" frames, that possibly could
// be received between text/binary continuation frames. That is, if peer sent
// text/binary frame with fin flag "false", then it could send ping frame, and
// eventually remaining part of text/binary frame with fin "true" with
// NextReader() the ping frame will be dropped without any notice. To handle
// this rare, but possible situation (and if you do not know exactly which
// frames peer could send), you could use Reader with OnIntermediate field set.
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
rd := &Reader{
Source: r,
State: s,
}
header, err := rd.NextFrame()
if err != nil {
return header, nil, err
}
return header, rd, nil
}