package wsflate

import (
	"bytes"
	"compress/flate"
	"fmt"
	"io"

	"github.com/gobwas/ws"
)

// DefaultHelper is a default helper instance holding standard library's
// `compress/flate` compressor and decompressor under the hood.
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
var DefaultHelper = Helper{
	Compressor: func(w io.Writer) Compressor {
		// No error can be returned here as NewWriter() doc says.
		f, _ := flate.NewWriter(w, 9)
		return f
	},
	Decompressor: func(r io.Reader) Decompressor {
		return flate.NewReader(r)
	},
}

// DefaultParameters holds deflate extension parameters which are assumed by
// DefaultHelper to be used during WebSocket handshake.
var DefaultParameters = Parameters{
	ServerNoContextTakeover: true,
	ClientNoContextTakeover: true,
}

// CompressFrame is a shortcut for DefaultHelper.CompressFrame().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func CompressFrame(f ws.Frame) (ws.Frame, error) {
	return DefaultHelper.CompressFrame(f)
}

// CompressFrameBuffer is a shortcut for DefaultHelper.CompressFrameBuffer().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
	return DefaultHelper.CompressFrameBuffer(buf, f)
}

// DecompressFrame is a shortcut for DefaultHelper.DecompressFrame().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func DecompressFrame(f ws.Frame) (ws.Frame, error) {
	return DefaultHelper.DecompressFrame(f)
}

// DecompressFrameBuffer is a shortcut for
// DefaultHelper.DecompressFrameBuffer().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
	return DefaultHelper.DecompressFrameBuffer(buf, f)
}

// Helper is a helper struct that holds common code for compression and
// decompression bytes or WebSocket frames.
//
// Its purpose is to reduce boilerplate code in WebSocket applications.
type Helper struct {
	Compressor   func(w io.Writer) Compressor
	Decompressor func(r io.Reader) Decompressor
}

// Buffer is an interface representing some bytes buffering object.
type Buffer interface {
	io.Writer
	Bytes() []byte
}

// CompressFrame returns compressed version of a frame.
// Note that it does memory allocations internally. To control those
// allocations consider using CompressFrameBuffer().
func (h *Helper) CompressFrame(in ws.Frame) (f ws.Frame, err error) {
	var buf bytes.Buffer
	return h.CompressFrameBuffer(&buf, in)
}

// DecompressFrame returns decompressed version of a frame.
// Note that it does memory allocations internally. To control those
// allocations consider using DecompressFrameBuffer().
func (h *Helper) DecompressFrame(in ws.Frame) (f ws.Frame, err error) {
	var buf bytes.Buffer
	return h.DecompressFrameBuffer(&buf, in)
}

// CompressFrameBuffer compresses a frame using given buffer.
// Returned frame's payload holds bytes returned by buf.Bytes().
func (h *Helper) CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
	if !f.Header.Fin {
		return f, fmt.Errorf("wsflate: fragmented messages are not allowed")
	}
	if err := h.CompressTo(buf, f.Payload); err != nil {
		return f, err
	}
	var err error
	f.Payload = buf.Bytes()
	f.Header.Length = int64(len(f.Payload))
	f.Header, err = SetBit(f.Header)
	if err != nil {
		return f, err
	}
	return f, nil
}

// DecompressFrameBuffer decompresses a frame using given buffer.
// Returned frame's payload holds bytes returned by buf.Bytes().
func (h *Helper) DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
	if !f.Header.Fin {
		return f, fmt.Errorf(
			"wsflate: fragmented messages are not supported by helper",
		)
	}
	var (
		compressed bool
		err        error
	)
	f.Header, compressed, err = UnsetBit(f.Header)
	if err != nil {
		return f, err
	}
	if !compressed {
		return f, nil
	}
	if err := h.DecompressTo(buf, f.Payload); err != nil {
		return f, err
	}

	f.Payload = buf.Bytes()
	f.Header.Length = int64(len(f.Payload))

	return f, nil
}

// Compress compresses given bytes.
// Note that it does memory allocations internally. To control those
// allocations consider using CompressTo().
func (h *Helper) Compress(p []byte) ([]byte, error) {
	var buf bytes.Buffer
	if err := h.CompressTo(&buf, p); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

// Decompress decompresses given bytes.
// Note that it does memory allocations internally. To control those
// allocations consider using DecompressTo().
func (h *Helper) Decompress(p []byte) ([]byte, error) {
	var buf bytes.Buffer
	if err := h.DecompressTo(&buf, p); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

// CompressTo compresses bytes into given buffer.
func (h *Helper) CompressTo(w io.Writer, p []byte) (err error) {
	c := NewWriter(w, h.Compressor)
	if _, err = c.Write(p); err != nil {
		return err
	}
	if err := c.Flush(); err != nil {
		return err
	}
	if err := c.Close(); err != nil {
		return err
	}
	return nil
}

// DecompressTo decompresses bytes into given buffer.
// Returned bytes are bytes returned by buf.Bytes().
func (h *Helper) DecompressTo(w io.Writer, p []byte) (err error) {
	fr := NewReader(bytes.NewReader(p), h.Decompressor)
	if _, err = io.Copy(w, fr); err != nil {
		return err
	}
	if err := fr.Close(); err != nil {
		return err
	}
	return nil
}