well-goknown/vendor/github.com/valyala/fasthttp/zstd.go
2024-09-25 21:59:44 -04:00

187 lines
4.3 KiB
Go

package fasthttp
import (
"bytes"
"fmt"
"io"
"sync"
"github.com/klauspost/compress/zstd"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp/stackless"
)
const (
CompressZstdSpeedNotSet = iota
CompressZstdBestSpeed
CompressZstdDefault
CompressZstdSpeedBetter
CompressZstdBestCompression
)
var (
zstdDecoderPool sync.Pool
zstdEncoderPool sync.Pool
realZstdWriterPoolMap = newCompressWriterPoolMap()
stacklessZstdWriterPoolMap = newCompressWriterPoolMap()
)
func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) {
v := zstdDecoderPool.Get()
if v == nil {
return zstd.NewReader(r)
}
zr := v.(*zstd.Decoder)
if err := zr.Reset(r); err != nil {
return nil, err
}
return zr, nil
}
func releaseZstdReader(zr *zstd.Decoder) {
zstdDecoderPool.Put(zr)
}
func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) {
v := zstdEncoderPool.Get()
if v == nil {
return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level)))
}
zw := v.(*zstd.Encoder)
zw.Reset(w)
return zw, nil
}
func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused
zw.Close()
zstdEncoderPool.Put(zw)
}
func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer {
nLevel := normalizeZstdCompressLevel(compressLevel)
p := stacklessZstdWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealZstdWriter(w, compressLevel)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) {
zf.Close()
nLevel := normalizeZstdCompressLevel(zstdDefault)
p := stacklessZstdWriterPoolMap[nLevel]
p.Put(zf)
}
func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder {
nLevel := normalizeZstdCompressLevel(level)
p := realZstdWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw, err := acquireZstdWriter(w, level)
if err != nil {
panic(err)
}
return zw
}
zw := v.(*zstd.Encoder)
zw.Reset(w)
return zw
}
func releaseRealZstdWrter(zw *zstd.Encoder, level int) {
zw.Close()
nLevel := normalizeZstdCompressLevel(level)
p := realZstdWriterPoolMap[nLevel]
p.Put(zw)
}
func AppendZstdBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{b: dst}
WriteZstdLevel(w, src, level) //nolint:errcheck
return w.b
}
func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) {
level = normalizeZstdCompressLevel(level)
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteZstd(ctx)
return len(p), nil
default:
zw := acquireStacklessZstdWriter(w, level)
n, err := zw.Write(p)
releaseStacklessZstdWriter(zw, level)
return n, err
}
}
var (
stacklessWriteZstdOnce sync.Once
stacklessWriteZstdFunc func(ctx any) bool
)
func stacklessWriteZstd(ctx any) {
stacklessWriteZstdOnce.Do(func() {
stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd)
})
stacklessWriteZstdFunc(ctx)
}
func nonblockingWriteZstd(ctxv any) {
ctx := ctxv.(*compressCtx)
zw := acquireRealZstdWriter(ctx.w, ctx.level)
zw.Write(ctx.p) //nolint:errcheck
releaseRealZstdWrter(zw, ctx.level)
}
// AppendZstdBytes appends zstd src to dst and returns the resulting dst.
func AppendZstdBytes(dst, src []byte) []byte {
return AppendZstdBytesLevel(dst, src, CompressZstdDefault)
}
// WriteUnzstd writes unzstd p to w and returns the number of uncompressed
// bytes written to w.
func WriteUnzstd(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{b: p}
zr, err := acquireZstdReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseZstdReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data unzstd: %d", n)
}
return nn, err
}
// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst.
func AppendUnzstdBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{b: dst}
_, err := WriteUnzstd(w, src)
return w.b, err
}
// normalizes compression level into [0..7], so it could be used as an index
// in *PoolMap.
func normalizeZstdCompressLevel(level int) int {
if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression {
level = CompressZstdDefault
}
return level
}