Compare commits

..

No commits in common. "822c5f421adb128375439bbd7fec1d3efa222a3d" and "3bd11e47c0bd36445581a8480d28d21c2fbd99b9" have entirely different histories.

133 changed files with 9647 additions and 5855 deletions

View file

@ -1,278 +0,0 @@
package alby
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"git.devvul.com/asara/gologger"
"git.devvul.com/asara/well-goknown/config"
"github.com/nbd-wtf/go-nostr"
"github.com/nbd-wtf/go-nostr/nip04"
)
// check if event is valid
func checkEvent(n string) bool {
var zapEvent ZapEvent
l := gologger.Get(config.GetConfig().LogLevel).With().Caller().Logger()
err := json.Unmarshal([]byte(n), &zapEvent)
if err != nil {
l.Debug().Msgf("unable to unmarshal nwc value: %s", err.Error())
return false
}
if err != nil {
l.Debug().Msgf("unable to read tags from nostr request: %s", err.Error())
return false
}
evt := nostr.Event{
ID: zapEvent.Id,
PubKey: zapEvent.Pubkey,
CreatedAt: zapEvent.CreatedAt,
Kind: zapEvent.Kind,
Tags: zapEvent.Tags,
Content: zapEvent.Content,
Sig: zapEvent.Signature,
}
ok, err := evt.CheckSignature()
if !ok {
l.Debug().Msgf("event is invalid", err.Error())
return false
}
return true
}
// background task to return a receipt when the payment is paid
func watchForReceipt(nEvent string, secret NWCSecret, invoice string) {
var zapEvent ZapEvent
l := gologger.Get(config.GetConfig().LogLevel).With().Caller().Logger()
_, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ok := checkEvent(nEvent)
if !ok {
l.Debug().Msgf("nostr event is not valid")
return
}
err := json.Unmarshal([]byte(nEvent), &zapEvent)
if err != nil {
l.Debug().Msgf("unable to unmarshal nwc value: %s", err.Error())
return
}
ticker := time.NewTicker(30 * time.Second)
quit := make(chan struct{})
go func() {
defer ticker.Stop()
for {
select {
case <-quit:
return
case _ = <-ticker.C:
paid, failed, result := checkInvoicePaid(invoice, secret, nEvent)
if failed {
close(quit)
return
}
if paid {
sendReceipt(secret, result, nEvent)
close(quit)
return
}
}
}
defer close(quit)
}()
}
func checkInvoicePaid(checkInvoice string, secret NWCSecret, nEvent string) (bool, bool, LookupInvoiceResponse) {
l := gologger.Get(config.GetConfig().LogLevel).With().Caller().Logger()
invoiceParams := LookupInvoiceParams{
Invoice: checkInvoice,
}
invoice := LookupInvoice{
Method: "lookup_invoice",
Params: invoiceParams,
}
invoiceJson, err := json.Marshal(invoice)
if err != nil {
l.Debug().Msgf("unable to marshal invoice: %s", err.Error())
return false, true, LookupInvoiceResponse{}
}
// generate nip-04 shared secret
sharedSecret, err := nip04.ComputeSharedSecret(secret.AppPubkey, secret.Secret)
if err != nil {
l.Debug().Msgf("unable to marshal invoice: %s", err.Error())
return false, true, LookupInvoiceResponse{}
}
// create the encrypted content payload
encryptedContent, err := nip04.Encrypt(string(invoiceJson), sharedSecret)
if err != nil {
l.Debug().Msgf("unable to marshal invoice: %s", err.Error())
return false, true, LookupInvoiceResponse{}
}
recipient := nostr.Tag{"p", secret.AppPubkey}
nwcEv := nostr.Event{
PubKey: secret.AppPubkey,
CreatedAt: nostr.Now(),
Kind: nostr.KindNWCWalletRequest,
Tags: nostr.Tags{recipient},
Content: encryptedContent,
}
// sign the message with the app token
nwcEv.Sign(secret.Secret)
var filters nostr.Filters
t := make(map[string][]string)
t["e"] = []string{nwcEv.GetID()}
filters = []nostr.Filter{
{
Kinds: []int{
nostr.KindNWCWalletResponse,
},
Tags: t,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
relay, err := nostr.RelayConnect(ctx, secret.Relay)
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer subCancel()
// subscribe to the filter
sub, err := relay.Subscribe(subCtx, filters)
if err != nil {
l.Debug().Msgf("unable to connect to relay: %s", err.Error())
return false, false, LookupInvoiceResponse{}
}
var wg sync.WaitGroup
wg.Add(1)
// watch for the invoice
evs := make([]nostr.Event, 0)
go func() {
defer wg.Done()
for {
select {
case ev, ok := <-sub.Events:
if !ok {
l.Debug().Msgf("subscription events channel is closed")
return
}
if ev.Kind != 0 {
evs = append(evs, *ev)
}
if len(evs) > 0 {
return
}
case <-sub.EndOfStoredEvents:
l.Trace().Msgf("end of stored events received")
case <-ctx.Done():
l.Debug().Msgf("subscription context cancelled or done: %v", ctx.Err())
return
}
}
}()
// publish the invoice request
if err := relay.Publish(ctx, nwcEv); err != nil {
l.Debug().Msgf("unable to publish event: %s", err.Error())
return false, false, LookupInvoiceResponse{}
}
// wait for the invoice to get returned
wg.Wait()
// decrypt the invoice
response, err := nip04.Decrypt(evs[0].Content, sharedSecret)
resStruct := LookupInvoiceResponse{}
err = json.Unmarshal([]byte(response), &resStruct)
if err != nil {
l.Debug().Msgf("unable to unmarshal invoice response: %s", err.Error())
return false, true, LookupInvoiceResponse{}
}
if settled := resStruct.Result.isSettled(); settled {
return true, false, resStruct
}
if expired := resStruct.Result.isExpired(); expired {
return false, true, LookupInvoiceResponse{}
}
return false, false, LookupInvoiceResponse{}
}
func sendReceipt(secret NWCSecret, result LookupInvoiceResponse, nEvent string) {
l := gologger.Get(config.GetConfig().LogLevel).With().Caller().Logger()
var zapRequestEvent ZapEvent
err := json.Unmarshal([]byte(nEvent), &zapRequestEvent)
if err != nil {
return
}
zapReceipt := nostr.Event{
PubKey: secret.ClientPubkey,
CreatedAt: result.Result.SettledAt,
Kind: nostr.KindNWCWalletResponse,
Tags: zapRequestEvent.Tags,
Content: "",
}
// add context to zapReceipt
sender := nostr.Tag{"P", zapRequestEvent.Pubkey}
bolt11 := nostr.Tag{"bolt11", result.Result.Invoice}
preimage := nostr.Tag{"preimage", result.Result.Preimage}
description := nostr.Tag{"description", nEvent}
zapReceipt.Tags = zapReceipt.Tags.AppendUnique(sender)
zapReceipt.Tags = zapReceipt.Tags.AppendUnique(bolt11)
zapReceipt.Tags = zapReceipt.Tags.AppendUnique(preimage)
zapReceipt.Tags = zapReceipt.Tags.AppendUnique(description)
// remove unneeded values from tags
zapReceipt.Tags = zapReceipt.Tags.FilterOut([]string{"relays"})
zapReceipt.Tags = zapReceipt.Tags.FilterOut([]string{"alt"})
// sign the receipt
zapReceipt.Sign(secret.Secret)
// send it to the listed relays
ctx := context.Background()
relayTag := zapRequestEvent.Tags.GetFirst([]string{"relays"})
var report []string
for idx, url := range *relayTag {
if idx == 0 {
continue
}
relay, err := nostr.RelayConnect(ctx, url)
if err != nil {
report = append(report, fmt.Sprintf("error: unable to connect to relay (%s): %s", url, err.Error()))
return
}
if err := relay.Publish(ctx, zapReceipt); err != nil {
report = append(report, fmt.Sprintf("error: unable to connect to relay (%s): %s", url, err.Error()))
return
}
report = append(report, fmt.Sprintf("success: sent receipt to %s", url))
}
l.Debug().Msgf("receipt report: %v", report)
return
}

View file

@ -2,8 +2,6 @@ package alby
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net"
@ -27,9 +25,9 @@ var (
)
type AlbyApp struct {
Id int32 `json:"id"`
Name string `json:"name"`
AppPubkey string `json:"appPubkey"`
Id int32 `json:"id"`
Name string `json:"name"`
NostrPubkey string `json:"nostrPubkey"`
}
type AlbyApps []AlbyApp
@ -51,21 +49,20 @@ type lnurlpError struct {
Reason string `json:"reason"`
}
type ZapEvent struct {
Id string `json:"id"`
Pubkey string `json:"pubkey"`
CreatedAt nostr.Timestamp `json:"created_at"`
Kind int `json:"kind"`
Tags nostr.Tags `json:"tags"`
Content string `json:"content"`
Signature string `json:"sig"`
type NWCReqNostr struct {
Id string `json:"id"`
Pubkey string `json:"pubkey"`
CreatedAt int64 `json:"created_at"`
Kind int32 `json:"kind"`
Tags [][]string `json:"tags"`
Content string `json:"content"`
Signature string `json:"sig"`
}
type NWCReq struct {
Nostr string `json:"nostr"`
Amount string `json:"amount"`
Comment string `json:"comment"`
LNUrl string `json:"lnurl"`
}
type NWCSecret struct {
@ -91,50 +88,6 @@ func (s *NWCSecret) decodeSecret() {
}
type LookupInvoiceParams struct {
Invoice string `json:"invoice"`
}
type LookupInvoice struct {
Method string `json:"method"`
Params LookupInvoiceParams `json:"params"`
}
type LookupInvoiceResponseResult struct {
Type string `json:"type"`
State string `json:"state"`
Invoice string `json:"invoice"`
Description string `json:"description"`
DescriptionHash string `json:"description_hash"`
Preimage string `json:"preimage"`
PaymentHash string `json:"payment_hash"`
Amount int64 `json:"amount"`
FeesPaid int64 `json:"fees_paid"`
CreatedAt nostr.Timestamp `json:"created_at"`
ExpiresAt nostr.Timestamp `json:"expires_at"`
SettledAt nostr.Timestamp `json:"settled_at"`
Metadata string `json:"metadata"`
}
func (s *LookupInvoiceResponseResult) isExpired() bool {
if time.Now().Unix() > s.ExpiresAt.Time().Unix() {
return true
}
return false
}
func (s *LookupInvoiceResponseResult) isSettled() bool {
if s.SettledAt.Time().Unix() != 0 {
return true
}
return false
}
type LookupInvoiceResponse struct {
Result LookupInvoiceResponseResult `json:"result"`
ResultType string `json:"result_type"`
}
type MakeInvoiceParams struct {
Amount int64 `json:"amount"`
Description string `json:"description"`
@ -218,12 +171,12 @@ func GetLnurlp(w http.ResponseWriter, r *http.Request) {
var npk string
for _, element := range albyApps {
if element.Name == name {
npk = element.AppPubkey
npk = element.NostrPubkey
}
}
if len(npk) == 0 {
l.Debug().Msgf("user doesn't exist in alby %s@%s", name, domain)
l.Debug().Msgf("user doesn't exist in alby %s@%s: %s", name, domain, err.Error())
lnurlpReturnError := &lnurlpError{Status: "ERROR", Reason: "user does not exist"}
retError, _ := json.Marshal(lnurlpReturnError)
w.WriteHeader(http.StatusNotFound)
@ -254,7 +207,7 @@ func GetLnurlp(w http.ResponseWriter, r *http.Request) {
MaxSendable: 10000000,
Metadata: fmt.Sprintf("[[\"text/plain\", \"ln address payment to %s on the devvul server\"],[\"text/identifier\", \"%s@%s\"]]", name, name, domain),
AllowsNostr: true,
NostrPubkey: secret.ClientPubkey,
NostrPubkey: secret.AppPubkey,
}
ret, err := json.Marshal(lnurlpReturn)
@ -276,7 +229,7 @@ func GetLnurlp(w http.ResponseWriter, r *http.Request) {
func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
l := gologger.Get(config.GetConfig().LogLevel).With().Caller().Logger()
var nwc NWCReq
var zapEvent ZapEvent
var nwcNostr NWCReqNostr
// normalize domain
domain, _, err := net.SplitHostPort(r.Host)
@ -304,26 +257,14 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
}
secret.decodeSecret()
// if there is a nostr payload unmarshal it
if nwc.Nostr != "" {
err = json.Unmarshal([]byte(nwc.Nostr), &zapEvent)
if err != nil {
l.Debug().Msgf("unable to unmarshal nwc value: %s", err.Error())
lnurlpReturnError := &lnurlpError{Status: "ERROR", Reason: "unable to connect to relay"}
retError, _ := json.Marshal(lnurlpReturnError)
w.WriteHeader(http.StatusNotFound)
w.Write(retError)
return
}
ok := checkEvent(nwc.Nostr)
if !ok {
l.Debug().Msgf("nostr event is not valid", err.Error())
lnurlpReturnError := &lnurlpError{Status: "ERROR", Reason: "check your request and try again"}
retError, _ := json.Marshal(lnurlpReturnError)
w.WriteHeader(http.StatusNotFound)
w.Write(retError)
return
}
err = json.Unmarshal([]byte(nwc.Nostr), &nwcNostr)
if err != nil {
l.Debug().Msgf("unable to unmarshal nwc value: %s", err.Error())
lnurlpReturnError := &lnurlpError{Status: "ERROR", Reason: "unable to connect to relay"}
retError, _ := json.Marshal(lnurlpReturnError)
w.WriteHeader(http.StatusNotFound)
w.Write(retError)
return
}
// connect to the relay
@ -350,20 +291,10 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
return
}
var hash string
if nwc.Nostr != "" {
sha := sha256.Sum256([]byte(nwc.Nostr))
hash = hex.EncodeToString(sha[:])
}
if nwc.Nostr == "" {
hash = ""
}
invoiceParams := MakeInvoiceParams{
Amount: amt,
Description: nwc.Nostr,
DescriptionHash: hash,
Expiry: 300,
Amount: amt,
Description: nwc.Comment,
Expiry: 300,
}
invoice := MakeInvoice{
@ -371,6 +302,7 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
Params: invoiceParams,
}
// marshal the json
invoiceJson, err := json.Marshal(invoice)
if err != nil {
l.Debug().Msgf("unable to marshal invoice: %s", err.Error())
@ -405,7 +337,7 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
recipient := nostr.Tag{"p", secret.AppPubkey}
nwcEv := nostr.Event{
PubKey: secret.ClientPubkey,
PubKey: nwcNostr.Pubkey,
CreatedAt: nostr.Now(),
Kind: nostr.KindNWCWalletRequest,
Tags: nostr.Tags{recipient},
@ -427,7 +359,7 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
},
}
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
subCtx, subCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer subCancel()
sub, err := relay.Subscribe(subCtx, filters)
if err != nil {
@ -442,7 +374,7 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
var wg sync.WaitGroup
wg.Add(1)
// watch for the invoice
// append new messages to slice
evs := make([]nostr.Event, 0, 1)
go func() {
defer wg.Done()
@ -468,7 +400,6 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
}
}()
// publish the invoice request
if err := relay.Publish(relayCtx, nwcEv); err != nil {
l.Debug().Msgf("unable to marshal invoice: %s", err.Error())
lnurlpReturnError := &lnurlpError{Status: "ERROR", Reason: "unable to create an invoice"}
@ -477,13 +408,11 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
w.Write(retError)
return
}
// wait for the invoice to get returned
wg.Wait()
// decrypt the invoice
response, err := nip04.Decrypt(evs[0].Content, sharedSecret)
resStruct := MakeInvoiceResponse{}
err = json.Unmarshal([]byte(response), &resStruct)
json.Unmarshal([]byte(response), &resStruct)
if err != nil {
l.Debug().Msgf("unable to create invoice: %s", err.Error())
lnurlpReturnError := &lnurlpError{Status: "ERROR", Reason: "unable to connect to relay"}
@ -498,7 +427,6 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
Routes: []int{},
}
// return the invoice to the requester
ret, err := json.Marshal(retStruct)
if err != nil {
l.Error().Msgf("unable to marshal json for invoice: %s", err.Error())
@ -508,10 +436,6 @@ func GetLnurlpCallback(w http.ResponseWriter, r *http.Request) {
w.Write(retError)
return
}
if nwc.Nostr != "" {
l.Debug().Msgf("starting background job for invoice")
go watchForReceipt(nwc.Nostr, secret, retStruct.Invoice)
}
l.Info().Msg("returning lnurl-p payload")
w.WriteHeader(http.StatusOK)

14
go.mod
View file

@ -1,15 +1,15 @@
module git.devvul.com/asara/well-goknown
go 1.23.5
go 1.23.3
require (
git.devvul.com/asara/gologger v0.9.0
github.com/fiatjaf/eventstore v0.16.0
github.com/fiatjaf/khatru v0.15.0
github.com/fiatjaf/eventstore v0.14.4
github.com/fiatjaf/khatru v0.14.0
github.com/gorilla/schema v1.4.1
github.com/jmoiron/sqlx v1.4.0
github.com/lib/pq v1.10.9
github.com/nbd-wtf/go-nostr v0.48.4
github.com/nbd-wtf/go-nostr v0.45.0
)
require (
@ -18,10 +18,12 @@ require (
github.com/bep/debounce v1.2.1 // indirect
github.com/btcsuite/btcd/btcec/v2 v2.3.4 // indirect
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect
github.com/coder/websocket v1.8.12 // indirect
github.com/decred/dcrd/crypto/blake256 v1.1.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
github.com/fasthttp/websocket v1.5.12 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gobwas/ws v1.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.11 // indirect
@ -30,7 +32,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.0 // indirect
github.com/puzpuzpuz/xsync/v3 v3.4.0 // indirect
github.com/rs/cors v1.11.1 // indirect
github.com/rs/zerolog v1.33.0 // indirect
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect

35
go.sum
View file

@ -1,5 +1,6 @@
fiatjaf.com/lib v0.2.0 h1:TgIJESbbND6GjOgGHxF5jsO6EMjuAxIzZHPo5DXYexs=
fiatjaf.com/lib v0.2.0/go.mod h1:Ycqq3+mJ9jAWu7XjbQI1cVr+OFgnHn79dQR5oTII47g=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
git.devvul.com/asara/gologger v0.9.0 h1:gijJpkPjvzI5S/dmAXgYoKJbp5uuaETAOBYWo7bJg6U=
git.devvul.com/asara/gologger v0.9.0/go.mod h1:APr1DdVYByFfPUGHqHtRMhxphQbj92/vT/t0iM40H/0=
@ -11,10 +12,9 @@ github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurT
github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04=
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 h1:59Kx4K6lzOW5w6nFlA0v5+lk/6sjybR934QNHSJZPTQ=
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/crypto/blake256 v1.1.0 h1:zPMNGQCm0g4QTY27fOCorQW7EryeQ/U0x++OzVrdms8=
github.com/decred/dcrd/crypto/blake256 v1.1.0/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
@ -22,11 +22,18 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnN
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE=
github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg=
github.com/fiatjaf/eventstore v0.16.0 h1:r26aJeOwJTCbEevU8RVqp9FlcAgzKKqUWFH//x+Y+7M=
github.com/fiatjaf/eventstore v0.16.0/go.mod h1:KAsld5BhkmSck48aF11Txu8X+OGNmoabw4TlYVWqInc=
github.com/fiatjaf/khatru v0.15.0 h1:0aLWiTrdzoKD4WmW35GWL/Jsn4dACCUw325JKZg/AmI=
github.com/fiatjaf/khatru v0.15.0/go.mod h1:GBQJXZpitDatXF9RookRXcWB5zCJclCE4ufDK3jk80g=
github.com/fiatjaf/eventstore v0.14.4 h1:bqJQit/M5E6vwbWwgrL4kTPoWCbt1Hb9H/AH4xf9uVQ=
github.com/fiatjaf/eventstore v0.14.4/go.mod h1:3Kkujc6A8KjpNvSKu1jNCcFjSgEEyCxaDJVgShHz0J8=
github.com/fiatjaf/khatru v0.14.0 h1:zpWlAA87XBpDKBPIDbAuNw/HpKXzyt5XHVDbSvUbmDo=
github.com/fiatjaf/khatru v0.14.0/go.mod h1:uxE5e8DBXPZqbHjr/gfatQas5bEJIMmsOCDcdF4LoRQ=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
@ -51,17 +58,20 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nbd-wtf/go-nostr v0.48.4 h1:LYUWt+OSUjka/5mnZ7WFszHJ14UnSVxSp2108fRprEY=
github.com/nbd-wtf/go-nostr v0.48.4/go.mod h1:I7Ah6f7gPSzPFASZ1FmuaP0uQpycIeMEn9lKx7j+5GA=
github.com/nbd-wtf/go-nostr v0.45.0 h1:4WaMg0Yvda9gBcyRq9KtI32lPeFY8mbX0eFlfdnLrSE=
github.com/nbd-wtf/go-nostr v0.45.0/go.mod h1:m0ID2gSA2Oak/uaPnM1uN22JhDRZS4UVJG2c8jo19rg=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.0 h1:i+cMcpEDY1BkNm7lPDkCtE4oElsYLn+EKF8kAu2vXT4=
github.com/puzpuzpuz/xsync/v3 v3.5.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
@ -71,6 +81,8 @@ github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@ -82,6 +94,7 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbRuE=
github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
@ -92,3 +105,5 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -2,6 +2,7 @@ package nostr
import (
"context"
"time"
"git.devvul.com/asara/well-goknown/config"
"github.com/fiatjaf/eventstore/postgresql"
@ -56,12 +57,12 @@ func NewRelay(version string) *khatru.Relay {
relay.QueryEvents = append(relay.QueryEvents, RelayDb.QueryEvents)
relay.CountEvents = append(relay.CountEvents, RelayDb.CountEvents)
relay.DeleteEvent = append(relay.DeleteEvent, RelayDb.DeleteEvent)
relay.ReplaceEvent = append(relay.ReplaceEvent, RelayDb.ReplaceEvent)
// apply policies
relay.RejectEvent = append(relay.RejectEvent,
RejectUnregisteredNpubs,
policies.ValidateKind,
policies.EventIPRateLimiter(25, time.Minute*1, 100),
)
relay.RejectFilter = append(relay.RejectFilter,
@ -70,5 +71,9 @@ func NewRelay(version string) *khatru.Relay {
policies.NoComplexFilters,
)
relay.RejectConnection = append(relay.RejectConnection,
policies.ConnectionRateLimiter(50, time.Minute*5, 100),
)
return relay
}

View file

@ -1,13 +0,0 @@
Copyright (c) 2023 Anmol Sethi <hi@nhooyr.io>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

View file

@ -1,160 +0,0 @@
# websocket
[![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket)
[![Go Coverage](https://img.shields.io/badge/coverage-91%25-success)](https://github.com/coder/websocket/coverage.html)
websocket is a minimal and idiomatic WebSocket library for Go.
## Install
```sh
go get github.com/coder/websocket
```
> [!NOTE]
> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket).
> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from
> 2019 to 2024.
## Highlights
- Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports)
- JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- Zero alloc reads and writes
- Concurrent writes
- [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm)
## Roadmap
See GitHub issues for minor issues but the major future enhancements are:
- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217)
- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340)
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)
## Examples
For a production quality example that demonstrates the complete API, see the
[echo example](./internal/examples/echo).
For a full stack example, see the [chat example](./internal/examples/chat).
### Server
```go
http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
// ...
}
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
var v interface{}
err = wsjson.Read(ctx, c, &v)
if err != nil {
// ...
}
log.Printf("received: %v", v)
c.Close(websocket.StatusNormalClosure, "")
})
```
### Client
```go
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil {
// ...
}
defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi")
if err != nil {
// ...
}
c.Close(websocket.StatusNormalClosure, "")
```
## Comparison
### gorilla/websocket
Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
- Mature and widely used
- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
- No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection.
- Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411)
Advantages of github.com/coder/websocket:
- Minimal and idiomatic API
- Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support
- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
- Will enable easy HTTP/2 support in the future
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
- Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
#### golang.org/x/net/websocket
[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning
to github.com/coder/websocket.
#### gobwas/ws
[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
#### lesismal/nbio
[lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is
event driven for performance reasons.
However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.

View file

@ -1,352 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"strings"
"github.com/coder/websocket/internal/errd"
)
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection when c.Subprotocol() == "".
Subprotocols []string
// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
//
// You probably want to use OriginPatterns instead.
InsecureSkipVerify bool
// OriginPatterns lists the host patterns for authorized origins.
// The request host is always authorized.
// Use this to enable cross origin WebSockets.
//
// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
// In such a case, example.com is the origin and chat.example.com is the request host.
// One would set this field to []string{"example.com"} to authorize example.com to connect.
//
// Each pattern is matched case insensitively against the request origin host
// with filepath.Match.
// See https://golang.org/pkg/path/filepath/#Match
//
// Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks.
//
// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
// to bring attention to the danger of such a setting.
OriginPatterns []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
}
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
var o AcceptOptions
if opts != nil {
o = *opts
}
return &o
}
// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to a WebSocket.
//
// Accept will not allow cross origin requests by default.
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
//
// Accept will write a response to w on all errors.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts)
}
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")
errCode, err := verifyClientRequest(w, r)
if err != nil {
http.Error(w, err.Error(), errCode)
return nil, err
}
opts = opts.cloneWithDefaults()
if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil {
if errors.Is(err, filepath.ErrBadPattern) {
log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden))
}
http.Error(w, err.Error(), http.StatusForbidden)
return nil, err
}
}
hj, ok := w.(http.Hijacker)
if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
return nil, err
}
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
subproto := selectSubprotocol(r, opts.Subprotocols)
if subproto != "" {
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}
w.WriteHeader(http.StatusSwitchingProtocols)
// See https://github.com/nhooyr/websocket/issues/166
if ginWriter, ok := w.(interface {
WriteHeaderNow()
}); ok {
ginWriter.WriteHeaderNow()
}
netConn, brw, err := hj.Hijack()
if err != nil {
err = fmt.Errorf("failed to hijack connection: %w", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil, err
}
// https://github.com/golang/go/issues/32314
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
return newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
br: brw.Reader,
bw: brw.Writer,
}), nil
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
if !r.ProtoAtLeast(1, 1) {
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
}
if r.Method != "GET" {
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
}
// The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}
return 0, nil
}
func authenticateOrigin(r *http.Request, originHosts []string) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(r.Host, u.Host) {
return nil
}
for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host)
if err != nil {
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
}
if matched {
return nil
}
}
if u.Host == "" {
return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
}
func match(pattern, s string) (bool, error) {
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
}
func selectSubprotocol(r *http.Request, subprotocols []string) string {
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
for _, sp := range subprotocols {
for _, cp := range cps {
if strings.EqualFold(sp, cp) {
return cp
}
}
}
return ""
}
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, false
}
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, false
}
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, false
}
return copts, true
}
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
for _, t := range headerTokens(h, key) {
if strings.EqualFold(t, token) {
return true
}
}
return false
}
type websocketExtension struct {
name string
params []string
}
func websocketExtensions(h http.Header) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
for _, extStr := range extStrs {
if extStr == "" {
continue
}
vals := strings.Split(extStr, ";")
for i := range vals {
vals[i] = strings.TrimSpace(vals[i])
}
e := websocketExtension{
name: vals[0],
params: vals[1:],
}
exts = append(exts, e)
}
return exts
}
func headerTokens(h http.Header, key string) []string {
key = textproto.CanonicalMIMEHeaderKey(key)
var tokens []string
for _, v := range h[key] {
v = strings.TrimSpace(v)
for _, t := range strings.Split(v, ",") {
t = strings.TrimSpace(t)
tokens = append(tokens, t)
}
}
return tokens
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

View file

@ -1,348 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
"github.com/coder/websocket/internal/errd"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
if !c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
if !c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var err error
if ce.Code != StatusNoStatusRcvd {
p, err = ce.bytes()
if err != nil {
return err
}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
}
func (c *Conn) waitCloseHandshake() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()
for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}
for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}
const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
func (c *Conn) casClosing() bool {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if !c.closing {
c.closing = true
return true
}
return false
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}

View file

@ -1,233 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"compress/flate"
"io"
"sync"
)
// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int
const (
// CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota
// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// previous messages. i.e compression context across messages is preserved.
//
// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
//
// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
//
// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
CompressionContextTakeover
// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
//
// This means less efficient compression as the sliding window from previous messages will not be used but the
// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// Especially if the connections are long lived and seldom written to.
//
// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover
)
func (m CompressionMode) opts() *compressionOptions {
return &compressionOptions{
clientNoContextTakeover: m == CompressionNoContextTakeover,
serverNoContextTakeover: m == CompressionNoContextTakeover,
}
}
type compressionOptions struct {
clientNoContextTakeover bool
serverNoContextTakeover bool
}
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
return s
}
// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct {
w io.Writer
tail []byte
}
func (tw *trimLastFourBytesWriter) reset() {
if tw != nil && tw.tail != nil {
tw.tail = tw.tail[:0]
}
}
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
if tw.tail == nil {
tw.tail = make([]byte, 0, 4)
}
extra := len(tw.tail) + len(p) - 4
if extra <= 0 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Now we need to write as many extra bytes as we can from the previous tail.
if extra > len(tw.tail) {
extra = len(tw.tail)
}
if extra > 0 {
_, err := tw.w.Write(tw.tail[:extra])
if err != nil {
return 0, err
}
// Shift remaining bytes in tail over.
n := copy(tw.tail, tw.tail[extra:])
tw.tail = tw.tail[:n]
}
// If p is less than or equal to 4 bytes,
// all of it is is part of the tail.
if len(p) <= 4 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Otherwise, only the last 4 bytes are.
tw.tail = append(tw.tail, p[len(p)-4:]...)
p = p[:len(p)-4]
n, err := tw.w.Write(p)
return n + 4, err
}
var flateReaderPool sync.Pool
func getFlateReader(r io.Reader, dict []byte) io.Reader {
fr, ok := flateReaderPool.Get().(io.Reader)
if !ok {
return flate.NewReaderDict(r, dict)
}
fr.(flate.Resetter).Reset(r, dict)
return fr
}
func putFlateReader(fr io.Reader) {
flateReaderPool.Put(fr)
}
var flateWriterPool sync.Pool
func getFlateWriter(w io.Writer) *flate.Writer {
fw, ok := flateWriterPool.Get().(*flate.Writer)
if !ok {
fw, _ = flate.NewWriter(w, flate.BestSpeed)
return fw
}
fw.Reset(w)
return fw
}
func putFlateWriter(w *flate.Writer) {
flateWriterPool.Put(w)
}
type slidingWindow struct {
buf []byte
}
var swPoolMu sync.RWMutex
var swPool = map[int]*sync.Pool{}
func slidingWindowPool(n int) *sync.Pool {
swPoolMu.RLock()
p, ok := swPool[n]
swPoolMu.RUnlock()
if ok {
return p
}
p = &sync.Pool{}
swPoolMu.Lock()
swPool[n] = p
swPoolMu.Unlock()
return p
}
func (sw *slidingWindow) init(n int) {
if sw.buf != nil {
return
}
if n == 0 {
n = 32768
}
p := slidingWindowPool(n)
sw2, ok := p.Get().(*slidingWindow)
if ok {
*sw = *sw2
} else {
sw.buf = make([]byte, 0, n)
}
}
func (sw *slidingWindow) close() {
sw.buf = sw.buf[:0]
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock()
}
func (sw *slidingWindow) write(p []byte) {
if len(p) >= cap(sw.buf) {
sw.buf = sw.buf[:cap(sw.buf)]
p = p[len(p)-cap(sw.buf):]
copy(sw.buf, p)
return
}
left := cap(sw.buf) - len(sw.buf)
if left < len(p) {
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
spaceNeeded := len(p) - left
copy(sw.buf, sw.buf[spaceNeeded:])
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
}
sw.buf = append(sw.buf, p...)
}

View file

@ -1,295 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
//
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closed chan struct{}
closeMu sync.Mutex
closing bool
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
}
type connConfig struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
c := &Conn{
subprotocol: cfg.subprotocol,
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
flateThreshold: cfg.flateThreshold,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
}
c.readMu = newMu(c)
c.writeFrameMu = newMu(c)
c.msgReader = newMsgReader(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close()
})
go c.timeoutLoop()
return c
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.close()
return
case <-writeCtx.Done():
c.close()
return
}
}
}
func (c *Conn) flate() bool {
return c.copts != nil
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1)
err := c.ping(ctx, strconv.Itoa(int(p)))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context, p string) error {
pong := make(chan struct{}, 1)
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
}
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return net.ErrClosed
default:
}
return nil
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}

View file

@ -1,330 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coder/websocket/internal/errd"
)
// DialOptions represents Dial's options.
type DialOptions struct {
// HTTPClient is used for the connection.
// Its Transport must return writable bodies for WebSocket handshakes.
// http.Transport does beginning with Go 1.12.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
}
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
var cancel context.CancelFunc
var o DialOptions
if opts != nil {
o = *opts
}
if o.HTTPClient == nil {
o.HTTPClient = http.DefaultClient
}
if o.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
newClient := *o.HTTPClient
newClient.Timeout = 0
o.HTTPClient = &newClient
}
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
newClient := *o.HTTPClient
oldCheckRedirect := o.HTTPClient.CheckRedirect
newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect(req, via)
}
return nil
}
o.HTTPClient = &newClient
return ctx, cancel, &o
}
// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
//
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
//
// URLs with http/https schemes will work and are interpreted as ws/wss.
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
return dial(ctx, u, opts, nil)
}
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
defer errd.Wrap(&err, "failed to WebSocket dial")
var cancel context.CancelFunc
ctx, cancel, opts = opts.cloneWithDefaults(ctx)
if cancel != nil {
defer cancel()
}
secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
respBody := resp.Body
resp.Body = nil
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
})
defer timer.Stop()
b, _ := io.ReadAll(r)
respBody.Close()
resp.Body = io.NopCloser(bytes.NewReader(b))
}
}()
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
rwc, ok := respBody.(io.ReadWriteCloser)
if !ok {
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
}
return newConn(connConfig{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
rwc: rwc,
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
}
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
u, err := url.Parse(urls)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
case "http", "https":
default:
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create new http request: %w", err)
}
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send handshake request: %w", err)
}
return resp, nil
}
func secWebSocketKey(rr io.Reader) (string, error) {
if rr == nil {
rr = rand.Reader
}
b := make([]byte, 16)
_, err := io.ReadFull(rr, b)
if err != nil {
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}
err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return nil, err
}
return verifyServerExtensions(copts, resp.Header)
}
func verifySubprotocol(subprotos []string, resp *http.Response) error {
proto := resp.Header.Get("Sec-WebSocket-Protocol")
if proto == "" {
return nil
}
for _, sp2 := range subprotos {
if strings.EqualFold(sp2, proto) {
return nil
}
}
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}
ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}
_copts := *copts
copts = &_copts
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
return copts, nil
}
var bufioReaderPool sync.Pool
func getBufioReader(r io.Reader) *bufio.Reader {
br, ok := bufioReaderPool.Get().(*bufio.Reader)
if !ok {
return bufio.NewReader(r)
}
br.Reset(r)
return br
}
func putBufioReader(br *bufio.Reader) {
bufioReaderPool.Put(br)
}
var bufioWriterPool sync.Pool
func getBufioWriter(w io.Writer) *bufio.Writer {
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
if !ok {
return bufio.NewWriter(w)
}
bw.Reset(w)
return bw
}
func putBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}

View file

@ -1,34 +0,0 @@
//go:build !js
// +build !js
// Package websocket implements the RFC 6455 WebSocket protocol.
//
// https://tools.ietf.org/html/rfc6455
//
// Use Dial to dial a WebSocket server.
//
// Use Accept to accept a WebSocket client.
//
// Conn represents the resulting WebSocket connection.
//
// The examples are the best way to understand how to correctly use the library.
//
// The wsjson subpackage contain helpers for JSON and protobuf messages.
//
// More documentation at https://github.com/coder/websocket.
//
// # Wasm
//
// The client side supports compiling to Wasm.
// It wraps the WebSocket browser API.
//
// See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket
//
// Some important caveats to be aware of:
//
// - Accept always errors out
// - Conn.Ping is no-op
// - Conn.CloseNow is Close(StatusGoingAway, "")
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
package websocket // import "github.com/coder/websocket"

View file

@ -1,173 +0,0 @@
//go:build !js
package websocket
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"github.com/coder/websocket/internal/errd"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
type header struct {
fin bool
rsv1 bool
rsv2 bool
rsv3 bool
opcode opcode
payloadLength int64
masked bool
maskKey uint32
}
// readFrameHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
defer errd.Wrap(&err, "failed to read frame header")
b, err := r.ReadByte()
if err != nil {
return header{}, err
}
h.fin = b&(1<<7) != 0
h.rsv1 = b&(1<<6) != 0
h.rsv2 = b&(1<<5) != 0
h.rsv3 = b&(1<<4) != 0
h.opcode = opcode(b & 0xf)
b, err = r.ReadByte()
if err != nil {
return header{}, err
}
h.masked = b&(1<<7) != 0
payloadLength := b &^ (1 << 7)
switch {
case payloadLength < 126:
h.payloadLength = int64(payloadLength)
case payloadLength == 126:
_, err = io.ReadFull(r, readBuf[:2])
h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
case payloadLength == 127:
_, err = io.ReadFull(r, readBuf)
h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
}
if err != nil {
return header{}, err
}
if h.payloadLength < 0 {
return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
}
if h.masked {
_, err = io.ReadFull(r, readBuf[:4])
if err != nil {
return header{}, err
}
h.maskKey = binary.LittleEndian.Uint32(readBuf)
}
return h, nil
}
// maxControlPayload is the maximum length of a control frame payload.
// See https://tools.ietf.org/html/rfc6455#section-5.5.
const maxControlPayload = 125
// writeFrameHeader writes the bytes of the header to w.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
defer errd.Wrap(&err, "failed to write frame header")
var b byte
if h.fin {
b |= 1 << 7
}
if h.rsv1 {
b |= 1 << 6
}
if h.rsv2 {
b |= 1 << 5
}
if h.rsv3 {
b |= 1 << 4
}
b |= byte(h.opcode)
err = w.WriteByte(b)
if err != nil {
return err
}
lengthByte := byte(0)
if h.masked {
lengthByte |= 1 << 7
}
switch {
case h.payloadLength > math.MaxUint16:
lengthByte |= 127
case h.payloadLength > 125:
lengthByte |= 126
case h.payloadLength >= 0:
lengthByte |= byte(h.payloadLength)
}
err = w.WriteByte(lengthByte)
if err != nil {
return err
}
switch {
case h.payloadLength > math.MaxUint16:
binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
_, err = w.Write(buf)
case h.payloadLength > 125:
binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
_, err = w.Write(buf[:2])
}
if err != nil {
return err
}
if h.masked {
binary.LittleEndian.PutUint32(buf, h.maskKey)
_, err = w.Write(buf[:4])
if err != nil {
return err
}
}
return nil
}

View file

@ -1,24 +0,0 @@
package bpool
import (
"bytes"
"sync"
)
var bpool sync.Pool
// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
b := bpool.Get()
if b == nil {
return &bytes.Buffer{}
}
return b.(*bytes.Buffer)
}
// Put returns a buffer into the pool.
func Put(b *bytes.Buffer) {
b.Reset()
bpool.Put(b)
}

View file

@ -1,14 +0,0 @@
package errd
import (
"fmt"
)
// Wrap wraps err with fmt.Errorf if err is non nil.
// Intended for use with defer and a named error return.
// Inspired by https://github.com/golang/go/issues/32676.
func Wrap(err *error, f string, v ...interface{}) {
if *err != nil {
*err = fmt.Errorf(f+": %w", append(v, *err)...)
}
}

View file

@ -1,15 +0,0 @@
package util
// WriterFunc is used to implement one off io.Writers.
type WriterFunc func(p []byte) (int, error)
func (f WriterFunc) Write(p []byte) (int, error) {
return f(p)
}
// ReaderFunc is used to implement one off io.Readers.
type ReaderFunc func(p []byte) (int, error)
func (f ReaderFunc) Read(p []byte) (int, error) {
return f(p)
}

View file

@ -1,169 +0,0 @@
//go:build js
// +build js
// Package wsjs implements typed access to the browser javascript WebSocket API.
//
// https://developer.mozilla.org/en-US/docs/Web/API/WebSocket
package wsjs
import (
"syscall/js"
)
func handleJSError(err *error, onErr func()) {
r := recover()
if jsErr, ok := r.(js.Error); ok {
*err = jsErr
if onErr != nil {
onErr()
}
return
}
if r != nil {
panic(r)
}
}
// New is a wrapper around the javascript WebSocket constructor.
func New(url string, protocols []string) (c WebSocket, err error) {
defer handleJSError(&err, func() {
c = WebSocket{}
})
jsProtocols := make([]interface{}, len(protocols))
for i, p := range protocols {
jsProtocols[i] = p
}
c = WebSocket{
v: js.Global().Get("WebSocket").New(url, jsProtocols),
}
c.setBinaryType("arraybuffer")
return c, nil
}
// WebSocket is a wrapper around a javascript WebSocket object.
type WebSocket struct {
v js.Value
}
func (c WebSocket) setBinaryType(typ string) {
c.v.Set("binaryType", string(typ))
}
func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() {
f := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
fn(args[0])
return nil
})
c.v.Call("addEventListener", eventType, f)
return func() {
c.v.Call("removeEventListener", eventType, f)
f.Release()
}
}
// CloseEvent is the type passed to a WebSocket close handler.
type CloseEvent struct {
Code uint16
Reason string
WasClean bool
}
// OnClose registers a function to be called when the WebSocket is closed.
func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) {
return c.addEventListener("close", func(e js.Value) {
ce := CloseEvent{
Code: uint16(e.Get("code").Int()),
Reason: e.Get("reason").String(),
WasClean: e.Get("wasClean").Bool(),
}
fn(ce)
})
}
// OnError registers a function to be called when there is an error
// with the WebSocket.
func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) {
return c.addEventListener("error", fn)
}
// MessageEvent is the type passed to a message handler.
type MessageEvent struct {
// string or []byte.
Data interface{}
// There are more fields to the interface but we don't use them.
// See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent
}
// OnMessage registers a function to be called when the WebSocket receives a message.
func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) {
return c.addEventListener("message", func(e js.Value) {
var data interface{}
arrayBuffer := e.Get("data")
if arrayBuffer.Type() == js.TypeString {
data = arrayBuffer.String()
} else {
data = extractArrayBuffer(arrayBuffer)
}
me := MessageEvent{
Data: data,
}
fn(me)
})
}
// Subprotocol returns the WebSocket subprotocol in use.
func (c WebSocket) Subprotocol() string {
return c.v.Get("protocol").String()
}
// OnOpen registers a function to be called when the WebSocket is opened.
func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) {
return c.addEventListener("open", fn)
}
// Close closes the WebSocket with the given code and reason.
func (c WebSocket) Close(code int, reason string) (err error) {
defer handleJSError(&err, nil)
c.v.Call("close", code, reason)
return err
}
// SendText sends the given string as a text message
// on the WebSocket.
func (c WebSocket) SendText(v string) (err error) {
defer handleJSError(&err, nil)
c.v.Call("send", v)
return err
}
// SendBytes sends the given message as a binary message
// on the WebSocket.
func (c WebSocket) SendBytes(v []byte) (err error) {
defer handleJSError(&err, nil)
c.v.Call("send", uint8Array(v))
return err
}
func extractArrayBuffer(arrayBuffer js.Value) []byte {
uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer)
dst := make([]byte, uint8Array.Length())
js.CopyBytesToGo(dst, uint8Array)
return dst
}
func uint8Array(src []byte) js.Value {
uint8Array := js.Global().Get("Uint8Array").New(len(src))
js.CopyBytesToJS(uint8Array, src)
return uint8Array
}

View file

@ -1,26 +0,0 @@
package xsync
import (
"fmt"
"runtime/debug"
)
// Go allows running a function in another goroutine
// and waiting for its error.
func Go(fn func() error) <-chan error {
errs := make(chan error, 1)
go func() {
defer func() {
r := recover()
if r != nil {
select {
case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()):
default:
}
}
}()
errs <- fn()
}()
return errs
}

View file

@ -1,23 +0,0 @@
package xsync
import (
"sync/atomic"
)
// Int64 represents an atomic int64.
type Int64 struct {
// We do not use atomic.Load/StoreInt64 since it does not
// work on 32 bit computers but we need 64 bit integers.
i atomic.Value
}
// Load loads the int64.
func (v *Int64) Load() int64 {
i, _ := v.i.Load().(int64)
return i
}
// Store stores the int64.
func (v *Int64) Store(i int64) {
v.i.Store(i)
}

View file

@ -1,12 +0,0 @@
#!/bin/sh
set -eu
cd -- "$(dirname "$0")"
echo "=== fmt.sh"
./ci/fmt.sh
echo "=== lint.sh"
./ci/lint.sh
echo "=== test.sh"
./ci/test.sh "$@"
echo "=== bench.sh"
./ci/bench.sh

View file

@ -1,128 +0,0 @@
package websocket
import (
"encoding/binary"
"math/bits"
)
// maskGo applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func maskGo(b []byte, key uint32) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}
// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}
// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}

View file

@ -1,127 +0,0 @@
#include "textflag.h"
// func maskAsm(b *byte, len int, key uint32)
TEXT ·maskAsm(SB), NOSPLIT, $0-28
// AX = b
// CX = len (left length)
// SI = key (uint32)
// DI = uint64(SI) | uint64(SI)<<32
MOVQ b+0(FP), AX
MOVQ len+8(FP), CX
MOVL key+16(FP), SI
// calculate the DI
// DI = SI<<32 | SI
MOVL SI, DI
MOVQ DI, DX
SHLQ $32, DI
ORQ DX, DI
CMPQ CX, $15
JLE less_than_16
CMPQ CX, $63
JLE less_than_64
CMPQ CX, $128
JLE sse
TESTQ $31, AX
JNZ unaligned
unaligned_loop_1byte:
XORB SI, (AX)
INCQ AX
DECQ CX
ROLL $24, SI
TESTQ $7, AX
JNZ unaligned_loop_1byte
// calculate DI again since SI was modified
// DI = SI<<32 | SI
MOVL SI, DI
MOVQ DI, DX
SHLQ $32, DI
ORQ DX, DI
TESTQ $31, AX
JZ sse
unaligned:
TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b.
JNZ unaligned_loop_1byte
unaligned_loop:
// we don't need to check the CX since we know it's above 128
XORQ DI, (AX)
ADDQ $8, AX
SUBQ $8, CX
TESTQ $31, AX
JNZ unaligned_loop
JMP sse
sse:
CMPQ CX, $0x40
JL less_than_64
MOVQ DI, X0
PUNPCKLQDQ X0, X0
sse_loop:
MOVOU 0*16(AX), X1
MOVOU 1*16(AX), X2
MOVOU 2*16(AX), X3
MOVOU 3*16(AX), X4
PXOR X0, X1
PXOR X0, X2
PXOR X0, X3
PXOR X0, X4
MOVOU X1, 0*16(AX)
MOVOU X2, 1*16(AX)
MOVOU X3, 2*16(AX)
MOVOU X4, 3*16(AX)
ADDQ $0x40, AX
SUBQ $0x40, CX
CMPQ CX, $0x40
JAE sse_loop
less_than_64:
TESTQ $32, CX
JZ less_than_32
XORQ DI, (AX)
XORQ DI, 8(AX)
XORQ DI, 16(AX)
XORQ DI, 24(AX)
ADDQ $32, AX
less_than_32:
TESTQ $16, CX
JZ less_than_16
XORQ DI, (AX)
XORQ DI, 8(AX)
ADDQ $16, AX
less_than_16:
TESTQ $8, CX
JZ less_than_8
XORQ DI, (AX)
ADDQ $8, AX
less_than_8:
TESTQ $4, CX
JZ less_than_4
XORL SI, (AX)
ADDQ $4, AX
less_than_4:
TESTQ $2, CX
JZ less_than_2
XORW SI, (AX)
ROLL $16, SI
ADDQ $2, AX
less_than_2:
TESTQ $1, CX
JZ done
XORB SI, (AX)
ROLL $24, SI
done:
MOVL SI, ret+24(FP)
RET

View file

@ -1,72 +0,0 @@
#include "textflag.h"
// func maskAsm(b *byte, len int, key uint32)
TEXT ·maskAsm(SB), NOSPLIT, $0-28
// R0 = b
// R1 = len
// R3 = key (uint32)
// R2 = uint64(key)<<32 | uint64(key)
MOVD b_ptr+0(FP), R0
MOVD b_len+8(FP), R1
MOVWU key+16(FP), R3
MOVD R3, R2
ORR R2<<32, R2, R2
VDUP R2, V0.D2
CMP $64, R1
BLT less_than_64
loop_64:
VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16]
VEOR V1.B16, V0.B16, V1.B16
VEOR V2.B16, V0.B16, V2.B16
VEOR V3.B16, V0.B16, V3.B16
VEOR V4.B16, V0.B16, V4.B16
VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0)
SUBS $64, R1
CMP $64, R1
BGE loop_64
less_than_64:
CBZ R1, end
TBZ $5, R1, less_than_32
VLD1 (R0), [V1.B16, V2.B16]
VEOR V1.B16, V0.B16, V1.B16
VEOR V2.B16, V0.B16, V2.B16
VST1.P [V1.B16, V2.B16], 32(R0)
less_than_32:
TBZ $4, R1, less_than_16
LDP (R0), (R11, R12)
EOR R11, R2, R11
EOR R12, R2, R12
STP.P (R11, R12), 16(R0)
less_than_16:
TBZ $3, R1, less_than_8
MOVD (R0), R11
EOR R2, R11, R11
MOVD.P R11, 8(R0)
less_than_8:
TBZ $2, R1, less_than_4
MOVWU (R0), R11
EORW R2, R11, R11
MOVWU.P R11, 4(R0)
less_than_4:
TBZ $1, R1, less_than_2
MOVHU (R0), R11
EORW R3, R11, R11
MOVHU.P R11, 2(R0)
RORW $16, R3
less_than_2:
TBZ $0, R1, end
MOVBU (R0), R11
EORW R3, R11, R11
MOVBU.P R11, 1(R0)
RORW $8, R3
end:
MOVWU R3, ret+24(FP)
RET

View file

@ -1,26 +0,0 @@
//go:build amd64 || arm64
package websocket
func mask(b []byte, key uint32) uint32 {
// TODO: Will enable in v1.9.0.
return maskGo(b, key)
/*
if len(b) > 0 {
return maskAsm(&b[0], len(b), key)
}
return key
*/
}
// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this
// function are perfect. There are almost certainly missing optimizations or
// opportunities for simplification. I'm confident there are no bugs though.
// For example, the arm64 implementation doesn't align memory like the amd64.
// Or the amd64 implementation could use AVX512 instead of just AVX2.
// The AVX2 code I had to disable anyway as it wasn't performing as expected.
// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049
//
//go:noescape
//lint:ignore U1000 disabled till v1.9.0
func maskAsm(b *byte, len int, key uint32) uint32

View file

@ -1,7 +0,0 @@
//go:build !amd64 && !arm64 && !js
package websocket
func mask(b []byte, key uint32) uint32 {
return maskGo(b, key)
}

View file

@ -1,237 +0,0 @@
package websocket
import (
"context"
"fmt"
"io"
"math"
"net"
"sync/atomic"
"time"
)
// NetConn converts a *websocket.Conn into a net.Conn.
//
// It's for tunneling arbitrary protocols over WebSockets.
// Few users of the library will need this but it's tricky to implement
// correctly and so provided in the library.
// See https://github.com/nhooyr/websocket/issues/100.
//
// Every Write to the net.Conn will correspond to a message write of
// the given type on *websocket.Conn.
//
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
// all reads and writes on the net.Conn will be cancelled.
//
// If a message is read that is not of the correct type, the connection
// will be closed with StatusUnsupportedData and an error will be returned.
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return the real addresses for connections obtained
// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
// full net.Conn to us.
//
// When running as WASM, the Addr methods will always return the mock address described above.
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// Furthermore, the ReadLimit is set to -1 to disable it.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
c.SetReadLimit(-1)
nc := &netConn{
c: c,
msgType: msgType,
readMu: newMu(c),
writeMu: newMu(c),
}
nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
nc.readCtx, nc.readCancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.writeMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active write goroutine and so we should cancel the context.
nc.writeCancel()
return
}
defer nc.writeMu.unlock()
// Prevents future writes from writing until the deadline is reset.
atomic.StoreInt64(&nc.writeExpired, 1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.readMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active read goroutine and so we should cancel the context.
nc.readCancel()
return
}
defer nc.readMu.unlock()
// Prevents future reads from reading until the deadline is reset.
atomic.StoreInt64(&nc.readExpired, 1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
return nc
}
type netConn struct {
// These must be first to be aligned on 32 bit platforms.
// https://github.com/nhooyr/websocket/pull/438
readExpired int64
writeExpired int64
c *Conn
msgType MessageType
writeTimer *time.Timer
writeMu *mu
writeCtx context.Context
writeCancel context.CancelFunc
readTimer *time.Timer
readMu *mu
readCtx context.Context
readCancel context.CancelFunc
readEOFed bool
reader io.Reader
}
var _ net.Conn = &netConn{}
func (nc *netConn) Close() error {
nc.writeTimer.Stop()
nc.writeCancel()
nc.readTimer.Stop()
nc.readCancel()
return nc.c.Close(StatusNormalClosure, "")
}
func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()
if atomic.LoadInt64(&nc.writeExpired) == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}
err := nc.c.Write(nc.writeCtx, nc.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}
func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()
for {
n, err := nc.read(p)
if err != nil {
return n, err
}
if n == 0 {
continue
}
return n, nil
}
}
func (nc *netConn) read(p []byte) (int, error) {
if atomic.LoadInt64(&nc.readExpired) == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}
if nc.readEOFed {
return 0, io.EOF
}
if nc.reader == nil {
typ, r, err := nc.c.Reader(nc.readCtx)
if err != nil {
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
nc.readEOFed = true
return 0, io.EOF
}
return 0, err
}
if typ != nc.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
nc.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
nc.reader = r
}
n, err := nc.reader.Read(p)
if err == io.EOF {
nc.reader = nil
err = nil
}
return n, err
}
type websocketAddr struct {
}
func (a websocketAddr) Network() string {
return "websocket"
}
func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}
func (nc *netConn) SetDeadline(t time.Time) error {
nc.SetWriteDeadline(t)
nc.SetReadDeadline(t)
return nil
}
func (nc *netConn) SetWriteDeadline(t time.Time) error {
atomic.StoreInt64(&nc.writeExpired, 0)
if t.IsZero() {
nc.writeTimer.Stop()
} else {
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.writeTimer.Reset(dur)
}
return nil
}
func (nc *netConn) SetReadDeadline(t time.Time) error {
atomic.StoreInt64(&nc.readExpired, 0)
if t.IsZero() {
nc.readTimer.Stop()
} else {
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.readTimer.Reset(dur)
}
return nil
}

View file

@ -1,11 +0,0 @@
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}

View file

@ -1,20 +0,0 @@
//go:build !js
// +build !js
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.RemoteAddr()
}
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.LocalAddr()
}
return websocketAddr{}
}

View file

@ -1,506 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
"github.com/coder/websocket/internal/xsync"
)
// Reader reads from the connection until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and the Read itself,
// use time.AfterFunc to cancel the context passed in.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return c.reader(ctx)
}
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := io.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
//
// This function is idempotent.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadDone = make(chan struct{})
c.closeReadMu.Unlock()
go func() {
defer close(c.closeReadDone)
defer cancel()
defer c.close()
_, _, err := c.Reader(ctx)
if err == nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
//
// Set to -1 to disable.
func (c *Conn) SetReadLimit(n int64) {
if n >= 0 {
// We read one more byte than the limit in case
// there is a fin frame that needs to be read.
n++
}
c.msgReader.limitReader.limit.Store(n)
}
const defaultReadLimit = 32768
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.readFunc = mr.read
mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
return mr
}
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
func (mr *msgReader) putFlateReader() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
if mr.dict != nil {
mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
if mr.c.client {
putBufioReader(mr.c.br)
mr.c.br = nil
}
}
func (mr *msgReader) flateContextTakeover() bool {
if mr.c.client {
return !mr.c.copts.serverNoContextTakeover
}
return !mr.c.copts.clientNoContextTakeover
}
func (c *Conn) readRSV1Illegal(h header) bool {
// If compression is disabled, rsv1 is illegal.
if !c.flate() {
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
c.writeError(StatusProtocolError, err)
return header{}, err
}
if !c.client && !h.masked {
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
c.writeError(StatusProtocolError, err)
return header{}, err
}
}
}
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
select {
case <-c.closed:
return header{}, net.ErrClosed
case c.readTimeout <- ctx:
}
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
select {
case <-c.closed:
return header{}, net.ErrClosed
case <-ctx.Done():
return header{}, ctx.Err()
default:
return header{}, err
}
}
select {
case <-c.closed:
return header{}, net.ErrClosed
case c.readTimeout <- context.Background():
}
return h, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-c.closed:
return 0, net.ErrClosed
case c.readTimeout <- ctx:
}
n, err := io.ReadFull(c.br, p)
if err != nil {
select {
case <-c.closed:
return n, net.ErrClosed
case <-ctx.Done():
return n, ctx.Err()
default:
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
}
select {
case <-c.closed:
return n, net.ErrClosed
case c.readTimeout <- context.Background():
}
return n, err
}
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
c.writeError(StatusProtocolError, err)
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.readControlBuf[:h.payloadLength]
_, err = c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
mask(b, h.maskKey)
}
switch h.opcode {
case opPing:
return c.writeControl(ctx, opPong, b)
case opPong:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
select {
case pong <- struct{}{}:
default:
}
}
return nil
}
// opClose
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
c.writeError(StatusProtocolError, err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
c.writeClose(ce.Code, ce.Reason)
c.readMu.unlock()
c.close()
return err
}
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.lock(ctx)
if err != nil {
return 0, nil, err
}
defer c.readMu.unlock()
if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
}
h, err := c.readLoop(ctx)
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
c.writeError(StatusProtocolError, err)
return 0, nil, err
}
c.msgReader.reset(ctx, h)
return MessageType(h.opcode), c.msgReader, nil
}
type msgReader struct {
c *Conn
ctx context.Context
flate bool
flateReader io.Reader
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict *slidingWindow
fin bool
payloadLength int64
maskKey uint32
// util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc util.ReaderFunc
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.flate = h.rsv1
mr.limitReader.reset(mr.readFunc)
if mr.flate {
mr.resetFlate()
}
mr.setFrame(h)
}
func (mr *msgReader) setFrame(h header) {
mr.fin = h.fin
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()
n, err = mr.limitReader.Read(p)
if mr.flate && mr.flateContextTakeover() {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
return n, fmt.Errorf("failed to read: %w", err)
}
return n, nil
}
func (mr *msgReader) read(p []byte) (int, error) {
for {
if mr.payloadLength == 0 {
if mr.fin {
if mr.flate {
return mr.flateTail.Read(p)
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
mr.c.writeError(StatusProtocolError, err)
return 0, err
}
mr.setFrame(h)
continue
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
n, err := mr.c.readFramePayload(mr.ctx, p)
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
if !mr.c.client {
mr.maskKey = mask(p, mr.maskKey)
}
return n, nil
}
}
type limitReader struct {
c *Conn
r io.Reader
limit xsync.Int64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
lr.reset(r)
return lr
}
func (lr *limitReader) reset(r io.Reader) {
lr.n = lr.limit.Load()
lr.r = r
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n < 0 {
return lr.r.Read(p)
}
if lr.n == 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
if lr.n < 0 {
lr.n = 0
}
return n, err
}

View file

@ -1,91 +0,0 @@
// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT.
package websocket
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[opContinuation-0]
_ = x[opText-1]
_ = x[opBinary-2]
_ = x[opClose-8]
_ = x[opPing-9]
_ = x[opPong-10]
}
const (
_opcode_name_0 = "opContinuationopTextopBinary"
_opcode_name_1 = "opCloseopPingopPong"
)
var (
_opcode_index_0 = [...]uint8{0, 14, 20, 28}
_opcode_index_1 = [...]uint8{0, 7, 13, 19}
)
func (i opcode) String() string {
switch {
case 0 <= i && i <= 2:
return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]]
case 8 <= i && i <= 10:
i -= 8
return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]]
default:
return "opcode(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[MessageText-1]
_ = x[MessageBinary-2]
}
const _MessageType_name = "MessageTextMessageBinary"
var _MessageType_index = [...]uint8{0, 11, 24}
func (i MessageType) String() string {
i -= 1
if i < 0 || i >= MessageType(len(_MessageType_index)-1) {
return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")"
}
return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]]
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[StatusNormalClosure-1000]
_ = x[StatusGoingAway-1001]
_ = x[StatusProtocolError-1002]
_ = x[StatusUnsupportedData-1003]
_ = x[statusReserved-1004]
_ = x[StatusNoStatusRcvd-1005]
_ = x[StatusAbnormalClosure-1006]
_ = x[StatusInvalidFramePayloadData-1007]
_ = x[StatusPolicyViolation-1008]
_ = x[StatusMessageTooBig-1009]
_ = x[StatusMandatoryExtension-1010]
_ = x[StatusInternalError-1011]
_ = x[StatusServiceRestart-1012]
_ = x[StatusTryAgainLater-1013]
_ = x[StatusBadGateway-1014]
_ = x[StatusTLSHandshake-1015]
}
const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake"
var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312}
func (i StatusCode) String() string {
i -= 1000
if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) {
return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")"
}
return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]]
}

View file

@ -1,376 +0,0 @@
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"compress/flate"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
return w, nil
}
// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type msgWriter struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context.Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate.Writer
}
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write),
}
}
if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter)
}
mw.flate = true
}
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return c.msgWriter, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = mw.Close()
return n, err
}
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
}
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false
mw.trimWriter.reset()
return nil
}
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
}
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
mw.ensureFlate()
}
}
if mw.flate {
return mw.flateWriter.Write(p)
}
return mw.write(p)
}
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
mw.opcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate: %w", err)
}
}
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.flate && !mw.flateContextTakeover() {
mw.putFlateWriter()
}
mw.mu.unlock()
return nil
}
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
}
mw.writeMu.forceLock()
mw.putFlateWriter()
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
select {
case <-c.closed:
return 0, net.ErrClosed
case c.writeTimeout <- ctx:
}
defer func() {
if err != nil {
select {
case <-c.closed:
err = net.ErrClosed
case <-ctx.Done():
err = ctx.Err()
default:
}
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if c.client {
c.writeHeader.masked = true
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
}
c.writeHeader.rsv1 = false
if flate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
if err != nil {
return 0, err
}
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
select {
case <-c.closed:
if opcode == opClose {
return n, nil
}
return n, net.ErrClosed
case c.writeTimeout <- context.Background():
}
return n, nil
}
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !c.writeHeader.masked {
return c.bw.Write(p)
}
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
j := len(p)
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
p = p[j:]
n += j
}
return n, nil
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
bw.WriteByte(0)
bw.Flush()
bw.Reset(w)
return writeBuf
}
func (c *Conn) writeError(code StatusCode, err error) {
c.writeClose(code, err.Error())
}

View file

@ -1,598 +0,0 @@
package websocket // import "github.com/coder/websocket"
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"syscall/js"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/wsjs"
"github.com/coder/websocket/internal/xsync"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// Conn provides a wrapper around the browser WebSocket API.
type Conn struct {
noCopy noCopy
ws wsjs.WebSocket
// read limit for a message in bytes.
msgReadLimit xsync.Int64
closeReadMu sync.Mutex
closeReadCtx context.Context
closingMu sync.Mutex
closeOnce sync.Once
closed chan struct{}
closeErrOnce sync.Once
closeErr error
closeWasClean bool
releaseOnClose func()
releaseOnError func()
releaseOnMessage func()
readSignal chan struct{}
readBufMu sync.Mutex
readBuf []wsjs.MessageEvent
}
func (c *Conn) close(err error, wasClean bool) {
c.closeOnce.Do(func() {
runtime.SetFinalizer(c, nil)
if !wasClean {
err = fmt.Errorf("unclean connection close: %w", err)
}
c.setCloseErr(err)
c.closeWasClean = wasClean
close(c.closed)
})
}
func (c *Conn) init() {
c.closed = make(chan struct{})
c.readSignal = make(chan struct{}, 1)
c.msgReadLimit.Store(32768)
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
err := CloseError{
Code: StatusCode(e.Code),
Reason: e.Reason,
}
// We do not know if we sent or received this close as
// its possible the browser triggered it without us
// explicitly sending it.
c.close(err, e.WasClean)
c.releaseOnClose()
c.releaseOnError()
c.releaseOnMessage()
})
c.releaseOnError = c.ws.OnError(func(v js.Value) {
c.setCloseErr(errors.New(v.Get("message").String()))
c.closeWithInternal()
})
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
c.readBuf = append(c.readBuf, e)
// Lets the read goroutine know there is definitely something in readBuf.
select {
case c.readSignal <- struct{}{}:
default:
}
})
runtime.SetFinalizer(c, func(c *Conn) {
c.setCloseErr(errors.New("connection garbage collected"))
c.closeWithInternal()
})
}
func (c *Conn) closeWithInternal() {
c.Close(StatusInternalError, "something went wrong")
}
// Read attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
c.closeReadMu.Lock()
closedRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closedRead {
return 0, nil, errors.New("WebSocket connection read closed")
}
typ, p, err := c.read(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to read: %w", err)
}
readLimit := c.msgReadLimit.Load()
if readLimit >= 0 && int64(len(p)) > readLimit {
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
c.Close(StatusMessageTooBig, err.Error())
return 0, nil, err
}
return typ, p, nil
}
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "read timed out")
return 0, nil, ctx.Err()
case <-c.readSignal:
case <-c.closed:
return 0, nil, net.ErrClosed
}
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
me := c.readBuf[0]
// We copy the messages forward and decrease the size
// of the slice to avoid reallocating.
copy(c.readBuf, c.readBuf[1:])
c.readBuf = c.readBuf[:len(c.readBuf)-1]
if len(c.readBuf) > 0 {
// Next time we read, we'll grab the message.
select {
case c.readSignal <- struct{}{}:
default:
}
}
switch p := me.Data.(type) {
case string:
return MessageText, []byte(p), nil
case []byte:
return MessageBinary, p, nil
default:
panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
}
}
// Ping is mocked out for Wasm.
func (c *Conn) Ping(ctx context.Context) error {
return nil
}
// Write writes a message of the given type to the connection.
// Always non blocking.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
err := c.write(ctx, typ, p)
if err != nil {
// Have to ensure the WebSocket is closed after a write error
// to match the Go API. It can only error if the message type
// is unexpected or the passed bytes contain invalid UTF-8 for
// MessageText.
err := fmt.Errorf("failed to write: %w", err)
c.setCloseErr(err)
c.closeWithInternal()
return err
}
return nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
if c.isClosed() {
return net.ErrClosed
}
switch typ {
case MessageBinary:
return c.ws.SendBytes(p)
case MessageText:
return c.ws.SendText(string(p))
default:
return fmt.Errorf("unexpected message type: %v", typ)
}
}
// Close closes the WebSocket with the given code and reason.
// It will wait until the peer responds with a close frame
// or the connection is closed.
// It thus performs the full WebSocket close handshake.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
if err != nil {
return fmt.Errorf("failed to close WebSocket: %w", err)
}
return nil
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
//
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
// a WebSocket without the close handshake.
func (c *Conn) CloseNow() error {
return c.Close(StatusGoingAway, "")
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
c.closingMu.Lock()
defer c.closingMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
ce := fmt.Errorf("sent close: %w", CloseError{
Code: code,
Reason: reason,
})
c.setCloseErr(ce)
err := c.ws.Close(int(code), reason)
if err != nil {
return err
}
<-c.closed
if !c.closeWasClean {
return c.closeErr
}
return nil
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.ws.Subprotocol()
}
// DialOptions represents the options available to pass to Dial.
type DialOptions struct {
// Subprotocols lists the subprotocols to negotiate with the server.
Subprotocols []string
}
// Dial creates a new WebSocket connection to the given url with the given options.
// The passed context bounds the maximum time spent waiting for the connection to open.
// The returned *http.Response is always nil or a mock. It's only in the signature
// to match the core API.
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
c, resp, err := dial(ctx, url, opts)
if err != nil {
return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
}
return c, resp, nil
}
func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
if opts == nil {
opts = &DialOptions{}
}
url = strings.Replace(url, "http://", "ws://", 1)
url = strings.Replace(url, "https://", "wss://", 1)
ws, err := wsjs.New(url, opts.Subprotocols)
if err != nil {
return nil, nil, err
}
c := &Conn{
ws: ws,
}
c.init()
opench := make(chan struct{})
releaseOpen := ws.OnOpen(func(e js.Value) {
close(opench)
})
defer releaseOpen()
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "dial timed out")
return nil, nil, ctx.Err()
case <-opench:
return c, &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}, nil
case <-c.closed:
return nil, nil, net.ErrClosed
}
}
// Reader attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
typ, p, err := c.Read(ctx)
if err != nil {
return 0, nil, err
}
return typ, bytes.NewReader(p), nil
}
// Writer returns a writer to write a WebSocket data message to the connection.
// It buffers the entire message in memory and then sends it when the writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
return &writer{
c: c,
ctx: ctx,
typ: typ,
b: bpool.Get(),
}, nil
}
type writer struct {
closed bool
c *Conn
ctx context.Context
typ MessageType
b *bytes.Buffer
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, errors.New("cannot write to closed writer")
}
n, err := w.b.Write(p)
if err != nil {
return n, fmt.Errorf("failed to write message: %w", err)
}
return n, nil
}
func (w *writer) Close() error {
if w.closed {
return errors.New("cannot close closed writer")
}
w.closed = true
defer bpool.Put(w.b)
err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
if err != nil {
return fmt.Errorf("failed to close writer: %w", err)
}
return nil
}
// CloseRead implements *Conn.CloseRead for wasm.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadMu.Unlock()
go func() {
defer cancel()
defer c.CloseNow()
_, _, err := c.read(ctx)
if err != nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit implements *Conn.SetReadLimit for wasm.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit.Store(n)
}
func (c *Conn) setCloseErr(err error) {
c.closeErrOnce.Do(func() {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
})
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
Subprotocols []string
InsecureSkipVerify bool
OriginPatterns []string
CompressionMode CompressionMode
CompressionThreshold int
}
// Accept is stubbed out for Wasm.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, errors.New("unimplemented")
}
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int
const (
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
// for every message. This applies to both server and client side.
//
// This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used.
//
// The message will only be compressed if greater than 512 bytes.
CompressionNoContextTakeover CompressionMode = iota
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
// This enables reusing the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
//
// If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC.
CompressionContextTakeover
// CompressionDisabled disables the deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
CompressionDisabled
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}

View file

@ -1 +0,0 @@
knowledge.md

View file

@ -1,13 +1,10 @@
package postgresql
import (
"sync"
"github.com/jmoiron/sqlx"
)
type PostgresBackend struct {
sync.Mutex
*sqlx.DB
DatabaseURL string
QueryLimit int

View file

@ -1,44 +0,0 @@
package postgresql
import (
"context"
"fmt"
"github.com/fiatjaf/eventstore"
"github.com/fiatjaf/eventstore/internal"
"github.com/nbd-wtf/go-nostr"
)
func (b *PostgresBackend) ReplaceEvent(ctx context.Context, evt *nostr.Event) error {
b.Lock()
defer b.Unlock()
filter := nostr.Filter{Limit: 1, Kinds: []int{evt.Kind}, Authors: []string{evt.PubKey}}
if nostr.IsAddressableKind(evt.Kind) {
filter.Tags = nostr.TagMap{"d": []string{evt.Tags.GetD()}}
}
ch, err := b.QueryEvents(ctx, filter)
if err != nil {
return fmt.Errorf("failed to query before replacing: %w", err)
}
shouldStore := true
for previous := range ch {
if internal.IsOlder(previous, evt) {
if err := b.DeleteEvent(ctx, previous); err != nil {
return fmt.Errorf("failed to delete event for replacing: %w", err)
}
} else {
shouldStore = false
}
}
if shouldStore {
if err := b.SaveEvent(ctx, evt); err != nil && err != eventstore.ErrDupEvent {
return fmt.Errorf("failed to save: %w", err)
}
}
return nil
}

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/fiatjaf/eventstore/internal"
"github.com/nbd-wtf/go-nostr"
)
@ -30,8 +31,41 @@ func (w RelayWrapper) Publish(ctx context.Context, evt nostr.Event) error {
return nil
}
// others are replaced
w.Store.ReplaceEvent(ctx, &evt)
// from now on we know they are replaceable or addressable
if replacer, ok := w.Store.(Replacer); ok {
// use the replacer interface to potentially reduce queries and race conditions
replacer.Replace(ctx, &evt)
} else {
// otherwise do it the manual way
filter := nostr.Filter{Limit: 1, Kinds: []int{evt.Kind}, Authors: []string{evt.PubKey}}
if nostr.IsAddressableKind(evt.Kind) {
// when addressable, add the "d" tag to the filter
filter.Tags = nostr.TagMap{"d": []string{evt.Tags.GetD()}}
}
// now we fetch the past events, whatever they are, delete them and then save the new
ch, err := w.Store.QueryEvents(ctx, filter)
if err != nil {
return fmt.Errorf("failed to query before replacing: %w", err)
}
shouldStore := true
for previous := range ch {
if internal.IsOlder(previous, &evt) {
if err := w.Store.DeleteEvent(ctx, previous); err != nil {
return fmt.Errorf("failed to delete event for replacing: %w", err)
}
} else {
// there is a newer event already stored, so we won't store this
shouldStore = false
}
}
if shouldStore {
if err := w.SaveEvent(ctx, &evt); err != nil && err != ErrDupEvent {
return fmt.Errorf("failed to save: %w", err)
}
}
}
return nil
}

View file

@ -22,9 +22,10 @@ type Store interface {
DeleteEvent(context.Context, *nostr.Event) error
// SaveEvent just saves an event, no side-effects.
SaveEvent(context.Context, *nostr.Event) error
// ReplaceEvent atomically replaces a replaceable or addressable event.
// Conceptually it is like a Query->Delete->Save, but streamlined.
ReplaceEvent(context.Context, *nostr.Event) error
}
type Replacer interface {
Replace(context.Context, *nostr.Event) error
}
type Counter interface {

View file

@ -1,2 +1,2 @@
*.env
knowledge.md
rss-bridge

View file

@ -127,7 +127,6 @@ Fear no more. Using the https://github.com/fiatjaf/eventstore module you get a b
relay.QueryEvents = append(relay.QueryEvents, db.QueryEvents)
relay.CountEvents = append(relay.CountEvents, db.CountEvents)
relay.DeleteEvent = append(relay.DeleteEvent, db.DeleteEvent)
relay.ReplaceEvent = append(relay.ReplaceEvent, db.ReplaceEvent)
```
### But I don't want to write a bunch of custom policies!

View file

@ -106,9 +106,6 @@ func (rl *Relay) AddEvent(ctx context.Context, evt *nostr.Event) (skipBroadcast
for _, ons := range rl.OnEventSaved {
ons(ctx, evt)
}
// track event expiration if applicable
rl.expirationManager.trackEvent(evt)
}
return false, nil

View file

@ -1,135 +0,0 @@
package khatru
import (
"container/heap"
"context"
"sync"
"time"
"github.com/nbd-wtf/go-nostr"
"github.com/nbd-wtf/go-nostr/nip40"
)
type expiringEvent struct {
id string
expiresAt nostr.Timestamp
}
type expiringEventHeap []expiringEvent
func (h expiringEventHeap) Len() int { return len(h) }
func (h expiringEventHeap) Less(i, j int) bool { return h[i].expiresAt < h[j].expiresAt }
func (h expiringEventHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *expiringEventHeap) Push(x interface{}) {
*h = append(*h, x.(expiringEvent))
}
func (h *expiringEventHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
type expirationManager struct {
events expiringEventHeap
mu sync.Mutex
relay *Relay
interval time.Duration
initialScanDone bool
}
func newExpirationManager(relay *Relay) *expirationManager {
return &expirationManager{
events: make(expiringEventHeap, 0),
relay: relay,
interval: time.Hour,
}
}
func (em *expirationManager) start(ctx context.Context) {
ticker := time.NewTicker(em.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if !em.initialScanDone {
em.initialScan(ctx)
em.initialScanDone = true
}
em.checkExpiredEvents(ctx)
}
}
}
func (em *expirationManager) initialScan(ctx context.Context) {
em.mu.Lock()
defer em.mu.Unlock()
// query all events
for _, query := range em.relay.QueryEvents {
ch, err := query(ctx, nostr.Filter{})
if err != nil {
continue
}
for evt := range ch {
if expiresAt := nip40.GetExpiration(evt.Tags); expiresAt != -1 {
heap.Push(&em.events, expiringEvent{
id: evt.ID,
expiresAt: expiresAt,
})
}
}
}
heap.Init(&em.events)
}
func (em *expirationManager) checkExpiredEvents(ctx context.Context) {
em.mu.Lock()
defer em.mu.Unlock()
now := nostr.Now()
// keep deleting events from the heap as long as they're expired
for em.events.Len() > 0 {
next := em.events[0]
if now < next.expiresAt {
break
}
heap.Pop(&em.events)
for _, query := range em.relay.QueryEvents {
ch, err := query(ctx, nostr.Filter{IDs: []string{next.id}})
if err != nil {
continue
}
if evt := <-ch; evt != nil {
for _, del := range em.relay.DeleteEvent {
del(ctx, evt)
}
}
break
}
}
}
func (em *expirationManager) trackEvent(evt *nostr.Event) {
if expiresAt := nip40.GetExpiration(evt.Tags); expiresAt != -1 {
em.mu.Lock()
heap.Push(&em.events, expiringEvent{
id: evt.ID,
expiresAt: expiresAt,
})
em.mu.Unlock()
}
}

View file

@ -24,6 +24,10 @@ import (
// ServeHTTP implements http.Handler interface.
func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if rl.ServiceURL == "" {
rl.ServiceURL = getServiceBaseURL(r)
}
corsMiddleware := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
@ -315,7 +319,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
id := string(*env)
rl.removeListenerId(ws, id)
case *nostr.AuthEnvelope:
wsBaseUrl := strings.Replace(rl.getBaseURL(r), "http", "ws", 1)
wsBaseUrl := strings.Replace(rl.ServiceURL, "http", "ws", 1)
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
ws.AuthedPublicKey = pubkey
ws.authLock.Lock()

View file

@ -3,6 +3,7 @@ package khatru
import (
"net"
"net/http"
"strconv"
"strings"
"github.com/nbd-wtf/go-nostr"
@ -13,6 +14,28 @@ func isOlder(previous, next *nostr.Event) bool {
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
}
func getServiceBaseURL(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if host == "localhost" {
proto = "http"
} else if strings.Index(host, ":") != -1 {
// has a port number
proto = "http"
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
// it's a naked IP
proto = "http"
} else {
proto = "https"
}
}
return proto + "://" + host
}
var privateMasks = func() []net.IPNet {
privateCIDRs := []string{
"127.0.0.0/8",

View file

@ -80,7 +80,7 @@ func (rl *Relay) HandleNIP86(w http.ResponseWriter, r *http.Request) {
goto respond
}
if uTag := evt.Tags.GetFirst([]string{"u", ""}); uTag == nil || rl.getBaseURL(r) != (*uTag)[1] {
if uTag := evt.Tags.GetFirst([]string{"u", ""}); uTag == nil || rl.ServiceURL != (*uTag)[1] {
resp.Error = "invalid 'u' tag"
goto respond
} else if pht := evt.Tags.GetFirst([]string{"payload", hex.EncodeToString(payloadHash[:])}); pht == nil {

View file

@ -9,7 +9,7 @@ import (
func ApplySaneDefaults(relay *khatru.Relay) {
relay.RejectEvent = append(relay.RejectEvent,
RejectEventsWithBase64Media,
EventIPRateLimiter(2, time.Minute*3, 10),
EventIPRateLimiter(2, time.Minute*3, 5),
)
relay.RejectFilter = append(relay.RejectFilter,
@ -18,6 +18,6 @@ func ApplySaneDefaults(relay *khatru.Relay) {
)
relay.RejectConnection = append(relay.RejectConnection,
ConnectionRateLimiter(1, time.Minute*5, 100),
ConnectionRateLimiter(1, time.Minute*5, 10),
)
}

View file

@ -5,8 +5,6 @@ import (
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
@ -17,15 +15,13 @@ import (
)
func NewRelay() *Relay {
ctx := context.Background()
rl := &Relay{
Log: log.New(os.Stderr, "[khatru-relay] ", log.LstdFlags),
Info: &nip11.RelayInformationDocument{
Software: "https://github.com/fiatjaf/khatru",
Version: "n/a",
SupportedNIPs: []any{1, 11, 40, 42, 70, 86},
SupportedNIPs: []any{1, 11, 42, 70, 86},
},
upgrader: websocket.Upgrader{
@ -45,14 +41,10 @@ func NewRelay() *Relay {
MaxMessageSize: 512000,
}
rl.expirationManager = newExpirationManager(rl)
go rl.expirationManager.start(ctx)
return rl
}
type Relay struct {
// setting this variable overwrites the hackish workaround we do to try to figure out our own base URL
ServiceURL string
// hooks that will be called at various times
@ -113,33 +105,4 @@ type Relay struct {
PongWait time.Duration // Time allowed to read the next pong message from the peer.
PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait.
MaxMessageSize int64 // Maximum message size allowed from peer.
// NIP-40 expiration manager
expirationManager *expirationManager
}
func (rl *Relay) getBaseURL(r *http.Request) string {
if rl.ServiceURL != "" {
return rl.ServiceURL
}
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if host == "localhost" {
proto = "http"
} else if strings.Index(host, ":") != -1 {
// has a port number
proto = "http"
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
// it's a naked IP
proto = "http"
} else {
proto = "https"
}
}
return proto + "://" + host
}

21
vendor/github.com/gobwas/httphead/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017 Sergey Kamardin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

63
vendor/github.com/gobwas/httphead/README.md generated vendored Normal file
View file

@ -0,0 +1,63 @@
# httphead.[go](https://golang.org)
[![GoDoc][godoc-image]][godoc-url]
> Tiny HTTP header value parsing library in go.
## Overview
This library contains low-level functions for scanning HTTP RFC2616 compatible header value grammars.
## Install
```shell
go get github.com/gobwas/httphead
```
## Example
The example below shows how multiple-choise HTTP header value could be parsed with this library:
```go
options, ok := httphead.ParseOptions([]byte(`foo;bar=1,baz`), nil)
fmt.Println(options, ok)
// Output: [{foo map[bar:1]} {baz map[]}] true
```
The low-level example below shows how to optimize keys skipping and selection
of some key:
```go
// The right part of full header line like:
// X-My-Header: key;foo=bar;baz,key;baz
header := []byte(`foo;a=0,foo;a=1,foo;a=2,foo;a=3`)
// We want to search key "foo" with an "a" parameter that equal to "2".
var (
foo = []byte(`foo`)
a = []byte(`a`)
v = []byte(`2`)
)
var found bool
httphead.ScanOptions(header, func(i int, key, param, value []byte) Control {
if !bytes.Equal(key, foo) {
return ControlSkip
}
if !bytes.Equal(param, a) {
if bytes.Equal(value, v) {
// Found it!
found = true
return ControlBreak
}
return ControlSkip
}
return ControlContinue
})
```
For more usage examples please see [docs][godoc-url] or package tests.
[godoc-image]: https://godoc.org/github.com/gobwas/httphead?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/httphead
[travis-image]: https://travis-ci.org/gobwas/httphead.svg?branch=master
[travis-url]: https://travis-ci.org/gobwas/httphead

200
vendor/github.com/gobwas/httphead/cookie.go generated vendored Normal file
View file

@ -0,0 +1,200 @@
package httphead
import (
"bytes"
)
// ScanCookie scans cookie pairs from data using DefaultCookieScanner.Scan()
// method.
func ScanCookie(data []byte, it func(key, value []byte) bool) bool {
return DefaultCookieScanner.Scan(data, it)
}
// DefaultCookieScanner is a CookieScanner which is used by ScanCookie().
// Note that it is intended to have the same behavior as http.Request.Cookies()
// has.
var DefaultCookieScanner = CookieScanner{}
// CookieScanner contains options for scanning cookie pairs.
// See https://tools.ietf.org/html/rfc6265#section-4.1.1
type CookieScanner struct {
// DisableNameValidation disables name validation of a cookie. If false,
// only RFC2616 "tokens" are accepted.
DisableNameValidation bool
// DisableValueValidation disables value validation of a cookie. If false,
// only RFC6265 "cookie-octet" characters are accepted.
//
// Note that Strict option also affects validation of a value.
//
// If Strict is false, then scanner begins to allow space and comma
// characters inside the value for better compatibility with non standard
// cookies implementations.
DisableValueValidation bool
// BreakOnPairError sets scanner to immediately return after first pair syntax
// validation error.
// If false, scanner will try to skip invalid pair bytes and go ahead.
BreakOnPairError bool
// Strict enables strict RFC6265 mode scanning. It affects name and value
// validation, as also some other rules.
// If false, it is intended to bring the same behavior as
// http.Request.Cookies().
Strict bool
}
// Scan maps data to name and value pairs. Usually data represents value of the
// Cookie header.
func (c CookieScanner) Scan(data []byte, it func(name, value []byte) bool) bool {
lexer := &Scanner{data: data}
const (
statePair = iota
stateBefore
)
state := statePair
for lexer.Buffered() > 0 {
switch state {
case stateBefore:
// Pairs separated by ";" and space, according to the RFC6265:
// cookie-pair *( ";" SP cookie-pair )
//
// Cookie pairs MUST be separated by (";" SP). So our only option
// here is to fail as syntax error.
a, b := lexer.Peek2()
if a != ';' {
return false
}
state = statePair
advance := 1
if b == ' ' {
advance++
} else if c.Strict {
return false
}
lexer.Advance(advance)
case statePair:
if !lexer.FetchUntil(';') {
return false
}
var value []byte
name := lexer.Bytes()
if i := bytes.IndexByte(name, '='); i != -1 {
value = name[i+1:]
name = name[:i]
} else if c.Strict {
if !c.BreakOnPairError {
goto nextPair
}
return false
}
if !c.Strict {
trimLeft(name)
}
if !c.DisableNameValidation && !ValidCookieName(name) {
if !c.BreakOnPairError {
goto nextPair
}
return false
}
if !c.Strict {
value = trimRight(value)
}
value = stripQuotes(value)
if !c.DisableValueValidation && !ValidCookieValue(value, c.Strict) {
if !c.BreakOnPairError {
goto nextPair
}
return false
}
if !it(name, value) {
return true
}
nextPair:
state = stateBefore
}
}
return true
}
// ValidCookieValue reports whether given value is a valid RFC6265
// "cookie-octet" bytes.
//
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
// ; US-ASCII characters excluding CTLs,
// ; whitespace DQUOTE, comma, semicolon,
// ; and backslash
//
// Note that the false strict parameter disables errors on space 0x20 and comma
// 0x2c. This could be useful to bring some compatibility with non-compliant
// clients/servers in the real world.
// It acts the same as standard library cookie parser if strict is false.
func ValidCookieValue(value []byte, strict bool) bool {
if len(value) == 0 {
return true
}
for _, c := range value {
switch c {
case '"', ';', '\\':
return false
case ',', ' ':
if strict {
return false
}
default:
if c <= 0x20 {
return false
}
if c >= 0x7f {
return false
}
}
}
return true
}
// ValidCookieName reports wheter given bytes is a valid RFC2616 "token" bytes.
func ValidCookieName(name []byte) bool {
for _, c := range name {
if !OctetTypes[c].IsToken() {
return false
}
}
return true
}
func stripQuotes(bts []byte) []byte {
if last := len(bts) - 1; last > 0 && bts[0] == '"' && bts[last] == '"' {
return bts[1:last]
}
return bts
}
func trimLeft(p []byte) []byte {
var i int
for i < len(p) && OctetTypes[p[i]].IsSpace() {
i++
}
return p[i:]
}
func trimRight(p []byte) []byte {
j := len(p)
for j > 0 && OctetTypes[p[j-1]].IsSpace() {
j--
}
return p[:j]
}

275
vendor/github.com/gobwas/httphead/head.go generated vendored Normal file
View file

@ -0,0 +1,275 @@
package httphead
import (
"bufio"
"bytes"
)
// Version contains protocol major and minor version.
type Version struct {
Major int
Minor int
}
// RequestLine contains parameters parsed from the first request line.
type RequestLine struct {
Method []byte
URI []byte
Version Version
}
// ResponseLine contains parameters parsed from the first response line.
type ResponseLine struct {
Version Version
Status int
Reason []byte
}
// SplitRequestLine splits given slice of bytes into three chunks without
// parsing.
func SplitRequestLine(line []byte) (method, uri, version []byte) {
return split3(line, ' ')
}
// ParseRequestLine parses http request line like "GET / HTTP/1.0".
func ParseRequestLine(line []byte) (r RequestLine, ok bool) {
var i int
for i = 0; i < len(line); i++ {
c := line[i]
if !OctetTypes[c].IsToken() {
if i > 0 && c == ' ' {
break
}
return
}
}
if i == len(line) {
return
}
var proto []byte
r.Method = line[:i]
r.URI, proto = split2(line[i+1:], ' ')
if len(r.URI) == 0 {
return
}
if major, minor, ok := ParseVersion(proto); ok {
r.Version.Major = major
r.Version.Minor = minor
return r, true
}
return r, false
}
// SplitResponseLine splits given slice of bytes into three chunks without
// parsing.
func SplitResponseLine(line []byte) (version, status, reason []byte) {
return split3(line, ' ')
}
// ParseResponseLine parses first response line into ResponseLine struct.
func ParseResponseLine(line []byte) (r ResponseLine, ok bool) {
var (
proto []byte
status []byte
)
proto, status, r.Reason = split3(line, ' ')
if major, minor, ok := ParseVersion(proto); ok {
r.Version.Major = major
r.Version.Minor = minor
} else {
return r, false
}
if n, ok := IntFromASCII(status); ok {
r.Status = n
} else {
return r, false
}
// TODO(gobwas): parse here r.Reason fot TEXT rule:
// TEXT = <any OCTET except CTLs,
// but including LWS>
return r, true
}
var (
httpVersion10 = []byte("HTTP/1.0")
httpVersion11 = []byte("HTTP/1.1")
httpVersionPrefix = []byte("HTTP/")
)
// ParseVersion parses major and minor version of HTTP protocol.
// It returns parsed values and true if parse is ok.
func ParseVersion(bts []byte) (major, minor int, ok bool) {
switch {
case bytes.Equal(bts, httpVersion11):
return 1, 1, true
case bytes.Equal(bts, httpVersion10):
return 1, 0, true
case len(bts) < 8:
return
case !bytes.Equal(bts[:5], httpVersionPrefix):
return
}
bts = bts[5:]
dot := bytes.IndexByte(bts, '.')
if dot == -1 {
return
}
major, ok = IntFromASCII(bts[:dot])
if !ok {
return
}
minor, ok = IntFromASCII(bts[dot+1:])
if !ok {
return
}
return major, minor, true
}
// ReadLine reads line from br. It reads until '\n' and returns bytes without
// '\n' or '\r\n' at the end.
// It returns err if and only if line does not end in '\n'. Note that read
// bytes returned in any case of error.
//
// It is much like the textproto/Reader.ReadLine() except the thing that it
// returns raw bytes, instead of string. That is, it avoids copying bytes read
// from br.
//
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
// safe with future I/O operations on br.
//
// We could control I/O operations on br and do not need to make additional
// copy for safety.
func ReadLine(br *bufio.Reader) ([]byte, error) {
var line []byte
for {
bts, err := br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
// Copy bytes because next read will discard them.
line = append(line, bts...)
continue
}
// Avoid copy of single read.
if line == nil {
line = bts
} else {
line = append(line, bts...)
}
if err != nil {
return line, err
}
// Size of line is at least 1.
// In other case bufio.ReadSlice() returns error.
n := len(line)
// Cut '\n' or '\r\n'.
if n > 1 && line[n-2] == '\r' {
line = line[:n-2]
} else {
line = line[:n-1]
}
return line, nil
}
}
// ParseHeaderLine parses HTTP header as key-value pair. It returns parsed
// values and true if parse is ok.
func ParseHeaderLine(line []byte) (k, v []byte, ok bool) {
colon := bytes.IndexByte(line, ':')
if colon == -1 {
return
}
k = trim(line[:colon])
for _, c := range k {
if !OctetTypes[c].IsToken() {
return nil, nil, false
}
}
v = trim(line[colon+1:])
return k, v, true
}
// IntFromASCII converts ascii encoded decimal numeric value from HTTP entities
// to an integer.
func IntFromASCII(bts []byte) (ret int, ok bool) {
// ASCII numbers all start with the high-order bits 0011.
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
// bits and interpret them directly as an integer.
var n int
if n = len(bts); n < 1 {
return 0, false
}
for i := 0; i < n; i++ {
if bts[i]&0xf0 != 0x30 {
return 0, false
}
ret += int(bts[i]&0xf) * pow(10, n-i-1)
}
return ret, true
}
const (
toLower = 'a' - 'A' // for use with OR.
toUpper = ^byte(toLower) // for use with AND.
)
// CanonicalizeHeaderKey is like standard textproto/CanonicalMIMEHeaderKey,
// except that it operates with slice of bytes and modifies it inplace without
// copying.
func CanonicalizeHeaderKey(k []byte) {
upper := true
for i, c := range k {
if upper && 'a' <= c && c <= 'z' {
k[i] &= toUpper
} else if !upper && 'A' <= c && c <= 'Z' {
k[i] |= toLower
}
upper = c == '-'
}
}
// pow for integers implementation.
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
func pow(a, b int) int {
p := 1
for b > 0 {
if b&1 != 0 {
p *= a
}
b >>= 1
a *= a
}
return p
}
func split3(p []byte, sep byte) (p1, p2, p3 []byte) {
a := bytes.IndexByte(p, sep)
b := bytes.IndexByte(p[a+1:], sep)
if a == -1 || b == -1 {
return p, nil, nil
}
b += a + 1
return p[:a], p[a+1 : b], p[b+1:]
}
func split2(p []byte, sep byte) (p1, p2 []byte) {
i := bytes.IndexByte(p, sep)
if i == -1 {
return p, nil
}
return p[:i], p[i+1:]
}
func trim(p []byte) []byte {
var i, j int
for i = 0; i < len(p) && (p[i] == ' ' || p[i] == '\t'); {
i++
}
for j = len(p); j > i && (p[j-1] == ' ' || p[j-1] == '\t'); {
j--
}
return p[i:j]
}

331
vendor/github.com/gobwas/httphead/httphead.go generated vendored Normal file
View file

@ -0,0 +1,331 @@
// Package httphead contains utils for parsing HTTP and HTTP-grammar compatible
// text protocols headers.
//
// That is, this package first aim is to bring ability to easily parse
// constructions, described here https://tools.ietf.org/html/rfc2616#section-2
package httphead
import (
"bytes"
"strings"
)
// ScanTokens parses data in this form:
//
// list = 1#token
//
// It returns false if data is malformed.
func ScanTokens(data []byte, it func([]byte) bool) bool {
lexer := &Scanner{data: data}
var ok bool
for lexer.Next() {
switch lexer.Type() {
case ItemToken:
ok = true
if !it(lexer.Bytes()) {
return true
}
case ItemSeparator:
if !isComma(lexer.Bytes()) {
return false
}
default:
return false
}
}
return ok && !lexer.err
}
// ParseOptions parses all header options and appends it to given slice of
// Option. It returns flag of successful (wellformed input) parsing.
//
// Note that appended options are all consist of subslices of data. That is,
// mutation of data will mutate appended options.
func ParseOptions(data []byte, options []Option) ([]Option, bool) {
var i int
index := -1
return options, ScanOptions(data, func(idx int, name, attr, val []byte) Control {
if idx != index {
index = idx
i = len(options)
options = append(options, Option{Name: name})
}
if attr != nil {
options[i].Parameters.Set(attr, val)
}
return ControlContinue
})
}
// SelectFlag encodes way of options selection.
type SelectFlag byte
// String represetns flag as string.
func (f SelectFlag) String() string {
var flags [2]string
var n int
if f&SelectCopy != 0 {
flags[n] = "copy"
n++
}
if f&SelectUnique != 0 {
flags[n] = "unique"
n++
}
return "[" + strings.Join(flags[:n], "|") + "]"
}
const (
// SelectCopy causes selector to copy selected option before appending it
// to resulting slice.
// If SelectCopy flag is not passed to selector, then appended options will
// contain sub-slices of the initial data.
SelectCopy SelectFlag = 1 << iota
// SelectUnique causes selector to append only not yet existing option to
// resulting slice. Unique is checked by comparing option names.
SelectUnique
)
// OptionSelector contains configuration for selecting Options from header value.
type OptionSelector struct {
// Check is a filter function that applied to every Option that possibly
// could be selected.
// If Check is nil all options will be selected.
Check func(Option) bool
// Flags contains flags for options selection.
Flags SelectFlag
// Alloc used to allocate slice of bytes when selector is configured with
// SelectCopy flag. It will be called with number of bytes needed for copy
// of single Option.
// If Alloc is nil make is used.
Alloc func(n int) []byte
}
// Select parses header data and appends it to given slice of Option.
// It also returns flag of successful (wellformed input) parsing.
func (s OptionSelector) Select(data []byte, options []Option) ([]Option, bool) {
var current Option
var has bool
index := -1
alloc := s.Alloc
if alloc == nil {
alloc = defaultAlloc
}
check := s.Check
if check == nil {
check = defaultCheck
}
ok := ScanOptions(data, func(idx int, name, attr, val []byte) Control {
if idx != index {
if has && check(current) {
if s.Flags&SelectCopy != 0 {
current = current.Copy(alloc(current.Size()))
}
options = append(options, current)
has = false
}
if s.Flags&SelectUnique != 0 {
for i := len(options) - 1; i >= 0; i-- {
if bytes.Equal(options[i].Name, name) {
return ControlSkip
}
}
}
index = idx
current = Option{Name: name}
has = true
}
if attr != nil {
current.Parameters.Set(attr, val)
}
return ControlContinue
})
if has && check(current) {
if s.Flags&SelectCopy != 0 {
current = current.Copy(alloc(current.Size()))
}
options = append(options, current)
}
return options, ok
}
func defaultAlloc(n int) []byte { return make([]byte, n) }
func defaultCheck(Option) bool { return true }
// Control represents operation that scanner should perform.
type Control byte
const (
// ControlContinue causes scanner to continue scan tokens.
ControlContinue Control = iota
// ControlBreak causes scanner to stop scan tokens.
ControlBreak
// ControlSkip causes scanner to skip current entity.
ControlSkip
)
// ScanOptions parses data in this form:
//
// values = 1#value
// value = token *( ";" param )
// param = token [ "=" (token | quoted-string) ]
//
// It calls given callback with the index of the option, option itself and its
// parameter (attribute and its value, both could be nil). Index is useful when
// header contains multiple choises for the same named option.
//
// Given callback should return one of the defined Control* values.
// ControlSkip means that passed key is not in caller's interest. That is, all
// parameters of that key will be skipped.
// ControlBreak means that no more keys and parameters should be parsed. That
// is, it must break parsing immediately.
// ControlContinue means that caller want to receive next parameter and its
// value or the next key.
//
// It returns false if data is malformed.
func ScanOptions(data []byte, it func(index int, option, attribute, value []byte) Control) bool {
lexer := &Scanner{data: data}
var ok bool
var state int
const (
stateKey = iota
stateParamBeforeName
stateParamName
stateParamBeforeValue
stateParamValue
)
var (
index int
key, param, value []byte
mustCall bool
)
for lexer.Next() {
var (
call bool
growIndex int
)
t := lexer.Type()
v := lexer.Bytes()
switch t {
case ItemToken:
switch state {
case stateKey, stateParamBeforeName:
key = v
state = stateParamBeforeName
mustCall = true
case stateParamName:
param = v
state = stateParamBeforeValue
mustCall = true
case stateParamValue:
value = v
state = stateParamBeforeName
call = true
default:
return false
}
case ItemString:
if state != stateParamValue {
return false
}
value = v
state = stateParamBeforeName
call = true
case ItemSeparator:
switch {
case isComma(v) && state == stateKey:
// Nothing to do.
case isComma(v) && state == stateParamBeforeName:
state = stateKey
// Make call only if we have not called this key yet.
call = mustCall
if !call {
// If we have already called callback with the key
// that just ended.
index++
} else {
// Else grow the index after calling callback.
growIndex = 1
}
case isComma(v) && state == stateParamBeforeValue:
state = stateKey
growIndex = 1
call = true
case isSemicolon(v) && state == stateParamBeforeName:
state = stateParamName
case isSemicolon(v) && state == stateParamBeforeValue:
state = stateParamName
call = true
case isEquality(v) && state == stateParamBeforeValue:
state = stateParamValue
default:
return false
}
default:
return false
}
if call {
switch it(index, key, param, value) {
case ControlBreak:
// User want to stop to parsing parameters.
return true
case ControlSkip:
// User want to skip current param.
state = stateKey
lexer.SkipEscaped(',')
case ControlContinue:
// User is interested in rest of parameters.
// Nothing to do.
default:
panic("unexpected control value")
}
ok = true
param = nil
value = nil
mustCall = false
index += growIndex
}
}
if mustCall {
ok = true
it(index, key, param, value)
}
return ok && !lexer.err
}
func isComma(b []byte) bool {
return len(b) == 1 && b[0] == ','
}
func isSemicolon(b []byte) bool {
return len(b) == 1 && b[0] == ';'
}
func isEquality(b []byte) bool {
return len(b) == 1 && b[0] == '='
}

360
vendor/github.com/gobwas/httphead/lexer.go generated vendored Normal file
View file

@ -0,0 +1,360 @@
package httphead
import (
"bytes"
)
// ItemType encodes type of the lexing token.
type ItemType int
const (
// ItemUndef reports that token is undefined.
ItemUndef ItemType = iota
// ItemToken reports that token is RFC2616 token.
ItemToken
// ItemSeparator reports that token is RFC2616 separator.
ItemSeparator
// ItemString reports that token is RFC2616 quouted string.
ItemString
// ItemComment reports that token is RFC2616 comment.
ItemComment
// ItemOctet reports that token is octet slice.
ItemOctet
)
// Scanner represents header tokens scanner.
// See https://tools.ietf.org/html/rfc2616#section-2
type Scanner struct {
data []byte
pos int
itemType ItemType
itemBytes []byte
err bool
}
// NewScanner creates new RFC2616 data scanner.
func NewScanner(data []byte) *Scanner {
return &Scanner{data: data}
}
// Next scans for next token. It returns true on successful scanning, and false
// on error or EOF.
func (l *Scanner) Next() bool {
c, ok := l.nextChar()
if !ok {
return false
}
switch c {
case '"': // quoted-string;
return l.fetchQuotedString()
case '(': // comment;
return l.fetchComment()
case '\\', ')': // unexpected chars;
l.err = true
return false
default:
return l.fetchToken()
}
}
// FetchUntil fetches ItemOctet from current scanner position to first
// occurence of the c or to the end of the underlying data.
func (l *Scanner) FetchUntil(c byte) bool {
l.resetItem()
if l.pos == len(l.data) {
return false
}
return l.fetchOctet(c)
}
// Peek reads byte at current position without advancing it. On end of data it
// returns 0.
func (l *Scanner) Peek() byte {
if l.pos == len(l.data) {
return 0
}
return l.data[l.pos]
}
// Peek2 reads two first bytes at current position without advancing it.
// If there not enough data it returs 0.
func (l *Scanner) Peek2() (a, b byte) {
if l.pos == len(l.data) {
return 0, 0
}
if l.pos+1 == len(l.data) {
return l.data[l.pos], 0
}
return l.data[l.pos], l.data[l.pos+1]
}
// Buffered reporst how many bytes there are left to scan.
func (l *Scanner) Buffered() int {
return len(l.data) - l.pos
}
// Advance moves current position index at n bytes. It returns true on
// successful move.
func (l *Scanner) Advance(n int) bool {
l.pos += n
if l.pos > len(l.data) {
l.pos = len(l.data)
return false
}
return true
}
// Skip skips all bytes until first occurence of c.
func (l *Scanner) Skip(c byte) {
if l.err {
return
}
// Reset scanner state.
l.resetItem()
if i := bytes.IndexByte(l.data[l.pos:], c); i == -1 {
// Reached the end of data.
l.pos = len(l.data)
} else {
l.pos += i + 1
}
}
// SkipEscaped skips all bytes until first occurence of non-escaped c.
func (l *Scanner) SkipEscaped(c byte) {
if l.err {
return
}
// Reset scanner state.
l.resetItem()
if i := ScanUntil(l.data[l.pos:], c); i == -1 {
// Reached the end of data.
l.pos = len(l.data)
} else {
l.pos += i + 1
}
}
// Type reports current token type.
func (l *Scanner) Type() ItemType {
return l.itemType
}
// Bytes returns current token bytes.
func (l *Scanner) Bytes() []byte {
return l.itemBytes
}
func (l *Scanner) nextChar() (byte, bool) {
// Reset scanner state.
l.resetItem()
if l.err {
return 0, false
}
l.pos += SkipSpace(l.data[l.pos:])
if l.pos == len(l.data) {
return 0, false
}
return l.data[l.pos], true
}
func (l *Scanner) resetItem() {
l.itemType = ItemUndef
l.itemBytes = nil
}
func (l *Scanner) fetchOctet(c byte) bool {
i := l.pos
if j := bytes.IndexByte(l.data[l.pos:], c); j == -1 {
// Reached the end of data.
l.pos = len(l.data)
} else {
l.pos += j
}
l.itemType = ItemOctet
l.itemBytes = l.data[i:l.pos]
return true
}
func (l *Scanner) fetchToken() bool {
n, t := ScanToken(l.data[l.pos:])
if n == -1 {
l.err = true
return false
}
l.itemType = t
l.itemBytes = l.data[l.pos : l.pos+n]
l.pos += n
return true
}
func (l *Scanner) fetchQuotedString() (ok bool) {
l.pos++
n := ScanUntil(l.data[l.pos:], '"')
if n == -1 {
l.err = true
return false
}
l.itemType = ItemString
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
l.pos += n + 1
return true
}
func (l *Scanner) fetchComment() (ok bool) {
l.pos++
n := ScanPairGreedy(l.data[l.pos:], '(', ')')
if n == -1 {
l.err = true
return false
}
l.itemType = ItemComment
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
l.pos += n + 1
return true
}
// ScanUntil scans for first non-escaped character c in given data.
// It returns index of matched c and -1 if c is not found.
func ScanUntil(data []byte, c byte) (n int) {
for {
i := bytes.IndexByte(data[n:], c)
if i == -1 {
return -1
}
n += i
if n == 0 || data[n-1] != '\\' {
break
}
n++
}
return
}
// ScanPairGreedy scans for complete pair of opening and closing chars in greedy manner.
// Note that first opening byte must not be present in data.
func ScanPairGreedy(data []byte, open, close byte) (n int) {
var m int
opened := 1
for {
i := bytes.IndexByte(data[n:], close)
if i == -1 {
return -1
}
n += i
// If found index is not escaped then it is the end.
if n == 0 || data[n-1] != '\\' {
opened--
}
for m < i {
j := bytes.IndexByte(data[m:i], open)
if j == -1 {
break
}
m += j + 1
opened++
}
if opened == 0 {
break
}
n++
m = n
}
return
}
// RemoveByte returns data without c. If c is not present in data it returns
// the same slice. If not, it copies data without c.
func RemoveByte(data []byte, c byte) []byte {
j := bytes.IndexByte(data, c)
if j == -1 {
return data
}
n := len(data) - 1
// If character is present, than allocate slice with n-1 capacity. That is,
// resulting bytes could be at most n-1 length.
result := make([]byte, n)
k := copy(result, data[:j])
for i := j + 1; i < n; {
j = bytes.IndexByte(data[i:], c)
if j != -1 {
k += copy(result[k:], data[i:i+j])
i = i + j + 1
} else {
k += copy(result[k:], data[i:])
break
}
}
return result[:k]
}
// SkipSpace skips spaces and lws-sequences from p.
// It returns number ob bytes skipped.
func SkipSpace(p []byte) (n int) {
for len(p) > 0 {
switch {
case len(p) >= 3 &&
p[0] == '\r' &&
p[1] == '\n' &&
OctetTypes[p[2]].IsSpace():
p = p[3:]
n += 3
case OctetTypes[p[0]].IsSpace():
p = p[1:]
n++
default:
return
}
}
return
}
// ScanToken scan for next token in p. It returns length of the token and its
// type. It do not trim p.
func ScanToken(p []byte) (n int, t ItemType) {
if len(p) == 0 {
return 0, ItemUndef
}
c := p[0]
switch {
case OctetTypes[c].IsSeparator():
return 1, ItemSeparator
case OctetTypes[c].IsToken():
for n = 1; n < len(p); n++ {
c := p[n]
if !OctetTypes[c].IsToken() {
break
}
}
return n, ItemToken
default:
return -1, ItemUndef
}
}

83
vendor/github.com/gobwas/httphead/octet.go generated vendored Normal file
View file

@ -0,0 +1,83 @@
package httphead
// OctetType desribes character type.
//
// From the "Basic Rules" chapter of RFC2616
// See https://tools.ietf.org/html/rfc2616#section-2.2
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// UPALPHA = <any US-ASCII uppercase letter "A".."Z">
// LOALPHA = <any US-ASCII lowercase letter "a".."z">
// ALPHA = UPALPHA | LOALPHA
// DIGIT = <any US-ASCII digit "0".."9">
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
//
// Many HTTP/1.1 header field values consist of words separated by LWS
// or special characters. These special characters MUST be in a quoted
// string to be used within a parameter value (as defined in section
// 3.6).
//
// token = 1*<any CHAR except CTLs or separators>
// separators = "(" | ")" | "<" | ">" | "@"
// | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "="
// | "{" | "}" | SP | HT
type OctetType byte
// IsChar reports whether octet is CHAR.
func (t OctetType) IsChar() bool { return t&octetChar != 0 }
// IsControl reports whether octet is CTL.
func (t OctetType) IsControl() bool { return t&octetControl != 0 }
// IsSeparator reports whether octet is separator.
func (t OctetType) IsSeparator() bool { return t&octetSeparator != 0 }
// IsSpace reports whether octet is space (SP or HT).
func (t OctetType) IsSpace() bool { return t&octetSpace != 0 }
// IsToken reports whether octet is token.
func (t OctetType) IsToken() bool { return t&octetToken != 0 }
const (
octetChar OctetType = 1 << iota
octetControl
octetSpace
octetSeparator
octetToken
)
// OctetTypes is a table of octets.
var OctetTypes [256]OctetType
func init() {
for c := 32; c < 256; c++ {
var t OctetType
if c <= 127 {
t |= octetChar
}
if 0 <= c && c <= 31 || c == 127 {
t |= octetControl
}
switch c {
case '(', ')', '<', '>', '@', ',', ';', ':', '"', '/', '[', ']', '?', '=', '{', '}', '\\':
t |= octetSeparator
case ' ', '\t':
t |= octetSpace | octetSeparator
}
if t.IsChar() && !t.IsControl() && !t.IsSeparator() && !t.IsSpace() {
t |= octetToken
}
OctetTypes[c] = t
}
}

193
vendor/github.com/gobwas/httphead/option.go generated vendored Normal file
View file

@ -0,0 +1,193 @@
package httphead
import (
"bytes"
"sort"
)
// Option represents a header option.
type Option struct {
Name []byte
Parameters Parameters
}
// Size returns number of bytes need to be allocated for use in opt.Copy.
func (opt Option) Size() int {
return len(opt.Name) + opt.Parameters.bytes
}
// Copy copies all underlying []byte slices into p and returns new Option.
// Note that p must be at least of opt.Size() length.
func (opt Option) Copy(p []byte) Option {
n := copy(p, opt.Name)
opt.Name = p[:n]
opt.Parameters, p = opt.Parameters.Copy(p[n:])
return opt
}
// Clone is a shorthand for making slice of opt.Size() sequenced with Copy()
// call.
func (opt Option) Clone() Option {
return opt.Copy(make([]byte, opt.Size()))
}
// String represents option as a string.
func (opt Option) String() string {
return "{" + string(opt.Name) + " " + opt.Parameters.String() + "}"
}
// NewOption creates named option with given parameters.
func NewOption(name string, params map[string]string) Option {
p := Parameters{}
for k, v := range params {
p.Set([]byte(k), []byte(v))
}
return Option{
Name: []byte(name),
Parameters: p,
}
}
// Equal reports whether option is equal to b.
func (opt Option) Equal(b Option) bool {
if bytes.Equal(opt.Name, b.Name) {
return opt.Parameters.Equal(b.Parameters)
}
return false
}
// Parameters represents option's parameters.
type Parameters struct {
pos int
bytes int
arr [8]pair
dyn []pair
}
// Equal reports whether a equal to b.
func (p Parameters) Equal(b Parameters) bool {
switch {
case p.dyn == nil && b.dyn == nil:
case p.dyn != nil && b.dyn != nil:
default:
return false
}
ad, bd := p.data(), b.data()
if len(ad) != len(bd) {
return false
}
sort.Sort(pairs(ad))
sort.Sort(pairs(bd))
for i := 0; i < len(ad); i++ {
av, bv := ad[i], bd[i]
if !bytes.Equal(av.key, bv.key) || !bytes.Equal(av.value, bv.value) {
return false
}
}
return true
}
// Size returns number of bytes that needed to copy p.
func (p *Parameters) Size() int {
return p.bytes
}
// Copy copies all underlying []byte slices into dst and returns new
// Parameters.
// Note that dst must be at least of p.Size() length.
func (p *Parameters) Copy(dst []byte) (Parameters, []byte) {
ret := Parameters{
pos: p.pos,
bytes: p.bytes,
}
if p.dyn != nil {
ret.dyn = make([]pair, len(p.dyn))
for i, v := range p.dyn {
ret.dyn[i], dst = v.copy(dst)
}
} else {
for i, p := range p.arr {
ret.arr[i], dst = p.copy(dst)
}
}
return ret, dst
}
// Get returns value by key and flag about existence such value.
func (p *Parameters) Get(key string) (value []byte, ok bool) {
for _, v := range p.data() {
if string(v.key) == key {
return v.value, true
}
}
return nil, false
}
// Set sets value by key.
func (p *Parameters) Set(key, value []byte) {
p.bytes += len(key) + len(value)
if p.pos < len(p.arr) {
p.arr[p.pos] = pair{key, value}
p.pos++
return
}
if p.dyn == nil {
p.dyn = make([]pair, len(p.arr), len(p.arr)+1)
copy(p.dyn, p.arr[:])
}
p.dyn = append(p.dyn, pair{key, value})
}
// ForEach iterates over parameters key-value pairs and calls cb for each one.
func (p *Parameters) ForEach(cb func(k, v []byte) bool) {
for _, v := range p.data() {
if !cb(v.key, v.value) {
break
}
}
}
// String represents parameters as a string.
func (p *Parameters) String() (ret string) {
ret = "["
for i, v := range p.data() {
if i > 0 {
ret += " "
}
ret += string(v.key) + ":" + string(v.value)
}
return ret + "]"
}
func (p *Parameters) data() []pair {
if p.dyn != nil {
return p.dyn
}
return p.arr[:p.pos]
}
type pair struct {
key, value []byte
}
func (p pair) copy(dst []byte) (pair, []byte) {
n := copy(dst, p.key)
p.key = dst[:n]
m := n + copy(dst[n:], p.value)
p.value = dst[n:m]
dst = dst[m:]
return p, dst
}
type pairs []pair
func (p pairs) Len() int { return len(p) }
func (p pairs) Less(a, b int) bool { return bytes.Compare(p[a].key, p[b].key) == -1 }
func (p pairs) Swap(a, b int) { p[a], p[b] = p[b], p[a] }

101
vendor/github.com/gobwas/httphead/writer.go generated vendored Normal file
View file

@ -0,0 +1,101 @@
package httphead
import "io"
var (
comma = []byte{','}
equality = []byte{'='}
semicolon = []byte{';'}
quote = []byte{'"'}
escape = []byte{'\\'}
)
// WriteOptions write options list to the dest.
// It uses the same form as {Scan,Parse}Options functions:
// values = 1#value
// value = token *( ";" param )
// param = token [ "=" (token | quoted-string) ]
//
// It wraps valuse into the quoted-string sequence if it contains any
// non-token characters.
func WriteOptions(dest io.Writer, options []Option) (n int, err error) {
w := writer{w: dest}
for i, opt := range options {
if i > 0 {
w.write(comma)
}
writeTokenSanitized(&w, opt.Name)
for _, p := range opt.Parameters.data() {
w.write(semicolon)
writeTokenSanitized(&w, p.key)
if len(p.value) != 0 {
w.write(equality)
writeTokenSanitized(&w, p.value)
}
}
}
return w.result()
}
// writeTokenSanitized writes token as is or as quouted string if it contains
// non-token characters.
//
// Note that is is not expects LWS sequnces be in s, cause LWS is used only as
// header field continuation:
// "A CRLF is allowed in the definition of TEXT only as part of a header field
// continuation. It is expected that the folding LWS will be replaced with a
// single SP before interpretation of the TEXT value."
// See https://tools.ietf.org/html/rfc2616#section-2
//
// That is we sanitizing s for writing, so there could not be any header field
// continuation.
// That is any CRLF will be escaped as any other control characters not allowd in TEXT.
func writeTokenSanitized(bw *writer, bts []byte) {
var qt bool
var pos int
for i := 0; i < len(bts); i++ {
c := bts[i]
if !OctetTypes[c].IsToken() && !qt {
qt = true
bw.write(quote)
}
if OctetTypes[c].IsControl() || c == '"' {
if !qt {
qt = true
bw.write(quote)
}
bw.write(bts[pos:i])
bw.write(escape)
bw.write(bts[i : i+1])
pos = i + 1
}
}
if !qt {
bw.write(bts)
} else {
bw.write(bts[pos:])
bw.write(quote)
}
}
type writer struct {
w io.Writer
n int
err error
}
func (w *writer) write(p []byte) {
if w.err != nil {
return
}
var n int
n, w.err = w.w.Write(p)
w.n += n
return
}
func (w *writer) result() (int, error) {
return w.n, w.err
}

21
vendor/github.com/gobwas/pool/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017-2019 Sergey Kamardin <gobwas@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

107
vendor/github.com/gobwas/pool/README.md generated vendored Normal file
View file

@ -0,0 +1,107 @@
# pool
[![GoDoc][godoc-image]][godoc-url]
> Tiny memory reuse helpers for Go.
## generic
Without use of subpackages, `pool` allows to reuse any struct distinguishable
by size in generic way:
```go
package main
import "github.com/gobwas/pool"
func main() {
x, n := pool.Get(100) // Returns object with size 128 or nil.
if x == nil {
// Create x somehow with knowledge that n is 128.
}
defer pool.Put(x, n)
// Work with x.
}
```
Pool allows you to pass specific options for constructing custom pool:
```go
package main
import "github.com/gobwas/pool"
func main() {
p := pool.Custom(
pool.WithLogSizeMapping(), // Will ceil size n passed to Get(n) to nearest power of two.
pool.WithLogSizeRange(64, 512), // Will reuse objects in logarithmic range [64, 512].
pool.WithSize(65536), // Will reuse object with size 65536.
)
x, n := p.Get(1000) // Returns nil and 1000 because mapped size 1000 => 1024 is not reusing by the pool.
defer pool.Put(x, n) // Will not reuse x.
// Work with x.
}
```
Note that there are few non-generic pooling implementations inside subpackages.
## pbytes
Subpackage `pbytes` is intended for `[]byte` reuse.
```go
package main
import "github.com/gobwas/pool/pbytes"
func main() {
bts := pbytes.GetCap(100) // Returns make([]byte, 0, 128).
defer pbytes.Put(bts)
// Work with bts.
}
```
You can also create your own range for pooling:
```go
package main
import "github.com/gobwas/pool/pbytes"
func main() {
// Reuse only slices whose capacity is 128, 256, 512 or 1024.
pool := pbytes.New(128, 1024)
bts := pool.GetCap(100) // Returns make([]byte, 0, 128).
defer pool.Put(bts)
// Work with bts.
}
```
## pbufio
Subpackage `pbufio` is intended for `*bufio.{Reader, Writer}` reuse.
```go
package main
import "github.com/gobwas/pool/pbufio"
func main() {
bw := pbufio.GetWriter(os.Stdout, 100) // Returns bufio.NewWriterSize(128).
defer pbufio.PutWriter(bw)
// Work with bw.
}
```
Like with `pbytes`, you can also create pool with custom reuse bounds.
[godoc-image]: https://godoc.org/github.com/gobwas/pool?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/pool

87
vendor/github.com/gobwas/pool/generic.go generated vendored Normal file
View file

@ -0,0 +1,87 @@
package pool
import (
"sync"
"github.com/gobwas/pool/internal/pmath"
)
var DefaultPool = New(128, 65536)
// Get pulls object whose generic size is at least of given size. It also
// returns a real size of x for further pass to Put(). It returns -1 as real
// size for nil x. Size >-1 does not mean that x is non-nil, so checks must be
// done.
//
// Note that size could be ceiled to the next power of two.
//
// Get is a wrapper around DefaultPool.Get().
func Get(size int) (interface{}, int) { return DefaultPool.Get(size) }
// Put takes x and its size for future reuse.
// Put is a wrapper around DefaultPool.Put().
func Put(x interface{}, size int) { DefaultPool.Put(x, size) }
// Pool contains logic of reusing objects distinguishable by size in generic
// way.
type Pool struct {
pool map[int]*sync.Pool
size func(int) int
}
// New creates new Pool that reuses objects which size is in logarithmic range
// [min, max].
//
// Note that it is a shortcut for Custom() constructor with Options provided by
// WithLogSizeMapping() and WithLogSizeRange(min, max) calls.
func New(min, max int) *Pool {
return Custom(
WithLogSizeMapping(),
WithLogSizeRange(min, max),
)
}
// Custom creates new Pool with given options.
func Custom(opts ...Option) *Pool {
p := &Pool{
pool: make(map[int]*sync.Pool),
size: pmath.Identity,
}
c := (*poolConfig)(p)
for _, opt := range opts {
opt(c)
}
return p
}
// Get pulls object whose generic size is at least of given size.
// It also returns a real size of x for further pass to Put() even if x is nil.
// Note that size could be ceiled to the next power of two.
func (p *Pool) Get(size int) (interface{}, int) {
n := p.size(size)
if pool := p.pool[n]; pool != nil {
return pool.Get(), n
}
return nil, size
}
// Put takes x and its size for future reuse.
func (p *Pool) Put(x interface{}, size int) {
if pool := p.pool[size]; pool != nil {
pool.Put(x)
}
}
type poolConfig Pool
// AddSize adds size n to the map.
func (p *poolConfig) AddSize(n int) {
p.pool[n] = new(sync.Pool)
}
// SetSizeMapping sets up incoming size mapping function.
func (p *poolConfig) SetSizeMapping(size func(int) int) {
p.size = size
}

65
vendor/github.com/gobwas/pool/internal/pmath/pmath.go generated vendored Normal file
View file

@ -0,0 +1,65 @@
package pmath
const (
bitsize = 32 << (^uint(0) >> 63)
maxint = int(1<<(bitsize-1) - 1)
maxintHeadBit = 1 << (bitsize - 2)
)
// LogarithmicRange iterates from ceiled to power of two min to max,
// calling cb on each iteration.
func LogarithmicRange(min, max int, cb func(int)) {
if min == 0 {
min = 1
}
for n := CeilToPowerOfTwo(min); n <= max; n <<= 1 {
cb(n)
}
}
// IsPowerOfTwo reports whether given integer is a power of two.
func IsPowerOfTwo(n int) bool {
return n&(n-1) == 0
}
// Identity is identity.
func Identity(n int) int {
return n
}
// CeilToPowerOfTwo returns the least power of two integer value greater than
// or equal to n.
func CeilToPowerOfTwo(n int) int {
if n&maxintHeadBit != 0 && n > maxintHeadBit {
panic("argument is too large")
}
if n <= 2 {
return n
}
n--
n = fillBits(n)
n++
return n
}
// FloorToPowerOfTwo returns the greatest power of two integer value less than
// or equal to n.
func FloorToPowerOfTwo(n int) int {
if n <= 2 {
return n
}
n = fillBits(n)
n >>= 1
n++
return n
}
func fillBits(n int) int {
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
return n
}

43
vendor/github.com/gobwas/pool/option.go generated vendored Normal file
View file

@ -0,0 +1,43 @@
package pool
import "github.com/gobwas/pool/internal/pmath"
// Option configures pool.
type Option func(Config)
// Config describes generic pool configuration.
type Config interface {
AddSize(n int)
SetSizeMapping(func(int) int)
}
// WithSizeLogRange returns an Option that will add logarithmic range of
// pooling sizes containing [min, max] values.
func WithLogSizeRange(min, max int) Option {
return func(c Config) {
pmath.LogarithmicRange(min, max, func(n int) {
c.AddSize(n)
})
}
}
// WithSize returns an Option that will add given pooling size to the pool.
func WithSize(n int) Option {
return func(c Config) {
c.AddSize(n)
}
}
func WithSizeMapping(sz func(int) int) Option {
return func(c Config) {
c.SetSizeMapping(sz)
}
}
func WithLogSizeMapping() Option {
return WithSizeMapping(pmath.CeilToPowerOfTwo)
}
func WithIdentitySizeMapping() Option {
return WithSizeMapping(pmath.Identity)
}

106
vendor/github.com/gobwas/pool/pbufio/pbufio.go generated vendored Normal file
View file

@ -0,0 +1,106 @@
// Package pbufio contains tools for pooling bufio.Reader and bufio.Writers.
package pbufio
import (
"bufio"
"io"
"github.com/gobwas/pool"
)
var (
DefaultWriterPool = NewWriterPool(256, 65536)
DefaultReaderPool = NewReaderPool(256, 65536)
)
// GetWriter returns bufio.Writer whose buffer has at least size bytes.
// Note that size could be ceiled to the next power of two.
// GetWriter is a wrapper around DefaultWriterPool.Get().
func GetWriter(w io.Writer, size int) *bufio.Writer { return DefaultWriterPool.Get(w, size) }
// PutWriter takes bufio.Writer for future reuse.
// It does not reuse bufio.Writer which underlying buffer size is not power of
// PutWriter is a wrapper around DefaultWriterPool.Put().
func PutWriter(bw *bufio.Writer) { DefaultWriterPool.Put(bw) }
// GetReader returns bufio.Reader whose buffer has at least size bytes. It returns
// its capacity for further pass to Put().
// Note that size could be ceiled to the next power of two.
// GetReader is a wrapper around DefaultReaderPool.Get().
func GetReader(w io.Reader, size int) *bufio.Reader { return DefaultReaderPool.Get(w, size) }
// PutReader takes bufio.Reader and its size for future reuse.
// It does not reuse bufio.Reader if size is not power of two or is out of pool
// min/max range.
// PutReader is a wrapper around DefaultReaderPool.Put().
func PutReader(bw *bufio.Reader) { DefaultReaderPool.Put(bw) }
// WriterPool contains logic of *bufio.Writer reuse with various size.
type WriterPool struct {
pool *pool.Pool
}
// NewWriterPool creates new WriterPool that reuses writers which size is in
// logarithmic range [min, max].
func NewWriterPool(min, max int) *WriterPool {
return &WriterPool{pool.New(min, max)}
}
// CustomWriterPool creates new WriterPool with given options.
func CustomWriterPool(opts ...pool.Option) *WriterPool {
return &WriterPool{pool.Custom(opts...)}
}
// Get returns bufio.Writer whose buffer has at least size bytes.
func (wp *WriterPool) Get(w io.Writer, size int) *bufio.Writer {
v, n := wp.pool.Get(size)
if v != nil {
bw := v.(*bufio.Writer)
bw.Reset(w)
return bw
}
return bufio.NewWriterSize(w, n)
}
// Put takes ownership of bufio.Writer for further reuse.
func (wp *WriterPool) Put(bw *bufio.Writer) {
// Should reset even if we do Reset() inside Get().
// This is done to prevent locking underlying io.Writer from GC.
bw.Reset(nil)
wp.pool.Put(bw, writerSize(bw))
}
// ReaderPool contains logic of *bufio.Reader reuse with various size.
type ReaderPool struct {
pool *pool.Pool
}
// NewReaderPool creates new ReaderPool that reuses writers which size is in
// logarithmic range [min, max].
func NewReaderPool(min, max int) *ReaderPool {
return &ReaderPool{pool.New(min, max)}
}
// CustomReaderPool creates new ReaderPool with given options.
func CustomReaderPool(opts ...pool.Option) *ReaderPool {
return &ReaderPool{pool.Custom(opts...)}
}
// Get returns bufio.Reader whose buffer has at least size bytes.
func (rp *ReaderPool) Get(r io.Reader, size int) *bufio.Reader {
v, n := rp.pool.Get(size)
if v != nil {
br := v.(*bufio.Reader)
br.Reset(r)
return br
}
return bufio.NewReaderSize(r, n)
}
// Put takes ownership of bufio.Reader for further reuse.
func (rp *ReaderPool) Put(br *bufio.Reader) {
// Should reset even if we do Reset() inside Get().
// This is done to prevent locking underlying io.Reader from GC.
br.Reset(nil)
rp.pool.Put(br, readerSize(br))
}

13
vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go generated vendored Normal file
View file

@ -0,0 +1,13 @@
// +build go1.10
package pbufio
import "bufio"
func writerSize(bw *bufio.Writer) int {
return bw.Size()
}
func readerSize(br *bufio.Reader) int {
return br.Size()
}

27
vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go generated vendored Normal file
View file

@ -0,0 +1,27 @@
// +build !go1.10
package pbufio
import "bufio"
func writerSize(bw *bufio.Writer) int {
return bw.Available() + bw.Buffered()
}
// readerSize returns buffer size of the given buffered reader.
// NOTE: current workaround implementation resets underlying io.Reader.
func readerSize(br *bufio.Reader) int {
br.Reset(sizeReader)
br.ReadByte()
n := br.Buffered() + 1
br.Reset(nil)
return n
}
var sizeReader optimisticReader
type optimisticReader struct{}
func (optimisticReader) Read(p []byte) (int, error) {
return len(p), nil
}

24
vendor/github.com/gobwas/pool/pbytes/pbytes.go generated vendored Normal file
View file

@ -0,0 +1,24 @@
// Package pbytes contains tools for pooling byte pool.
// Note that by default it reuse slices with capacity from 128 to 65536 bytes.
package pbytes
// DefaultPool is used by pacakge level functions.
var DefaultPool = New(128, 65536)
// Get returns probably reused slice of bytes with at least capacity of c and
// exactly len of n.
// Get is a wrapper around DefaultPool.Get().
func Get(n, c int) []byte { return DefaultPool.Get(n, c) }
// GetCap returns probably reused slice of bytes with at least capacity of n.
// GetCap is a wrapper around DefaultPool.GetCap().
func GetCap(c int) []byte { return DefaultPool.GetCap(c) }
// GetLen returns probably reused slice of bytes with at least capacity of n
// and exactly len of n.
// GetLen is a wrapper around DefaultPool.GetLen().
func GetLen(n int) []byte { return DefaultPool.GetLen(n) }
// Put returns given slice to reuse pool.
// Put is a wrapper around DefaultPool.Put().
func Put(p []byte) { DefaultPool.Put(p) }

59
vendor/github.com/gobwas/pool/pbytes/pool.go generated vendored Normal file
View file

@ -0,0 +1,59 @@
// +build !pool_sanitize
package pbytes
import "github.com/gobwas/pool"
// Pool contains logic of reusing byte slices of various size.
type Pool struct {
pool *pool.Pool
}
// New creates new Pool that reuses slices which size is in logarithmic range
// [min, max].
//
// Note that it is a shortcut for Custom() constructor with Options provided by
// pool.WithLogSizeMapping() and pool.WithLogSizeRange(min, max) calls.
func New(min, max int) *Pool {
return &Pool{pool.New(min, max)}
}
// New creates new Pool with given options.
func Custom(opts ...pool.Option) *Pool {
return &Pool{pool.Custom(opts...)}
}
// Get returns probably reused slice of bytes with at least capacity of c and
// exactly len of n.
func (p *Pool) Get(n, c int) []byte {
if n > c {
panic("requested length is greater than capacity")
}
v, x := p.pool.Get(c)
if v != nil {
bts := v.([]byte)
bts = bts[:n]
return bts
}
return make([]byte, n, x)
}
// Put returns given slice to reuse pool.
// It does not reuse bytes whose size is not power of two or is out of pool
// min/max range.
func (p *Pool) Put(bts []byte) {
p.pool.Put(bts, cap(bts))
}
// GetCap returns probably reused slice of bytes with at least capacity of n.
func (p *Pool) GetCap(c int) []byte {
return p.Get(0, c)
}
// GetLen returns probably reused slice of bytes with at least capacity of n
// and exactly len of n.
func (p *Pool) GetLen(n int) []byte {
return p.Get(n, n)
}

121
vendor/github.com/gobwas/pool/pbytes/pool_sanitize.go generated vendored Normal file
View file

@ -0,0 +1,121 @@
// +build pool_sanitize
package pbytes
import (
"reflect"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const magic = uint64(0x777742)
type guard struct {
magic uint64
size int
owners int32
}
const guardSize = int(unsafe.Sizeof(guard{}))
type Pool struct {
min, max int
}
func New(min, max int) *Pool {
return &Pool{min, max}
}
// Get returns probably reused slice of bytes with at least capacity of c and
// exactly len of n.
func (p *Pool) Get(n, c int) []byte {
if n > c {
panic("requested length is greater than capacity")
}
pageSize := syscall.Getpagesize()
pages := (c+guardSize)/pageSize + 1
size := pages * pageSize
bts := alloc(size)
g := (*guard)(unsafe.Pointer(&bts[0]))
*g = guard{
magic: magic,
size: size,
owners: 1,
}
return bts[guardSize : guardSize+n]
}
func (p *Pool) GetCap(c int) []byte { return p.Get(0, c) }
func (p *Pool) GetLen(n int) []byte { return Get(n, n) }
// Put returns given slice to reuse pool.
func (p *Pool) Put(bts []byte) {
hdr := *(*reflect.SliceHeader)(unsafe.Pointer(&bts))
ptr := hdr.Data - uintptr(guardSize)
g := (*guard)(unsafe.Pointer(ptr))
if g.magic != magic {
panic("unknown slice returned to the pool")
}
if n := atomic.AddInt32(&g.owners, -1); n < 0 {
panic("multiple Put() detected")
}
// Disable read and write on bytes memory pages. This will cause panic on
// incorrect access to returned slice.
mprotect(ptr, false, false, g.size)
runtime.SetFinalizer(&bts, func(b *[]byte) {
mprotect(ptr, true, true, g.size)
free(*(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
Data: ptr,
Len: g.size,
Cap: g.size,
})))
})
}
func alloc(n int) []byte {
b, err := unix.Mmap(-1, 0, n, unix.PROT_READ|unix.PROT_WRITE|unix.PROT_EXEC, unix.MAP_SHARED|unix.MAP_ANONYMOUS)
if err != nil {
panic(err.Error())
}
return b
}
func free(b []byte) {
if err := unix.Munmap(b); err != nil {
panic(err.Error())
}
}
func mprotect(ptr uintptr, r, w bool, size int) {
// Need to avoid "EINVAL addr is not a valid pointer,
// or not a multiple of PAGESIZE."
start := ptr & ^(uintptr(syscall.Getpagesize() - 1))
prot := uintptr(syscall.PROT_EXEC)
switch {
case r && w:
prot |= syscall.PROT_READ | syscall.PROT_WRITE
case r:
prot |= syscall.PROT_READ
case w:
prot |= syscall.PROT_WRITE
}
_, _, err := syscall.Syscall(syscall.SYS_MPROTECT,
start, uintptr(size), prot,
)
if err != 0 {
panic(err.Error())
}
}

25
vendor/github.com/gobwas/pool/pool.go generated vendored Normal file
View file

@ -0,0 +1,25 @@
// Package pool contains helpers for pooling structures distinguishable by
// size.
//
// Quick example:
//
// import "github.com/gobwas/pool"
//
// func main() {
// // Reuse objects in logarithmic range from 0 to 64 (0,1,2,4,6,8,16,32,64).
// p := pool.New(0, 64)
//
// buf, n := p.Get(10) // Returns buffer with 16 capacity.
// if buf == nil {
// buf = bytes.NewBuffer(make([]byte, n))
// }
// defer p.Put(buf, n)
//
// // Work with buf.
// }
//
// There are non-generic implementations for pooling:
// - pool/pbytes for []byte reuse;
// - pool/pbufio for *bufio.Reader and *bufio.Writer reuse;
//
package pool

5
vendor/github.com/gobwas/ws/.gitignore generated vendored Normal file
View file

@ -0,0 +1,5 @@
bin/
reports/
cpu.out
mem.out
ws.test

21
vendor/github.com/gobwas/ws/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017-2021 Sergey Kamardin <gobwas@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

54
vendor/github.com/gobwas/ws/Makefile generated vendored Normal file
View file

@ -0,0 +1,54 @@
BENCH ?=.
BENCH_BASE?=master
clean:
rm -f bin/reporter
rm -fr autobahn/report/*
bin/reporter:
go build -o bin/reporter ./autobahn
bin/gocovmerge:
go build -o bin/gocovmerge github.com/wadey/gocovmerge
.PHONY: autobahn
autobahn: clean bin/reporter
./autobahn/script/test.sh --build --follow-logs
bin/reporter $(PWD)/autobahn/report/index.json
.PHONY: autobahn/report
autobahn/report: bin/reporter
./bin/reporter -http localhost:5555 ./autobahn/report/index.json
test:
go test -coverprofile=ws.coverage .
go test -coverprofile=wsutil.coverage ./wsutil
go test -coverprofile=wsfalte.coverage ./wsflate
# No statements to cover in ./tests (there are only tests).
go test ./tests
cover: bin/gocovmerge test autobahn
bin/gocovmerge ws.coverage wsutil.coverage wsflate.coverage autobahn/report/server.coverage > total.coverage
benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX)
benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX)
benchcmp:
if [ ! -z "$(shell git status -s)" ]; then\
echo "could not compare with $(BENCH_BASE) found unstaged changes";\
exit 1;\
fi;\
if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\
echo "comparing the same branches";\
exit 1;\
fi;\
echo "benchmarking $(BENCH_BRANCH)...";\
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\
echo "benchmarking $(BENCH_BASE)...";\
git checkout -q $(BENCH_BASE);\
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\
git checkout -q $(BENCH_BRANCH);\
echo "\nresults:";\
echo "========\n";\
benchcmp $(BENCH_OLD) $(BENCH_NEW);\

541
vendor/github.com/gobwas/ws/README.md generated vendored Normal file
View file

@ -0,0 +1,541 @@
# ws
[![GoDoc][godoc-image]][godoc-url]
[![CI][ci-badge]][ci-url]
> [RFC6455][rfc-url] WebSocket implementation in Go.
# Features
- Zero-copy upgrade
- No intermediate allocations during I/O
- Low-level API which allows to build your own logic of packet handling and
buffers reuse
- High-level wrappers and helpers around API in `wsutil` package, which allow
to start fast without digging the protocol internals
# Documentation
[GoDoc][godoc-url].
# Why
Existing WebSocket implementations do not allow users to reuse I/O buffers
between connections in clear way. This library aims to export efficient
low-level interface for working with the protocol without forcing only one way
it could be used.
By the way, if you want get the higher-level tools, you can use `wsutil`
package.
# Status
Library is tagged as `v1*` so its API must not be broken during some
improvements or refactoring.
This implementation of RFC6455 passes [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) and currently has
about 78% coverage.
# Examples
Example applications using `ws` are developed in separate repository
[ws-examples](https://github.com/gobwas/ws-examples).
# Usage
The higher-level example of WebSocket echo server:
```go
package main
import (
"net/http"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
for {
msg, op, err := wsutil.ReadClientData(conn)
if err != nil {
// handle error
}
err = wsutil.WriteServerMessage(conn, op, msg)
if err != nil {
// handle error
}
}
}()
}))
}
```
Lower-level, but still high-level example:
```go
import (
"net/http"
"io"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
func main() {
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
var (
state = ws.StateServerSide
reader = wsutil.NewReader(conn, state)
writer = wsutil.NewWriter(conn, state, ws.OpText)
)
for {
header, err := reader.NextFrame()
if err != nil {
// handle error
}
// Reset writer to write frame with right operation code.
writer.Reset(conn, state, header.OpCode)
if _, err = io.Copy(writer, reader); err != nil {
// handle error
}
if err = writer.Flush(); err != nil {
// handle error
}
}
}()
}))
}
```
We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.:
```go
...
var (
r = wsutil.NewReader(conn, ws.StateServerSide)
w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText)
decoder = json.NewDecoder(r)
encoder = json.NewEncoder(w)
)
for {
hdr, err = r.NextFrame()
if err != nil {
return err
}
if hdr.OpCode == ws.OpClose {
return io.EOF
}
var req Request
if err := decoder.Decode(&req); err != nil {
return err
}
var resp Response
if err := encoder.Encode(&resp); err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
}
...
```
The lower-level example without `wsutil`:
```go
package main
import (
"net"
"io"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatal(err)
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
_, err = ws.Upgrade(conn)
if err != nil {
// handle error
}
go func() {
defer conn.Close()
for {
header, err := ws.ReadHeader(conn)
if err != nil {
// handle error
}
payload := make([]byte, header.Length)
_, err = io.ReadFull(conn, payload)
if err != nil {
// handle error
}
if header.Masked {
ws.Cipher(payload, header.Mask, 0)
}
// Reset the Masked flag, server frames must not be masked as
// RFC6455 says.
header.Masked = false
if err := ws.WriteHeader(conn, header); err != nil {
// handle error
}
if _, err := conn.Write(payload); err != nil {
// handle error
}
if header.OpCode == ws.OpClose {
return
}
}
}()
}
}
```
# Zero-copy upgrade
Zero-copy upgrade helps to avoid unnecessary allocations and copying while
handling HTTP Upgrade request.
Processing of all non-websocket headers is made in place with use of registered
user callbacks whose arguments are only valid until callback returns.
The simple example looks like this:
```go
package main
import (
"net"
"log"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
log.Fatal(err)
}
u := ws.Upgrader{
OnHeader: func(key, value []byte) (err error) {
log.Printf("non-websocket header: %q=%q", key, value)
return
},
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
_, err = u.Upgrade(conn)
if err != nil {
// handle error
}
}
}
```
Usage of `ws.Upgrader` here brings ability to control incoming connections on
tcp level and simply not to accept them by some logic.
Zero-copy upgrade is for high-load services which have to control many
resources such as connections buffers.
The real life example could be like this:
```go
package main
import (
"fmt"
"io"
"log"
"net"
"net/http"
"runtime"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
// handle error
}
// Prepare handshake header writer from http.Header mapping.
header := ws.HandshakeHeaderHTTP(http.Header{
"X-Go-Version": []string{runtime.Version()},
})
u := ws.Upgrader{
OnHost: func(host []byte) error {
if string(host) == "github.com" {
return nil
}
return ws.RejectConnectionError(
ws.RejectionStatus(403),
ws.RejectionHeader(ws.HandshakeHeaderString(
"X-Want-Host: github.com\r\n",
)),
)
},
OnHeader: func(key, value []byte) error {
if string(key) != "Cookie" {
return nil
}
ok := httphead.ScanCookie(value, func(key, value []byte) bool {
// Check session here or do some other stuff with cookies.
// Maybe copy some values for future use.
return true
})
if ok {
return nil
}
return ws.RejectConnectionError(
ws.RejectionReason("bad cookie"),
ws.RejectionStatus(400),
)
},
OnBeforeUpgrade: func() (ws.HandshakeHeader, error) {
return header, nil
},
}
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
_, err = u.Upgrade(conn)
if err != nil {
log.Printf("upgrade error: %s", err)
}
}
}
```
# Compression
There is a `ws/wsflate` package to support [Permessage-Deflate Compression
Extension][rfc-pmce].
It provides minimalistic I/O wrappers to be used in conjunction with any
deflate implementation (for example, the standard library's
[compress/flate][compress/flate]).
It is also compatible with `wsutil`'s reader and writer by providing
`wsflate.MessageState` type, which implements `wsutil.SendExtension` and
`wsutil.RecvExtension` interfaces.
```go
package main
import (
"bytes"
"log"
"net"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
)
func main() {
ln, err := net.Listen("tcp", "localhost:8080")
if err != nil {
// handle error
}
e := wsflate.Extension{
// We are using default parameters here since we use
// wsflate.{Compress,Decompress}Frame helpers below in the code.
// This assumes that we use standard compress/flate package as flate
// implementation.
Parameters: wsflate.DefaultParameters,
}
u := ws.Upgrader{
Negotiate: e.Negotiate,
}
for {
conn, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
// Reset extension after previous upgrades.
e.Reset()
_, err = u.Upgrade(conn)
if err != nil {
log.Printf("upgrade error: %s", err)
continue
}
if _, ok := e.Accepted(); !ok {
log.Printf("didn't negotiate compression for %s", conn.RemoteAddr())
conn.Close()
continue
}
go func() {
defer conn.Close()
for {
frame, err := ws.ReadFrame(conn)
if err != nil {
// Handle error.
return
}
frame = ws.UnmaskFrameInPlace(frame)
if wsflate.IsCompressed(frame.Header) {
// Note that even after successful negotiation of
// compression extension, both sides are able to send
// non-compressed messages.
frame, err = wsflate.DecompressFrame(frame)
if err != nil {
// Handle error.
return
}
}
// Do something with frame...
ack := ws.NewTextFrame([]byte("this is an acknowledgement"))
// Compress response unconditionally.
ack, err = wsflate.CompressFrame(ack)
if err != nil {
// Handle error.
return
}
if err = ws.WriteFrame(conn, ack); err != nil {
// Handle error.
return
}
}
}()
}
}
```
You can use compression with `wsutil` package this way:
```go
// Upgrade somehow and negotiate compression to get the conn...
// Initialize flate reader. We are using nil as a source io.Reader because
// we will Reset() it in the message i/o loop below.
fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
// Initialize flate writer. We are using nil as a destination io.Writer
// because we will Reset() it in the message i/o loop below.
fw := wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
f, _ := flate.NewWriter(w, 9)
return f
})
// Declare compression message state variable.
//
// It has two goals:
// - Allow users to check whether received message is compressed or not.
// - Help wsutil.Reader and wsutil.Writer to set/unset appropriate
// WebSocket header bits while writing next frame to the wire (it
// implements wsutil.RecvExtension and wsutil.SendExtension).
var msg wsflate.MessageState
// Initialize WebSocket reader as previously.
// Please note the use of Reader.Extensions field as well as
// of ws.StateExtended flag.
rd := &wsutil.Reader{
Source: conn,
State: ws.StateServerSide | ws.StateExtended,
Extensions: []wsutil.RecvExtension{
&msg,
},
}
// Initialize WebSocket writer with ws.StateExtended flag as well.
wr := wsutil.NewWriter(conn, ws.StateServerSide|ws.StateExtended, 0)
// Use the message state as wsutil.SendExtension.
wr.SetExtensions(&msg)
for {
h, err := rd.NextFrame()
if err != nil {
// handle error.
}
if h.OpCode.IsControl() {
// handle control frame.
}
if !msg.IsCompressed() {
// handle uncompressed frame (skipped for the sake of example
// simplicity).
}
// Reset the writer to echo same op code.
wr.Reset(h.OpCode)
// Reset both flate reader and writer to start the new round of i/o.
fr.Reset(rd)
fw.Reset(wr)
// Copy whole message from reader to writer decompressing it and
// compressing again.
if _, err := io.Copy(fw, fr); err != nil {
// handle error.
}
// Flush any remaining buffers from flate writer to WebSocket writer.
if err := fw.Close(); err != nil {
// handle error.
}
// Flush the whole WebSocket message to the wire.
if err := wr.Flush(); err != nil {
// handle error.
}
}
```
[rfc-url]: https://tools.ietf.org/html/rfc6455
[rfc-pmce]: https://tools.ietf.org/html/rfc7692#section-7
[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg
[godoc-url]: https://godoc.org/github.com/gobwas/ws
[compress/flate]: https://golang.org/pkg/compress/flate/
[ci-badge]: https://github.com/gobwas/ws/workflows/CI/badge.svg
[ci-url]: https://github.com/gobwas/ws/actions?query=workflow%3ACI

145
vendor/github.com/gobwas/ws/check.go generated vendored Normal file
View file

@ -0,0 +1,145 @@
package ws
import "unicode/utf8"
// State represents state of websocket endpoint.
// It used by some functions to be more strict when checking compatibility with RFC6455.
type State uint8
const (
// StateServerSide means that endpoint (caller) is a server.
StateServerSide State = 0x1 << iota
// StateClientSide means that endpoint (caller) is a client.
StateClientSide
// StateExtended means that extension was negotiated during handshake.
StateExtended
// StateFragmented means that endpoint (caller) has received fragmented
// frame and waits for continuation parts.
StateFragmented
)
// Is checks whether the s has v enabled.
func (s State) Is(v State) bool {
return uint8(s)&uint8(v) != 0
}
// Set enables v state on s.
func (s State) Set(v State) State {
return s | v
}
// Clear disables v state on s.
func (s State) Clear(v State) State {
return s & (^v)
}
// ServerSide reports whether states represents server side.
func (s State) ServerSide() bool { return s.Is(StateServerSide) }
// ClientSide reports whether state represents client side.
func (s State) ClientSide() bool { return s.Is(StateClientSide) }
// Extended reports whether state is extended.
func (s State) Extended() bool { return s.Is(StateExtended) }
// Fragmented reports whether state is fragmented.
func (s State) Fragmented() bool { return s.Is(StateFragmented) }
// ProtocolError describes error during checking/parsing websocket frames or
// headers.
type ProtocolError string
// Error implements error interface.
func (p ProtocolError) Error() string { return string(p) }
// Errors used by the protocol checkers.
var (
ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code")
ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded")
ErrProtocolControlNotFinal = ProtocolError("control frame is not final")
ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated")
ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked")
ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked")
ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame")
ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame")
ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use")
ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet")
ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec")
ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason")
)
// CheckHeader checks h to contain valid header data for given state s.
//
// Note that zero state (0) means that state is clean,
// neither server or client side, nor fragmented, nor extended.
func CheckHeader(h Header, s State) error {
if h.OpCode.IsReserved() {
return ErrProtocolOpCodeReserved
}
if h.OpCode.IsControl() {
if h.Length > MaxControlFramePayloadSize {
return ErrProtocolControlPayloadOverflow
}
if !h.Fin {
return ErrProtocolControlNotFinal
}
}
switch {
// [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
// non-zero values. If a nonzero value is received and none of the
// negotiated extensions defines the meaning of such a nonzero value, the
// receiving endpoint MUST _Fail the WebSocket Connection_.
case h.Rsv != 0 && !s.Extended():
return ErrProtocolNonZeroRsv
// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
// status code 1002 (protocol error) as defined in Section 7.4.1.
case s.ServerSide() && !h.Masked:
return ErrProtocolMaskRequired
case s.ClientSide() && h.Masked:
return ErrProtocolMaskUnexpected
// [RFC6455]: See detailed explanation in 5.4 section.
case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
return ErrProtocolContinuationExpected
case !s.Fragmented() && h.OpCode == OpContinuation:
return ErrProtocolContinuationUnexpected
default:
return nil
}
}
// CheckCloseFrameData checks received close information
// to be valid RFC6455 compatible close info.
//
// Note that code.Empty() or code.IsAppLevel() will raise error.
//
// If endpoint sends close frame without status code (with frame.Length = 0),
// application should not check its payload.
func CheckCloseFrameData(code StatusCode, reason string) error {
switch {
case code.IsNotUsed():
return ErrProtocolStatusCodeNotInUse
case code.IsProtocolReserved():
return ErrProtocolStatusCodeApplicationLevel
case code == StatusNoMeaningYet:
return ErrProtocolStatusCodeNoMeaning
case code.IsProtocolSpec() && !code.IsProtocolDefined():
return ErrProtocolStatusCodeUnknown
case !utf8.ValidString(reason):
return ErrProtocolInvalidUTF8
default:
return nil
}
}

61
vendor/github.com/gobwas/ws/cipher.go generated vendored Normal file
View file

@ -0,0 +1,61 @@
package ws
import (
"encoding/binary"
)
// Cipher applies XOR cipher to the payload using mask.
// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
//
// To convert masked data into unmasked data, or vice versa, the following
// algorithm is applied. The same algorithm applies regardless of the
// direction of the translation, e.g., the same steps are applied to
// mask the data as to unmask the data.
func Cipher(payload []byte, mask [4]byte, offset int) {
n := len(payload)
if n < 8 {
for i := 0; i < n; i++ {
payload[i] ^= mask[(offset+i)%4]
}
return
}
// Calculate position in mask due to previously processed bytes number.
mpos := offset % 4
// Count number of bytes will processed one by one from the beginning of payload.
ln := remain[mpos]
// Count number of bytes will processed one by one from the end of payload.
// This is done to process payload by 16 bytes in each iteration of main loop.
rn := (n - ln) % 16
for i := 0; i < ln; i++ {
payload[i] ^= mask[(mpos+i)%4]
}
for i := n - rn; i < n; i++ {
payload[i] ^= mask[(mpos+i)%4]
}
// NOTE: we use here binary.LittleEndian regardless of what is real
// endianness on machine is. To do so, we have to use binary.LittleEndian in
// the masking loop below as well.
var (
m = binary.LittleEndian.Uint32(mask[:])
m2 = uint64(m)<<32 | uint64(m)
)
// Skip already processed right part.
// Get number of uint64 parts remaining to process.
n = (n - ln - rn) >> 4
j := ln
for i := 0; i < n; i++ {
chunk := payload[j : j+16]
p := binary.LittleEndian.Uint64(chunk) ^ m2
p2 := binary.LittleEndian.Uint64(chunk[8:]) ^ m2
binary.LittleEndian.PutUint64(chunk, p)
binary.LittleEndian.PutUint64(chunk[8:], p2)
j += 16
}
}
// remain maps position in masking key [0,4) to number
// of bytes that need to be processed manually inside Cipher().
var remain = [4]int{0, 3, 2, 1}

573
vendor/github.com/gobwas/ws/dialer.go generated vendored Normal file
View file

@ -0,0 +1,573 @@
package ws
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/gobwas/httphead"
"github.com/gobwas/pool/pbufio"
)
// Constants used by Dialer.
const (
DefaultClientReadBufferSize = 4096
DefaultClientWriteBufferSize = 4096
)
// Handshake represents handshake result.
type Handshake struct {
// Protocol is the subprotocol selected during handshake.
Protocol string
// Extensions is the list of negotiated extensions.
Extensions []httphead.Option
}
// Errors used by the websocket client.
var (
ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
)
// DefaultDialer is dialer that holds no options and is used by Dial function.
var DefaultDialer Dialer
// Dial is like Dialer{}.Dial().
func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
return DefaultDialer.Dial(ctx, urlstr)
}
// Dialer contains options for establishing websocket connection to an url.
type Dialer struct {
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
// They used to read and write http data while upgrading to WebSocket.
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
//
// If a size is zero then default value is used.
ReadBufferSize, WriteBufferSize int
// Timeout is the maximum amount of time a Dial() will wait for a connect
// and an handshake to complete.
//
// The default is no timeout.
Timeout time.Duration
// Protocols is the list of subprotocols that the client wants to speak,
// ordered by preference.
//
// See https://tools.ietf.org/html/rfc6455#section-4.1
Protocols []string
// Extensions is the list of extensions that client wants to speak.
//
// Note that if server decides to use some of this extensions, Dial() will
// return Handshake struct containing a slice of items, which are the
// shallow copies of the items from this list. That is, internals of
// Extensions items are shared during Dial().
//
// See https://tools.ietf.org/html/rfc6455#section-4.1
// See https://tools.ietf.org/html/rfc6455#section-9.1
Extensions []httphead.Option
// Header is an optional HandshakeHeader instance that could be used to
// write additional headers to the handshake request.
//
// It used instead of any key-value mappings to avoid allocations in user
// land.
Header HandshakeHeader
// Host is an optional string that could be used to specify the host during
// HTTP upgrade request by setting 'Host' header.
//
// Default value is an empty string, which results in setting 'Host' header
// equal to the URL hostname given to Dialer.Dial().
Host string
// OnStatusError is the callback that will be called after receiving non
// "101 Continue" HTTP response status. It receives an io.Reader object
// representing server response bytes. That is, it gives ability to parse
// HTTP response somehow (probably with http.ReadResponse call) and make a
// decision of further logic.
//
// The arguments are only valid until the callback returns.
OnStatusError func(status int, reason []byte, resp io.Reader)
// OnHeader is the callback that will be called after successful parsing of
// header, that is not used during WebSocket handshake procedure. That is,
// it will be called with non-websocket headers, which could be relevant
// for application-level logic.
//
// The arguments are only valid until the callback returns.
//
// Returned value could be used to prevent processing response.
OnHeader func(key, value []byte) (err error)
// NetDial is the function that is used to get plain tcp connection.
// If it is not nil, then it is used instead of net.Dialer.
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
// TLSClient is the callback that will be called after successful dial with
// received connection and its remote host name. If it is nil, then the
// default tls.Client() will be used.
// If it is not nil, then TLSConfig field is ignored.
TLSClient func(conn net.Conn, hostname string) net.Conn
// TLSConfig is passed to tls.Client() to start TLS over established
// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
// non-nil and its ServerName is empty, then for every Dial() it will be
// cloned and appropriate ServerName will be set.
TLSConfig *tls.Config
// WrapConn is the optional callback that will be called when connection is
// ready for an i/o. That is, it will be called after successful dial and
// TLS initialization (for "wss" schemes). It may be helpful for different
// user land purposes such as end to end encryption.
//
// Note that for debugging purposes of an http handshake (e.g. sent request
// and received response), there is an wsutil.DebugDialer struct.
WrapConn func(conn net.Conn) net.Conn
}
// Dial connects to the url host and upgrades connection to WebSocket.
//
// If server has sent frames right after successful handshake then returned
// buffer will be non-nil. In other cases buffer is always nil. For better
// memory efficiency received non-nil bufio.Reader should be returned to the
// inner pool with PutReader() function after use.
//
// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
// If you want to dial non-ascii host name, take care of its name serialization
// avoiding bad request issues. For more info see net/http Request.Write()
// implementation, especially cleanHost() function.
func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
u, err := url.ParseRequestURI(urlstr)
if err != nil {
return nil, nil, hs, err
}
// Prepare context to dial with. Initially it is the same as original, but
// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
// we use more shorter context for dial.
dialctx := ctx
var deadline time.Time
if t := d.Timeout; t != 0 {
deadline = time.Now().Add(t)
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
var cancel context.CancelFunc
dialctx, cancel = context.WithDeadline(ctx, deadline)
defer cancel()
}
}
if conn, err = d.dial(dialctx, u); err != nil {
return conn, nil, hs, err
}
defer func() {
if err != nil {
conn.Close()
}
}()
if ctx == context.Background() {
// No need to start I/O interrupter goroutine which is not zero-cost.
conn.SetDeadline(deadline)
defer conn.SetDeadline(noDeadline)
} else {
// Context could be canceled or its deadline could be exceeded.
// Start the interrupter goroutine to handle context cancelation.
done := setupContextDeadliner(ctx, conn)
defer func() {
// Map Upgrade() error to a possible context expiration error. That
// is, even if Upgrade() err is nil, context could be already
// expired and connection be "poisoned" by SetDeadline() call.
// In that case we must not return ctx.Err() error.
done(&err)
}()
}
br, hs, err = d.Upgrade(conn, u)
return conn, br, hs, err
}
var (
// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
// Dialer.NetDial is not provided.
netEmptyDialer net.Dialer
// tlsEmptyConfig is an empty tls.Config used as default one.
tlsEmptyConfig tls.Config
)
func tlsDefaultConfig() *tls.Config {
return &tlsEmptyConfig
}
func hostport(host, defaultPort string) (hostname, addr string) {
var (
colon = strings.LastIndexByte(host, ':')
bracket = strings.IndexByte(host, ']')
)
if colon > bracket {
return host[:colon], host
}
return host, host + defaultPort
}
func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
dial := d.NetDial
if dial == nil {
dial = netEmptyDialer.DialContext
}
switch u.Scheme {
case "ws":
_, addr := hostport(u.Host, ":80")
conn, err = dial(ctx, "tcp", addr)
case "wss":
hostname, addr := hostport(u.Host, ":443")
conn, err = dial(ctx, "tcp", addr)
if err != nil {
return nil, err
}
tlsClient := d.TLSClient
if tlsClient == nil {
tlsClient = d.tlsClient
}
conn = tlsClient(conn, hostname)
default:
return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
}
if wrap := d.WrapConn; wrap != nil {
conn = wrap(conn)
}
return conn, err
}
func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
config := d.TLSConfig
if config == nil {
config = tlsDefaultConfig()
}
if config.ServerName == "" {
config = tlsCloneConfig(config)
config.ServerName = hostname
}
// Do not make conn.Handshake() here because downstairs we will prepare
// i/o on this conn with proper context's timeout handling.
return tls.Client(conn, config)
}
var (
// This variables are set like in net/net.go.
// noDeadline is just zero value for readability.
noDeadline = time.Time{}
// aLongTimeAgo is a non-zero time, far in the past, used for immediate
// cancelation of dials.
aLongTimeAgo = time.Unix(42, 0)
)
// Upgrade writes an upgrade request to the given io.ReadWriter conn at given
// url u and reads a response from it.
//
// It is a caller responsibility to manage I/O deadlines on conn.
//
// It returns handshake info and some bytes which could be written by the peer
// right after response and be caught by us during buffered read.
func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
// headerSeen constants helps to report whether or not some header was seen
// during reading request bytes.
const (
headerSeenUpgrade = 1 << iota
headerSeenConnection
headerSeenSecAccept
// headerSeenAll is the value that we expect to receive at the end of
// headers read/parse loop.
headerSeenAll = 0 |
headerSeenUpgrade |
headerSeenConnection |
headerSeenSecAccept
)
br = pbufio.GetReader(conn,
nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
)
bw := pbufio.GetWriter(conn,
nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
)
defer func() {
pbufio.PutWriter(bw)
if br.Buffered() == 0 || err != nil {
// Server does not wrote additional bytes to the connection or
// error occurred. That is, no reason to return buffer.
pbufio.PutReader(br)
br = nil
}
}()
nonce := make([]byte, nonceSize)
initNonce(nonce)
httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header, d.Host)
if err := bw.Flush(); err != nil {
return br, hs, err
}
// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
sl, err := readLine(br)
if err != nil {
return br, hs, err
}
// Begin validation of the response.
// See https://tools.ietf.org/html/rfc6455#section-4.2.2
// Parse request line data like HTTP version, uri and method.
resp, err := httpParseResponseLine(sl)
if err != nil {
return br, hs, err
}
// Even if RFC says "1.1 or higher" without mentioning the part of the
// version, we apply it only to minor part.
if resp.major != 1 || resp.minor < 1 {
err = ErrHandshakeBadProtocol
return br, hs, err
}
if resp.status != http.StatusSwitchingProtocols {
err = StatusError(resp.status)
if onStatusError := d.OnStatusError; onStatusError != nil {
// Invoke callback with multireader of status-line bytes br.
onStatusError(resp.status, resp.reason,
io.MultiReader(
bytes.NewReader(sl),
strings.NewReader(crlf),
br,
),
)
}
return br, hs, err
}
// If response status is 101 then we expect all technical headers to be
// valid. If not, then we stop processing response without giving user
// ability to read non-technical headers. That is, we do not distinguish
// technical errors (such as parsing error) and protocol errors.
var headerSeen byte
for {
line, e := readLine(br)
if e != nil {
err = e
return br, hs, err
}
if len(line) == 0 {
// Blank line, no more lines to read.
break
}
k, v, ok := httpParseHeaderLine(line)
if !ok {
err = ErrMalformedResponse
return br, hs, err
}
switch btsToString(k) {
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
return br, hs, err
}
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
// Note that as RFC6455 says:
// > A |Connection| header field with value "Upgrade".
// That is, in server side, "Connection" header could contain
// multiple token. But in response it must contains exactly one.
if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
err = ErrHandshakeBadConnection
return br, hs, err
}
case headerSecAcceptCanonical:
headerSeen |= headerSeenSecAccept
if !checkAcceptFromNonce(v, nonce) {
err = ErrHandshakeBadSecAccept
return br, hs, err
}
case headerSecProtocolCanonical:
// RFC6455 1.3:
// "The server selects one or none of the acceptable protocols
// and echoes that value in its handshake to indicate that it has
// selected that protocol."
for _, want := range d.Protocols {
if string(v) == want {
hs.Protocol = want
break
}
}
if hs.Protocol == "" {
// Server echoed subprotocol that is not present in client
// requested protocols.
err = ErrHandshakeBadSubProtocol
return br, hs, err
}
case headerSecExtensionsCanonical:
hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
if err != nil {
return br, hs, err
}
default:
if onHeader := d.OnHeader; onHeader != nil {
if e := onHeader(k, v); e != nil {
err = e
return br, hs, err
}
}
}
}
if err == nil && headerSeen != headerSeenAll {
switch {
case headerSeen&headerSeenUpgrade == 0:
err = ErrHandshakeBadUpgrade
case headerSeen&headerSeenConnection == 0:
err = ErrHandshakeBadConnection
case headerSeen&headerSeenSecAccept == 0:
err = ErrHandshakeBadSecAccept
default:
panic("unknown headers state")
}
}
return br, hs, err
}
// PutReader returns bufio.Reader instance to the inner reuse pool.
// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
// contains unprocessed buffered data, that was sent by the server quickly
// right after handshake.
func PutReader(br *bufio.Reader) {
pbufio.PutReader(br)
}
// StatusError contains an unexpected status-line code from the server.
type StatusError int
func (s StatusError) Error() string {
return "unexpected HTTP response status: " + strconv.Itoa(int(s))
}
func isTimeoutError(err error) bool {
t, ok := err.(net.Error)
return ok && t.Timeout()
}
func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
if len(selected) == 0 {
return received, nil
}
var (
index int
option httphead.Option
err error
)
index = -1
match := func() (ok bool) {
for _, want := range wanted {
// A server accepts one or more extensions by including a
// |Sec-WebSocket-Extensions| header field containing one or more
// extensions that were requested by the client.
//
// The interpretation of any extension parameters, and what
// constitutes a valid response by a server to a requested set of
// parameters by a client, will be defined by each such extension.
if bytes.Equal(option.Name, want.Name) {
// Check parsed extension to be present in client
// requested extensions. We move matched extension
// from client list to avoid allocation of httphead.Option.Name,
// httphead.Option.Parameters have to be copied from the header
want.Parameters, _ = option.Parameters.Copy(make([]byte, option.Parameters.Size()))
received = append(received, want)
return true
}
}
return false
}
ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
if i != index {
// Met next option.
index = i
if i != 0 && !match() {
// Server returned non-requested extension.
err = ErrHandshakeBadExtensions
return httphead.ControlBreak
}
option = httphead.Option{Name: name}
}
if attr != nil {
option.Parameters.Set(attr, val)
}
return httphead.ControlContinue
})
if !ok {
err = ErrMalformedResponse
return received, err
}
if !match() {
return received, ErrHandshakeBadExtensions
}
return received, err
}
// setupContextDeadliner is a helper function that starts connection I/O
// interrupter goroutine.
//
// Started goroutine calls SetDeadline() with long time ago value when context
// become expired to make any I/O operations failed. It returns done function
// that stops started goroutine and maps error received from conn I/O methods
// to possible context expiration error.
//
// In concern with possible SetDeadline() call inside interrupter goroutine,
// caller passes pointer to its I/O error (even if it is nil) to done(&err).
// That is, even if I/O error is nil, context could be already expired and
// connection "poisoned" by SetDeadline() call. In that case done(&err) will
// store at *err ctx.Err() result. If err is caused not by timeout, it will
// leaved untouched.
func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
var (
quit = make(chan struct{})
interrupt = make(chan error, 1)
)
go func() {
select {
case <-quit:
interrupt <- nil
case <-ctx.Done():
// Cancel i/o immediately.
conn.SetDeadline(aLongTimeAgo)
interrupt <- ctx.Err()
}
}()
return func(err *error) {
close(quit)
// If ctx.Err() is non-nil and the original err is net.Error with
// Timeout() == true, then it means that I/O was canceled by us by
// SetDeadline(aLongTimeAgo) call, or by somebody else previously
// by conn.SetDeadline(x).
//
// Even on race condition when both deadlines are expired
// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
// be returned.
if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
*err = ctxErr
}
}
}

35
vendor/github.com/gobwas/ws/dialer_tls_go17.go generated vendored Normal file
View file

@ -0,0 +1,35 @@
// +build !go1.8
package ws
import "crypto/tls"
func tlsCloneConfig(c *tls.Config) *tls.Config {
// NOTE: we copying SessionTicketsDisabled and SessionTicketKey here
// without calling inner c.initOnceServer somehow because we only could get
// here from the ws.Dialer code, which is obviously a client and makes
// tls.Client() when it gets new net.Conn.
return &tls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
}
}

10
vendor/github.com/gobwas/ws/dialer_tls_go18.go generated vendored Normal file
View file

@ -0,0 +1,10 @@
//go:build go1.8
// +build go1.8
package ws
import "crypto/tls"
func tlsCloneConfig(c *tls.Config) *tls.Config {
return c.Clone()
}

81
vendor/github.com/gobwas/ws/doc.go generated vendored Normal file
View file

@ -0,0 +1,81 @@
/*
Package ws implements a client and server for the WebSocket protocol as
specified in RFC 6455.
The main purpose of this package is to provide simple low-level API for
efficient work with protocol.
Overview.
Upgrade to WebSocket (or WebSocket handshake) can be done in two ways.
The first way is to use `net/http` server:
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
})
The second and much more efficient way is so-called "zero-copy upgrade". It
avoids redundant allocations and copying of not used headers or other request
data. User decides by himself which data should be copied.
ln, err := net.Listen("tcp", ":8080")
if err != nil {
// handle error
}
conn, err := ln.Accept()
if err != nil {
// handle error
}
handshake, err := ws.Upgrade(conn)
if err != nil {
// handle error
}
For customization details see `ws.Upgrader` documentation.
After WebSocket handshake you can work with connection in multiple ways.
That is, `ws` does not force the only one way of how to work with WebSocket:
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
buf := make([]byte, header.Length)
_, err := io.ReadFull(conn, buf)
if err != nil {
// handle err
}
resp := ws.NewBinaryFrame([]byte("hello, world!"))
if err := ws.WriteFrame(conn, frame); err != nil {
// handle err
}
As you can see, it stream friendly:
const N = 42
ws.WriteHeader(ws.Header{
Fin: true,
Length: N,
OpCode: ws.OpBinary,
})
io.CopyN(conn, rand.Reader, N)
Or:
header, err := ws.ReadHeader(conn)
if err != nil {
// handle err
}
io.CopyN(ioutil.Discard, conn, header.Length)
For more info see the documentation.
*/
package ws

59
vendor/github.com/gobwas/ws/errors.go generated vendored Normal file
View file

@ -0,0 +1,59 @@
package ws
// RejectOption represents an option used to control the way connection is
// rejected.
type RejectOption func(*ConnectionRejectedError)
// RejectionReason returns an option that makes connection to be rejected with
// given reason.
func RejectionReason(reason string) RejectOption {
return func(err *ConnectionRejectedError) {
err.reason = reason
}
}
// RejectionStatus returns an option that makes connection to be rejected with
// given HTTP status code.
func RejectionStatus(code int) RejectOption {
return func(err *ConnectionRejectedError) {
err.code = code
}
}
// RejectionHeader returns an option that makes connection to be rejected with
// given HTTP headers.
func RejectionHeader(h HandshakeHeader) RejectOption {
return func(err *ConnectionRejectedError) {
err.header = h
}
}
// RejectConnectionError constructs an error that could be used to control the
// way handshake is rejected by Upgrader.
func RejectConnectionError(options ...RejectOption) error {
err := new(ConnectionRejectedError)
for _, opt := range options {
opt(err)
}
return err
}
// ConnectionRejectedError represents a rejection of connection during
// WebSocket handshake error.
//
// It can be returned by Upgrader's On* hooks to indicate that WebSocket
// handshake should be rejected.
type ConnectionRejectedError struct {
reason string
code int
header HandshakeHeader
}
// Error implements error interface.
func (r *ConnectionRejectedError) Error() string {
return r.reason
}
func (r *ConnectionRejectedError) StatusCode() int {
return r.code
}

420
vendor/github.com/gobwas/ws/frame.go generated vendored Normal file
View file

@ -0,0 +1,420 @@
package ws
import (
"bytes"
"encoding/binary"
"math/rand"
)
// Constants defined by specification.
const (
// All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
MaxControlFramePayloadSize = 125
)
// OpCode represents operation code.
type OpCode byte
// Operation codes defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-5.2
const (
OpContinuation OpCode = 0x0
OpText OpCode = 0x1
OpBinary OpCode = 0x2
OpClose OpCode = 0x8
OpPing OpCode = 0x9
OpPong OpCode = 0xa
)
// IsControl checks whether the c is control operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.5
func (c OpCode) IsControl() bool {
// RFC6455: Control frames are identified by opcodes where
// the most significant bit of the opcode is 1.
//
// Note that OpCode is only 4 bit length.
return c&0x8 != 0
}
// IsData checks whether the c is data operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.6
func (c OpCode) IsData() bool {
// RFC6455: Data frames (e.g., non-control frames) are identified by opcodes
// where the most significant bit of the opcode is 0.
//
// Note that OpCode is only 4 bit length.
return c&0x8 == 0
}
// IsReserved checks whether the c is reserved operation code.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func (c OpCode) IsReserved() bool {
// RFC6455:
// %x3-7 are reserved for further non-control frames
// %xB-F are reserved for further control frames
return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf)
}
// StatusCode represents the encoded reason for closure of websocket connection.
//
// There are few helper methods on StatusCode that helps to define a range in
// which given code is lay in. accordingly to ranges defined in specification.
//
// See https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode uint16
// StatusCodeRange describes range of StatusCode values.
type StatusCodeRange struct {
Min, Max StatusCode
}
// Status code ranges defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-7.4.2
var (
StatusRangeNotInUse = StatusCodeRange{0, 999}
StatusRangeProtocol = StatusCodeRange{1000, 2999}
StatusRangeApplication = StatusCodeRange{3000, 3999}
StatusRangePrivate = StatusCodeRange{4000, 4999}
)
// Status codes defined by specification.
// See https://tools.ietf.org/html/rfc6455#section-7.4.1
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
StatusNoMeaningYet StatusCode = 1004
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExt StatusCode = 1010
StatusInternalServerError StatusCode = 1011
StatusTLSHandshake StatusCode = 1015
// StatusAbnormalClosure is a special code designated for use in
// applications.
StatusAbnormalClosure StatusCode = 1006
// StatusNoStatusRcvd is a special code designated for use in applications.
StatusNoStatusRcvd StatusCode = 1005
)
// In reports whether the code is defined in given range.
func (s StatusCode) In(r StatusCodeRange) bool {
return r.Min <= s && s <= r.Max
}
// Empty reports whether the code is empty.
// Empty code has no any meaning neither app level codes nor other.
// This method is useful just to check that code is golang default value 0.
func (s StatusCode) Empty() bool {
return s == 0
}
// IsNotUsed reports whether the code is predefined in not used range.
func (s StatusCode) IsNotUsed() bool {
return s.In(StatusRangeNotInUse)
}
// IsApplicationSpec reports whether the code should be defined by
// application, framework or libraries specification.
func (s StatusCode) IsApplicationSpec() bool {
return s.In(StatusRangeApplication)
}
// IsPrivateSpec reports whether the code should be defined privately.
func (s StatusCode) IsPrivateSpec() bool {
return s.In(StatusRangePrivate)
}
// IsProtocolSpec reports whether the code should be defined by protocol specification.
func (s StatusCode) IsProtocolSpec() bool {
return s.In(StatusRangeProtocol)
}
// IsProtocolDefined reports whether the code is already defined by protocol specification.
func (s StatusCode) IsProtocolDefined() bool {
switch s {
case StatusNormalClosure,
StatusGoingAway,
StatusProtocolError,
StatusUnsupportedData,
StatusInvalidFramePayloadData,
StatusPolicyViolation,
StatusMessageTooBig,
StatusMandatoryExt,
StatusInternalServerError,
StatusNoStatusRcvd,
StatusAbnormalClosure,
StatusTLSHandshake:
return true
}
return false
}
// IsProtocolReserved reports whether the code is defined by protocol specification
// to be reserved only for application usage purpose.
func (s StatusCode) IsProtocolReserved() bool {
switch s {
// [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a
// Close control frame by an endpoint.
case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return true
default:
return false
}
}
// Compiled control frames for common use cases.
// For construct-serialize optimizations.
var (
CompiledPing = MustCompileFrame(NewPingFrame(nil))
CompiledPong = MustCompileFrame(NewPongFrame(nil))
CompiledClose = MustCompileFrame(NewCloseFrame(nil))
CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure)
CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway)
CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError)
CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData)
CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet)
CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData)
CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation)
CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig)
CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt)
CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError)
CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake)
)
// Header represents websocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type Header struct {
Fin bool
Rsv byte
OpCode OpCode
Masked bool
Mask [4]byte
Length int64
}
// Rsv1 reports whether the header has first rsv bit set.
func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 }
// Rsv2 reports whether the header has second rsv bit set.
func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 }
// Rsv3 reports whether the header has third rsv bit set.
func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 }
// Rsv creates rsv byte representation from bits.
func Rsv(r1, r2, r3 bool) (rsv byte) {
if r1 {
rsv |= bit5
}
if r2 {
rsv |= bit6
}
if r3 {
rsv |= bit7
}
return rsv
}
// RsvBits returns rsv bits from bytes representation.
func RsvBits(rsv byte) (r1, r2, r3 bool) {
r1 = rsv&bit5 != 0
r2 = rsv&bit6 != 0
r3 = rsv&bit7 != 0
return r1, r2, r3
}
// Frame represents websocket frame.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type Frame struct {
Header Header
Payload []byte
}
// NewFrame creates frame with given operation code,
// flag of completeness and payload bytes.
func NewFrame(op OpCode, fin bool, p []byte) Frame {
return Frame{
Header: Header{
Fin: fin,
OpCode: op,
Length: int64(len(p)),
},
Payload: p,
}
}
// NewTextFrame creates text frame with p as payload.
// Note that p is not copied.
func NewTextFrame(p []byte) Frame {
return NewFrame(OpText, true, p)
}
// NewBinaryFrame creates binary frame with p as payload.
// Note that p is not copied.
func NewBinaryFrame(p []byte) Frame {
return NewFrame(OpBinary, true, p)
}
// NewPingFrame creates ping frame with p as payload.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewPingFrame(p []byte) Frame {
return NewFrame(OpPing, true, p)
}
// NewPongFrame creates pong frame with p as payload.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewPongFrame(p []byte) Frame {
return NewFrame(OpPong, true, p)
}
// NewCloseFrame creates close frame with given close body.
// Note that p is not copied.
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
// to RFC.
func NewCloseFrame(p []byte) Frame {
return NewFrame(OpClose, true, p)
}
// NewCloseFrameBody encodes a closure code and a reason into a binary
// representation.
//
// It returns slice which is at most MaxControlFramePayloadSize bytes length.
// If the reason is too big it will be cropped to fit the limit defined by the
// spec.
//
// See https://tools.ietf.org/html/rfc6455#section-5.5
func NewCloseFrameBody(code StatusCode, reason string) []byte {
n := min(2+len(reason), MaxControlFramePayloadSize)
p := make([]byte, n)
crop := min(MaxControlFramePayloadSize-2, len(reason))
PutCloseFrameBody(p, code, reason[:crop])
return p
}
// PutCloseFrameBody encodes code and reason into buf.
//
// It will panic if the buffer is too small to accommodate a code or a reason.
//
// PutCloseFrameBody does not check buffer to be RFC compliant, but note that
// by RFC it must be at most MaxControlFramePayloadSize.
func PutCloseFrameBody(p []byte, code StatusCode, reason string) {
_ = p[1+len(reason)]
binary.BigEndian.PutUint16(p, uint16(code))
copy(p[2:], reason)
}
// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
// Note that it copies f payload to prevent collisions.
// For less allocations you could use MaskFrameInPlace or construct frame manually.
func MaskFrame(f Frame) Frame {
return MaskFrameWith(f, NewMask())
}
// MaskFrameWith masks frame with given mask and returns frame
// with masked payload and Mask header's field set.
// Note that it copies f payload to prevent collisions.
// For less allocations you could use MaskFrameInPlaceWith or construct frame manually.
func MaskFrameWith(f Frame, mask [4]byte) Frame {
// TODO(gobwas): check CopyCipher ws copy() Cipher().
p := make([]byte, len(f.Payload))
copy(p, f.Payload)
f.Payload = p
return MaskFrameInPlaceWith(f, mask)
}
// MaskFrameInPlace masks frame and returns frame with masked payload and Mask
// header's field set.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func MaskFrameInPlace(f Frame) Frame {
return MaskFrameInPlaceWith(f, NewMask())
}
var zeroMask [4]byte
// UnmaskFrame unmasks frame and returns frame with unmasked payload and Mask
// header's field cleared.
// Note that it copies f payload.
func UnmaskFrame(f Frame) Frame {
p := make([]byte, len(f.Payload))
copy(p, f.Payload)
f.Payload = p
return UnmaskFrameInPlace(f)
}
// UnmaskFrameInPlace unmasks frame and returns frame with unmasked payload and
// Mask header's field cleared.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func UnmaskFrameInPlace(f Frame) Frame {
Cipher(f.Payload, f.Header.Mask, 0)
f.Header.Masked = false
f.Header.Mask = zeroMask
return f
}
// MaskFrameInPlaceWith masks frame with given mask and returns frame
// with masked payload and Mask header's field set.
// Note that it applies xor cipher to f.Payload without copying, that is, it
// modifies f.Payload inplace.
func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame {
f.Header.Masked = true
f.Header.Mask = m
Cipher(f.Payload, m, 0)
return f
}
// NewMask creates new random mask.
func NewMask() (ret [4]byte) {
binary.BigEndian.PutUint32(ret[:], rand.Uint32())
return ret
}
// CompileFrame returns byte representation of given frame.
// In terms of memory consumption it is useful to precompile static frames
// which are often used.
func CompileFrame(f Frame) (bts []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, 16))
err = WriteFrame(buf, f)
bts = buf.Bytes()
return bts, err
}
// MustCompileFrame is like CompileFrame but panics if frame can not be
// encoded.
func MustCompileFrame(f Frame) []byte {
bts, err := CompileFrame(f)
if err != nil {
panic(err)
}
return bts
}
func makeCloseFrame(code StatusCode) Frame {
return NewCloseFrame(NewCloseFrameBody(code, ""))
}
var (
closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure)
closeFrameGoingAway = makeCloseFrame(StatusGoingAway)
closeFrameProtocolError = makeCloseFrame(StatusProtocolError)
closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData)
closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet)
closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData)
closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation)
closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig)
closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt)
closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError)
closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake)
)

18
vendor/github.com/gobwas/ws/hijack_go119.go generated vendored Normal file
View file

@ -0,0 +1,18 @@
//go:build !go1.20
// +build !go1.20
package ws
import (
"bufio"
"net"
"net/http"
)
func hijack(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.(http.Hijacker)
if ok {
return hj.Hijack()
}
return nil, nil, ErrNotHijacker
}

19
vendor/github.com/gobwas/ws/hijack_go120.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
//go:build go1.20
// +build go1.20
package ws
import (
"bufio"
"errors"
"net"
"net/http"
)
func hijack(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
conn, rw, err := http.NewResponseController(w).Hijack()
if errors.Is(err, http.ErrNotSupported) {
return nil, nil, ErrNotHijacker
}
return conn, rw, err
}

507
vendor/github.com/gobwas/ws/http.go generated vendored Normal file
View file

@ -0,0 +1,507 @@
package ws
import (
"bufio"
"bytes"
"io"
"net/http"
"net/url"
"strconv"
"github.com/gobwas/httphead"
)
const (
crlf = "\r\n"
colonAndSpace = ": "
commaAndSpace = ", "
)
const (
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
)
var (
textHeadBadRequest = statusText(http.StatusBadRequest)
textHeadInternalServerError = statusText(http.StatusInternalServerError)
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
)
const (
// Every new header must be added to TestHeaderNames test.
headerHost = "Host"
headerUpgrade = "Upgrade"
headerConnection = "Connection"
headerSecVersion = "Sec-WebSocket-Version"
headerSecProtocol = "Sec-WebSocket-Protocol"
headerSecExtensions = "Sec-WebSocket-Extensions"
headerSecKey = "Sec-WebSocket-Key"
headerSecAccept = "Sec-WebSocket-Accept"
headerHostCanonical = headerHost
headerUpgradeCanonical = headerUpgrade
headerConnectionCanonical = headerConnection
headerSecVersionCanonical = "Sec-Websocket-Version"
headerSecProtocolCanonical = "Sec-Websocket-Protocol"
headerSecExtensionsCanonical = "Sec-Websocket-Extensions"
headerSecKeyCanonical = "Sec-Websocket-Key"
headerSecAcceptCanonical = "Sec-Websocket-Accept"
)
var (
specHeaderValueUpgrade = []byte("websocket")
specHeaderValueConnection = []byte("Upgrade")
specHeaderValueConnectionLower = []byte("upgrade")
specHeaderValueSecVersion = []byte("13")
)
var (
httpVersion1_0 = []byte("HTTP/1.0")
httpVersion1_1 = []byte("HTTP/1.1")
httpVersionPrefix = []byte("HTTP/")
)
type httpRequestLine struct {
method, uri []byte
major, minor int
}
type httpResponseLine struct {
major, minor int
status int
reason []byte
}
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
var proto []byte
req.method, req.uri, proto = bsplit3(line, ' ')
var ok bool
req.major, req.minor, ok = httpParseVersion(proto)
if !ok {
err = ErrMalformedRequest
}
return req, err
}
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
var (
proto []byte
status []byte
)
proto, status, resp.reason = bsplit3(line, ' ')
var ok bool
resp.major, resp.minor, ok = httpParseVersion(proto)
if !ok {
return resp, ErrMalformedResponse
}
var convErr error
resp.status, convErr = asciiToInt(status)
if convErr != nil {
return resp, ErrMalformedResponse
}
return resp, nil
}
// httpParseVersion parses major and minor version of HTTP protocol. It returns
// parsed values and true if parse is ok.
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
switch {
case bytes.Equal(bts, httpVersion1_0):
return 1, 0, true
case bytes.Equal(bts, httpVersion1_1):
return 1, 1, true
case len(bts) < 8:
return 0, 0, false
case !bytes.Equal(bts[:5], httpVersionPrefix):
return 0, 0, false
}
bts = bts[5:]
dot := bytes.IndexByte(bts, '.')
if dot == -1 {
return 0, 0, false
}
var err error
major, err = asciiToInt(bts[:dot])
if err != nil {
return major, 0, false
}
minor, err = asciiToInt(bts[dot+1:])
if err != nil {
return major, minor, false
}
return major, minor, true
}
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
// values and true if parse is ok.
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
colon := bytes.IndexByte(line, ':')
if colon == -1 {
return nil, nil, false
}
k = btrim(line[:colon])
// TODO(gobwas): maybe use just lower here?
canonicalizeHeaderKey(k)
v = btrim(line[colon+1:])
return k, v, true
}
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
// that key is already canonical. This helps to increase performance.
func httpGetHeader(h http.Header, key string) string {
if h == nil {
return ""
}
v := h[key]
if len(v) == 0 {
return ""
}
return v[0]
}
// The request MAY include a header field with the name
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
// comma-separated subprotocol the client wishes to speak, ordered by
// preference. The elements that comprise this value MUST be non-empty strings
// with characters in the range U+0021 to U+007E not including separator
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
// for the value of this header field is 1#token, where the definitions of
// constructs and rules are as given in [RFC2616].
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
if check(btsToString(v)) {
ret = string(v)
return false
}
return true
})
return ret, ok
}
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
var selected []byte
ok = httphead.ScanTokens(h, func(v []byte) bool {
if check(v) {
selected = v
return false
}
return true
})
if ok && selected != nil {
return string(selected), true
}
return ret, ok
}
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
s := httphead.OptionSelector{
Flags: httphead.SelectCopy,
Check: check,
}
return s.Select(h, selected)
}
func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
if in.Size() == 0 {
return dest, nil
}
opt, err := f(in)
if err != nil {
return nil, err
}
if opt.Size() > 0 {
dest = append(dest, opt)
}
return dest, nil
}
func negotiateExtensions(
h []byte, dest []httphead.Option,
f func(httphead.Option) (httphead.Option, error),
) (_ []httphead.Option, err error) {
index := -1
var current httphead.Option
ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
if i != index {
dest, err = negotiateMaybe(current, dest, f)
if err != nil {
return httphead.ControlBreak
}
index = i
current = httphead.Option{Name: name}
}
if attr != nil {
current.Parameters.Set(attr, val)
}
return httphead.ControlContinue
})
if !ok {
return nil, ErrMalformedRequest
}
return negotiateMaybe(current, dest, f)
}
func httpWriteHeader(bw *bufio.Writer, key, value string) {
httpWriteHeaderKey(bw, key)
bw.WriteString(value)
bw.WriteString(crlf)
}
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
httpWriteHeaderKey(bw, key)
bw.Write(value)
bw.WriteString(crlf)
}
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
bw.WriteString(key)
bw.WriteString(colonAndSpace)
}
func httpWriteUpgradeRequest(
bw *bufio.Writer,
u *url.URL,
nonce []byte,
protocols []string,
extensions []httphead.Option,
header HandshakeHeader,
host string,
) {
bw.WriteString("GET ")
bw.WriteString(u.RequestURI())
bw.WriteString(" HTTP/1.1\r\n")
if host == "" {
host = u.Host
}
httpWriteHeader(bw, headerHost, host)
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
// NOTE: write nonce bytes as a string to prevent heap allocation
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer which in turn
// will lead to p escape.
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
if len(protocols) > 0 {
httpWriteHeaderKey(bw, headerSecProtocol)
for i, p := range protocols {
if i > 0 {
bw.WriteString(commaAndSpace)
}
bw.WriteString(p)
}
bw.WriteString(crlf)
}
if len(extensions) > 0 {
httpWriteHeaderKey(bw, headerSecExtensions)
httphead.WriteOptions(bw, extensions)
bw.WriteString(crlf)
}
if header != nil {
header.WriteTo(bw)
}
bw.WriteString(crlf)
}
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
bw.WriteString(textHeadUpgrade)
httpWriteHeaderKey(bw, headerSecAccept)
writeAccept(bw, nonce)
bw.WriteString(crlf)
if hs.Protocol != "" {
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
}
if len(hs.Extensions) > 0 {
httpWriteHeaderKey(bw, headerSecExtensions)
httphead.WriteOptions(bw, hs.Extensions)
bw.WriteString(crlf)
}
if header != nil {
header(bw)
}
bw.WriteString(crlf)
}
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
switch code {
case http.StatusBadRequest:
bw.WriteString(textHeadBadRequest)
case http.StatusInternalServerError:
bw.WriteString(textHeadInternalServerError)
case http.StatusUpgradeRequired:
bw.WriteString(textHeadUpgradeRequired)
default:
writeStatusText(bw, code)
}
// Write custom headers.
if header != nil {
header(bw)
}
switch err {
case ErrHandshakeBadProtocol:
bw.WriteString(textTailErrHandshakeBadProtocol)
case ErrHandshakeBadMethod:
bw.WriteString(textTailErrHandshakeBadMethod)
case ErrHandshakeBadHost:
bw.WriteString(textTailErrHandshakeBadHost)
case ErrHandshakeBadUpgrade:
bw.WriteString(textTailErrHandshakeBadUpgrade)
case ErrHandshakeBadConnection:
bw.WriteString(textTailErrHandshakeBadConnection)
case ErrHandshakeBadSecAccept:
bw.WriteString(textTailErrHandshakeBadSecAccept)
case ErrHandshakeBadSecKey:
bw.WriteString(textTailErrHandshakeBadSecKey)
case ErrHandshakeBadSecVersion:
bw.WriteString(textTailErrHandshakeBadSecVersion)
case ErrHandshakeUpgradeRequired:
bw.WriteString(textTailErrUpgradeRequired)
case nil:
bw.WriteString(crlf)
default:
writeErrorText(bw, err)
}
}
func writeStatusText(bw *bufio.Writer, code int) {
bw.WriteString("HTTP/1.1 ")
bw.WriteString(strconv.Itoa(code))
bw.WriteByte(' ')
bw.WriteString(http.StatusText(code))
bw.WriteString(crlf)
bw.WriteString("Content-Type: text/plain; charset=utf-8")
bw.WriteString(crlf)
}
func writeErrorText(bw *bufio.Writer, err error) {
body := err.Error()
bw.WriteString("Content-Length: ")
bw.WriteString(strconv.Itoa(len(body)))
bw.WriteString(crlf)
bw.WriteString(crlf)
bw.WriteString(body)
}
// httpError is like the http.Error with WebSocket context exception.
func httpError(w http.ResponseWriter, body string, code int) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.WriteHeader(code)
w.Write([]byte(body))
}
// statusText is a non-performant status text generator.
// NOTE: Used only to generate constants.
func statusText(code int) string {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
writeStatusText(bw, code)
bw.Flush()
return buf.String()
}
// errorText is a non-performant error text generator.
// NOTE: Used only to generate constants.
func errorText(err error) string {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
writeErrorText(bw, err)
bw.Flush()
return buf.String()
}
// HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer.
type HandshakeHeader interface {
io.WriterTo
}
// HandshakeHeaderString is an adapter to allow the use of headers represented
// by ordinary string as HandshakeHeader.
type HandshakeHeaderString string
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
n, err := io.WriteString(w, string(s))
return int64(n), err
}
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
// by ordinary slice of bytes as HandshakeHeader.
type HandshakeHeaderBytes []byte
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(b)
return int64(n), err
}
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
// ordinary function as HandshakeHeader.
type HandshakeHeaderFunc func(io.Writer) (int64, error)
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
return f(w)
}
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
// HandshakeHeader.
type HandshakeHeaderHTTP http.Header
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
wr := writer{w: w}
err := http.Header(h).Write(&wr)
return wr.n, err
}
type writer struct {
n int64
w io.Writer
}
func (w *writer) WriteString(s string) (int, error) {
n, err := io.WriteString(w.w, s)
w.n += int64(n)
return n, err
}
func (w *writer) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.n += int64(n)
return n, err
}

78
vendor/github.com/gobwas/ws/nonce.go generated vendored Normal file
View file

@ -0,0 +1,78 @@
package ws
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"fmt"
"math/rand"
)
const (
// RFC6455: The value of this header field MUST be a nonce consisting of a
// randomly selected 16-byte value that has been base64-encoded (see
// Section 4 of [RFC4648]). The nonce MUST be selected randomly for each
// connection.
nonceKeySize = 16
nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
// RFC6455: The value of this header field is constructed by concatenating
// /key/, defined above in step 4 in Section 4.2.2, with the string
// "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
// concatenated value to obtain a 20-byte value and base64- encoding (see
// Section 4 of [RFC4648]) this 20-byte hash.
acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
)
// initNonce fills given slice with random base64-encoded nonce bytes.
func initNonce(dst []byte) {
// NOTE: bts does not escape.
bts := make([]byte, nonceKeySize)
if _, err := rand.Read(bts); err != nil {
panic(fmt.Sprintf("rand read error: %s", err))
}
base64.StdEncoding.Encode(dst, bts)
}
// checkAcceptFromNonce reports whether given accept bytes are valid for given
// nonce bytes.
func checkAcceptFromNonce(accept, nonce []byte) bool {
if len(accept) != acceptSize {
return false
}
// NOTE: expect does not escape.
expect := make([]byte, acceptSize)
initAcceptFromNonce(expect, nonce)
return bytes.Equal(expect, accept)
}
// initAcceptFromNonce fills given slice with accept bytes generated from given
// nonce bytes. Given buffer should be exactly acceptSize bytes.
func initAcceptFromNonce(accept, nonce []byte) {
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
if len(accept) != acceptSize {
panic("accept buffer is invalid")
}
if len(nonce) != nonceSize {
panic("nonce is invalid")
}
p := make([]byte, nonceSize+len(magic))
copy(p[:nonceSize], nonce)
copy(p[nonceSize:], magic)
sum := sha1.Sum(p)
base64.StdEncoding.Encode(accept, sum[:])
}
func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) {
accept := make([]byte, acceptSize)
initAcceptFromNonce(accept, nonce)
// NOTE: write accept bytes as a string to prevent heap allocation
// WriteString() copy given string into its inner buffer, unlike Write()
// which may write p directly to the underlying io.Writer which in turn
// will lead to p escape.
return bw.WriteString(btsToString(accept))
}

147
vendor/github.com/gobwas/ws/read.go generated vendored Normal file
View file

@ -0,0 +1,147 @@
package ws
import (
"encoding/binary"
"fmt"
"io"
)
// Errors used by frame reader.
var (
ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0")
ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits")
)
// ReadHeader reads a frame header from r.
func ReadHeader(r io.Reader) (h Header, err error) {
// Make slice of bytes with capacity 12 that could hold any header.
//
// The maximum header size is 14, but due to the 2 hop reads,
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
// So 14 - 2 = 12.
bts := make([]byte, 2, MaxHeaderSize-2)
// Prepare to hold first 2 bytes to choose size of next read.
_, err = io.ReadFull(r, bts)
if err != nil {
return h, err
}
h.Fin = bts[0]&bit0 != 0
h.Rsv = (bts[0] & 0x70) >> 4
h.OpCode = OpCode(bts[0] & 0x0f)
var extra int
if bts[1]&bit0 != 0 {
h.Masked = true
extra += 4
}
length := bts[1] & 0x7f
switch {
case length < 126:
h.Length = int64(length)
case length == 126:
extra += 2
case length == 127:
extra += 8
default:
err = ErrHeaderLengthUnexpected
return h, err
}
if extra == 0 {
return h, err
}
// Increase len of bts to extra bytes need to read.
// Overwrite first 2 bytes that was read before.
bts = bts[:extra]
_, err = io.ReadFull(r, bts)
if err != nil {
return h, err
}
switch {
case length == 126:
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
bts = bts[2:]
case length == 127:
if bts[0]&0x80 != 0 {
err = ErrHeaderLengthMSB
return h, err
}
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
bts = bts[8:]
}
if h.Masked {
copy(h.Mask[:], bts)
}
return h, nil
}
// ReadFrame reads a frame from r.
// It is not designed for high optimized use case cause it makes allocation
// for frame.Header.Length size inside to read frame payload into.
//
// Note that ReadFrame does not unmask payload.
func ReadFrame(r io.Reader) (f Frame, err error) {
f.Header, err = ReadHeader(r)
if err != nil {
return f, err
}
if f.Header.Length > 0 {
// int(f.Header.Length) is safe here cause we have
// checked it for overflow above in ReadHeader.
f.Payload = make([]byte, int(f.Header.Length))
_, err = io.ReadFull(r, f.Payload)
}
return f, err
}
// MustReadFrame is like ReadFrame but panics if frame can not be read.
func MustReadFrame(r io.Reader) Frame {
f, err := ReadFrame(r)
if err != nil {
panic(err)
}
return f
}
// ParseCloseFrameData parses close frame status code and closure reason if any provided.
// If there is no status code in the payload
// the empty status code is returned (code.Empty()) with empty string as a reason.
func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) {
if len(payload) < 2 {
// We returning empty StatusCode here, preventing the situation
// when endpoint really sent code 1005 and we should return ProtocolError on that.
//
// In other words, we ignoring this rule [RFC6455:7.1.5]:
// If this Close control frame contains no status code, _The WebSocket
// Connection Close Code_ is considered to be 1005.
return code, reason
}
code = StatusCode(binary.BigEndian.Uint16(payload))
reason = string(payload[2:])
return code, reason
}
// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing
// that it does not copies payload bytes into reason, but prepares unsafe cast.
func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) {
if len(payload) < 2 {
return code, reason
}
code = StatusCode(binary.BigEndian.Uint16(payload))
reason = btsToString(payload[2:])
return code, reason
}

658
vendor/github.com/gobwas/ws/server.go generated vendored Normal file
View file

@ -0,0 +1,658 @@
package ws
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/gobwas/httphead"
"github.com/gobwas/pool/pbufio"
)
// Constants used by ConnUpgrader.
const (
DefaultServerReadBufferSize = 4096
DefaultServerWriteBufferSize = 512
)
// Errors used by both client and server when preparing WebSocket handshake.
var (
ErrHandshakeBadProtocol = RejectConnectionError(
RejectionStatus(http.StatusHTTPVersionNotSupported),
RejectionReason("handshake error: bad HTTP protocol version"),
)
ErrHandshakeBadMethod = RejectConnectionError(
RejectionStatus(http.StatusMethodNotAllowed),
RejectionReason("handshake error: bad HTTP request method"),
)
ErrHandshakeBadHost = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
)
ErrHandshakeBadUpgrade = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
)
ErrHandshakeBadConnection = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
)
ErrHandshakeBadSecAccept = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
)
ErrHandshakeBadSecKey = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
)
ErrHandshakeBadSecVersion = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
)
)
// ErrMalformedResponse is returned by Dialer to indicate that server response
// can not be parsed.
var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
// ErrMalformedRequest is returned when HTTP request can not be parsed.
var ErrMalformedRequest = RejectConnectionError(
RejectionStatus(http.StatusBadRequest),
RejectionReason("malformed HTTP request"),
)
// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
// connection is rejected because given WebSocket version is malformed.
//
// According to RFC6455:
// If this version does not match a version understood by the server, the
// server MUST abort the WebSocket handshake described in this section and
// instead send an appropriate HTTP error code (such as 426 Upgrade Required)
// and a |Sec-WebSocket-Version| header field indicating the version(s) the
// server is capable of understanding.
var ErrHandshakeUpgradeRequired = RejectConnectionError(
RejectionStatus(http.StatusUpgradeRequired),
RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
)
// ErrNotHijacker is an error returned when http.ResponseWriter does not
// implement http.Hijacker interface.
var ErrNotHijacker = RejectConnectionError(
RejectionStatus(http.StatusInternalServerError),
RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
)
// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
// UpgradeHTTP function.
var DefaultHTTPUpgrader HTTPUpgrader
// UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
return DefaultHTTPUpgrader.Upgrade(r, w)
}
// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
// function.
var DefaultUpgrader Upgrader
// Upgrade is like Upgrader{}.Upgrade().
func Upgrade(conn io.ReadWriter) (Handshake, error) {
return DefaultUpgrader.Upgrade(conn)
}
// HTTPUpgrader contains options for upgrading connection to websocket from
// net/http Handler arguments.
type HTTPUpgrader struct {
// Timeout is the maximum amount of time an Upgrade() will spent while
// writing handshake response.
//
// The default is no timeout.
Timeout time.Duration
// Header is an optional http.Header mapping that could be used to
// write additional headers to the handshake response.
//
// Note that if present, it will be written in any result of handshake.
Header http.Header
// Protocol is the select function that is used to select subprotocol from
// list requested by client. If this field is set, then the first matched
// protocol is sent to a client as negotiated.
Protocol func(string) bool
// Extension is the select function that is used to select extensions from
// list requested by client. If this field is set, then the all matched
// extensions are sent to a client as negotiated.
//
// Deprecated: use Negotiate instead.
Extension func(httphead.Option) bool
// Negotiate is the callback that is used to negotiate extensions from
// the client's offer. If this field is set, then the returned non-zero
// extensions are sent to the client as accepted extensions in the
// response.
//
// The argument is only valid until the Negotiate callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
Negotiate func(httphead.Option) (httphead.Option, error)
}
// Upgrade upgrades http connection to the websocket connection.
//
// It hijacks net.Conn from w and returns received net.Conn and
// bufio.ReadWriter. On successful handshake it returns Handshake struct
// describing handshake info.
func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
// Hijack connection first to get the ability to write rejection errors the
// same way as in Upgrader.
conn, rw, err = hijack(w)
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return conn, rw, hs, err
}
// See https://tools.ietf.org/html/rfc6455#section-4.1
// The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
var nonce string
if r.Method != http.MethodGet {
err = ErrHandshakeBadMethod
} else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
err = ErrHandshakeBadProtocol
} else if r.Host == "" {
err = ErrHandshakeBadHost
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
err = ErrHandshakeBadUpgrade
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
err = ErrHandshakeBadConnection
} else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
err = ErrHandshakeBadSecKey
} else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
// According to RFC6455:
//
// If this version does not match a version understood by the server,
// the server MUST abort the WebSocket handshake described in this
// section and instead send an appropriate HTTP error code (such as 426
// Upgrade Required) and a |Sec-WebSocket-Version| header field
// indicating the version(s) the server is capable of understanding.
//
// So we branching here cause empty or not present version does not
// meet the ABNF rules of RFC6455:
//
// version = DIGIT | (NZDIGIT DIGIT) |
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
// ; Limited to 0-255 range, with no leading zeros
//
// That is, if version is really invalid we sent 426 status, if it
// not present or empty it is 400.
if v != "" {
err = ErrHandshakeUpgradeRequired
} else {
err = ErrHandshakeBadSecVersion
}
}
if check := u.Protocol; err == nil && check != nil {
ps := r.Header[headerSecProtocolCanonical]
for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
var ok bool
hs.Protocol, ok = strSelectProtocol(ps[i], check)
if !ok {
err = ErrMalformedRequest
}
}
}
if f := u.Negotiate; err == nil && f != nil {
for _, h := range r.Header[headerSecExtensionsCanonical] {
hs.Extensions, err = negotiateExtensions(strToBytes(h), hs.Extensions, f)
if err != nil {
break
}
}
}
// DEPRECATED path.
if check := u.Extension; err == nil && check != nil && u.Negotiate == nil {
xs := r.Header[headerSecExtensionsCanonical]
for i := 0; i < len(xs) && err == nil; i++ {
var ok bool
hs.Extensions, ok = btsSelectExtensions(strToBytes(xs[i]), hs.Extensions, check)
if !ok {
err = ErrMalformedRequest
}
}
}
// Clear deadlines set by server.
conn.SetDeadline(noDeadline)
if t := u.Timeout; t != 0 {
conn.SetWriteDeadline(time.Now().Add(t))
defer conn.SetWriteDeadline(noDeadline)
}
var header handshakeHeader
if h := u.Header; h != nil {
header[0] = HandshakeHeaderHTTP(h)
}
if err == nil {
httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
err = rw.Writer.Flush()
} else {
var code int
if rej, ok := err.(*ConnectionRejectedError); ok {
code = rej.code
header[1] = rej.header
}
if code == 0 {
code = http.StatusInternalServerError
}
httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
// Do not store Flush() error to not override already existing one.
_ = rw.Writer.Flush()
}
return conn, rw, hs, err
}
// Upgrader contains options for upgrading connection to websocket.
type Upgrader struct {
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
// They used to read and write http data while upgrading to WebSocket.
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
//
// If a size is zero then default value is used.
//
// Usually it is useful to set read buffer size bigger than write buffer
// size because incoming request could contain long header values, such as
// Cookie. Response, in other way, could be big only if user write multiple
// custom headers. Usually response takes less than 256 bytes.
ReadBufferSize, WriteBufferSize int
// Protocol is a select function that is used to select subprotocol
// from list requested by client. If this field is set, then the first matched
// protocol is sent to a client as negotiated.
//
// The argument is only valid until the callback returns.
Protocol func([]byte) bool
// ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
// Note that returned bytes must be valid until Upgrade returns.
// If ProtocolCustom is set, it used instead of Protocol function.
ProtocolCustom func([]byte) (string, bool)
// Extension is a select function that is used to select extensions
// from list requested by client. If this field is set, then the all matched
// extensions are sent to a client as negotiated.
//
// Note that Extension may be called multiple times and implementations
// must track uniqueness of accepted extensions manually.
//
// The argument is only valid until the callback returns.
//
// According to the RFC6455 order of extensions passed by a client is
// significant. That is, returning true from this function means that no
// other extension with the same name should be checked because server
// accepted the most preferable extension right now:
// "Note that the order of extensions is significant. Any interactions between
// multiple extensions MAY be defined in the documents defining the extensions.
// In the absence of such definitions, the interpretation is that the header
// fields listed by the client in its request represent a preference of the
// header fields it wishes to use, with the first options listed being most
// preferable."
//
// Deprecated: use Negotiate instead.
Extension func(httphead.Option) bool
// ExtensionCustom allow user to parse Sec-WebSocket-Extensions header
// manually.
//
// If ExtensionCustom() decides to accept received extension, it must
// append appropriate option to the given slice of httphead.Option.
// It returns results of append() to the given slice and a flag that
// reports whether given header value is wellformed or not.
//
// Note that ExtensionCustom may be called multiple times and
// implementations must track uniqueness of accepted extensions manually.
//
// Note that returned options should be valid until Upgrade returns.
// If ExtensionCustom is set, it used instead of Extension function.
ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
// Negotiate is the callback that is used to negotiate extensions from
// the client's offer. If this field is set, then the returned non-zero
// extensions are sent to the client as accepted extensions in the
// response.
//
// The argument is only valid until the Negotiate callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
Negotiate func(httphead.Option) (httphead.Option, error)
// Header is an optional HandshakeHeader instance that could be used to
// write additional headers to the handshake response.
//
// It used instead of any key-value mappings to avoid allocations in user
// land.
//
// Note that if present, it will be written in any result of handshake.
Header HandshakeHeader
// OnRequest is a callback that will be called after request line
// successful parsing.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnRequest func(uri []byte) error
// OnHost is a callback that will be called after "Host" header successful
// parsing.
//
// It is separated from OnHeader callback because the Host header must be
// present in each request since HTTP/1.1. Thus Host header is non-optional
// and required for every WebSocket handshake.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnHost func(host []byte) error
// OnHeader is a callback that will be called after successful parsing of
// header, that is not used during WebSocket handshake procedure. That is,
// it will be called with non-websocket headers, which could be relevant
// for application-level logic.
//
// The arguments are only valid until the callback returns.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnHeader func(key, value []byte) error
// OnBeforeUpgrade is a callback that will be called before sending
// successful upgrade response.
//
// Setting OnBeforeUpgrade allows user to make final application-level
// checks and decide whether this connection is allowed to successfully
// upgrade to WebSocket.
//
// It must return non-nil either HandshakeHeader or error and never both.
//
// If returned error is non-nil then connection is rejected and response is
// sent with appropriate HTTP error code and body set to error message.
//
// RejectConnectionError could be used to get more control on response.
OnBeforeUpgrade func() (header HandshakeHeader, err error)
}
// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
// as connection with incoming HTTP Upgrade request.
//
// It is a caller responsibility to manage i/o timeouts on conn.
//
// Non-nil error means that request for the WebSocket upgrade is invalid or
// malformed and usually connection should be closed.
// Even when error is non-nil Upgrade will write appropriate response into
// connection in compliance with RFC.
func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
// headerSeen constants helps to report whether or not some header was seen
// during reading request bytes.
const (
headerSeenHost = 1 << iota
headerSeenUpgrade
headerSeenConnection
headerSeenSecVersion
headerSeenSecKey
// headerSeenAll is the value that we expect to receive at the end of
// headers read/parse loop.
headerSeenAll = 0 |
headerSeenHost |
headerSeenUpgrade |
headerSeenConnection |
headerSeenSecVersion |
headerSeenSecKey
)
// Prepare I/O buffers.
// TODO(gobwas): make it configurable.
br := pbufio.GetReader(conn,
nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
)
bw := pbufio.GetWriter(conn,
nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
)
defer func() {
pbufio.PutReader(br)
pbufio.PutWriter(bw)
}()
// Read HTTP request line like "GET /ws HTTP/1.1".
rl, err := readLine(br)
if err != nil {
return hs, err
}
// Parse request line data like HTTP version, uri and method.
req, err := httpParseRequestLine(rl)
if err != nil {
return hs, err
}
// Prepare stack-based handshake header list.
header := handshakeHeader{
0: u.Header,
}
// Parse and check HTTP request.
// As RFC6455 says:
// The client's opening handshake consists of the following parts. If the
// server, while reading the handshake, finds that the client did not
// send a handshake that matches the description below (note that as per
// [RFC2616], the order of the header fields is not important), including
// but not limited to any violations of the ABNF grammar specified for
// the components of the handshake, the server MUST stop processing the
// client's handshake and return an HTTP response with an appropriate
// error code (such as 400 Bad Request).
//
// See https://tools.ietf.org/html/rfc6455#section-4.2.1
// An HTTP/1.1 or higher GET request, including a "Request-URI".
//
// Even if RFC says "1.1 or higher" without mentioning the part of the
// version, we apply it only to minor part.
switch {
case req.major != 1 || req.minor < 1:
// Abort processing the whole request because we do not even know how
// to actually parse it.
err = ErrHandshakeBadProtocol
case btsToString(req.method) != http.MethodGet:
err = ErrHandshakeBadMethod
default:
if onRequest := u.OnRequest; onRequest != nil {
err = onRequest(req.uri)
}
}
// Start headers read/parse loop.
var (
// headerSeen reports which header was seen by setting corresponding
// bit on.
headerSeen byte
nonce = make([]byte, nonceSize)
)
for err == nil {
line, e := readLine(br)
if e != nil {
return hs, e
}
if len(line) == 0 {
// Blank line, no more lines to read.
break
}
k, v, ok := httpParseHeaderLine(line)
if !ok {
err = ErrMalformedRequest
break
}
switch btsToString(k) {
case headerHostCanonical:
headerSeen |= headerSeenHost
if onHost := u.OnHost; onHost != nil {
err = onHost(v)
}
case headerUpgradeCanonical:
headerSeen |= headerSeenUpgrade
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
err = ErrHandshakeBadUpgrade
}
case headerConnectionCanonical:
headerSeen |= headerSeenConnection
if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
err = ErrHandshakeBadConnection
}
case headerSecVersionCanonical:
headerSeen |= headerSeenSecVersion
if !bytes.Equal(v, specHeaderValueSecVersion) {
err = ErrHandshakeUpgradeRequired
}
case headerSecKeyCanonical:
headerSeen |= headerSeenSecKey
if len(v) != nonceSize {
err = ErrHandshakeBadSecKey
} else {
copy(nonce, v)
}
case headerSecProtocolCanonical:
if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
var ok bool
if custom != nil {
hs.Protocol, ok = custom(v)
} else {
hs.Protocol, ok = btsSelectProtocol(v, check)
}
if !ok {
err = ErrMalformedRequest
}
}
case headerSecExtensionsCanonical:
if f := u.Negotiate; err == nil && f != nil {
hs.Extensions, err = negotiateExtensions(v, hs.Extensions, f)
}
// DEPRECATED path.
if custom, check := u.ExtensionCustom, u.Extension; u.Negotiate == nil && (custom != nil || check != nil) {
var ok bool
if custom != nil {
hs.Extensions, ok = custom(v, hs.Extensions)
} else {
hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
}
if !ok {
err = ErrMalformedRequest
}
}
default:
if onHeader := u.OnHeader; onHeader != nil {
err = onHeader(k, v)
}
}
}
switch {
case err == nil && headerSeen != headerSeenAll:
switch {
case headerSeen&headerSeenHost == 0:
// As RFC2616 says:
// A client MUST include a Host header field in all HTTP/1.1
// request messages. If the requested URI does not include an
// Internet host name for the service being requested, then the
// Host header field MUST be given with an empty value. An
// HTTP/1.1 proxy MUST ensure that any request message it
// forwards does contain an appropriate Host header field that
// identifies the service being requested by the proxy. All
// Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
// Request) status code to any HTTP/1.1 request message which
// lacks a Host header field.
err = ErrHandshakeBadHost
case headerSeen&headerSeenUpgrade == 0:
err = ErrHandshakeBadUpgrade
case headerSeen&headerSeenConnection == 0:
err = ErrHandshakeBadConnection
case headerSeen&headerSeenSecVersion == 0:
// In case of empty or not present version we do not send 426 status,
// because it does not meet the ABNF rules of RFC6455:
//
// version = DIGIT | (NZDIGIT DIGIT) |
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
// ; Limited to 0-255 range, with no leading zeros
//
// That is, if version is really invalid we sent 426 status as above, if it
// not present it is 400.
err = ErrHandshakeBadSecVersion
case headerSeen&headerSeenSecKey == 0:
err = ErrHandshakeBadSecKey
default:
panic("unknown headers state")
}
case err == nil && u.OnBeforeUpgrade != nil:
header[1], err = u.OnBeforeUpgrade()
}
if err != nil {
var code int
if rej, ok := err.(*ConnectionRejectedError); ok {
code = rej.code
header[1] = rej.header
}
if code == 0 {
code = http.StatusInternalServerError
}
httpWriteResponseError(bw, err, code, header.WriteTo)
// Do not store Flush() error to not override already existing one.
_ = bw.Flush()
return hs, err
}
httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
err = bw.Flush()
return hs, err
}
type handshakeHeader [2]HandshakeHeader
func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
for i := 0; i < len(hs) && err == nil; i++ {
if h := hs[i]; h != nil {
var m int64
m, err = h.WriteTo(w)
n += m
}
}
return n, err
}

199
vendor/github.com/gobwas/ws/util.go generated vendored Normal file
View file

@ -0,0 +1,199 @@
package ws
import (
"bufio"
"bytes"
"fmt"
"github.com/gobwas/httphead"
)
// SelectFromSlice creates accept function that could be used as Protocol/Extension
// select during upgrade.
func SelectFromSlice(accept []string) func(string) bool {
if len(accept) > 16 {
mp := make(map[string]struct{}, len(accept))
for _, p := range accept {
mp[p] = struct{}{}
}
return func(p string) bool {
_, ok := mp[p]
return ok
}
}
return func(p string) bool {
for _, ok := range accept {
if p == ok {
return true
}
}
return false
}
}
// SelectEqual creates accept function that could be used as Protocol/Extension
// select during upgrade.
func SelectEqual(v string) func(string) bool {
return func(p string) bool {
return v == p
}
}
// asciiToInt converts bytes to int.
func asciiToInt(bts []byte) (ret int, err error) {
// ASCII numbers all start with the high-order bits 0011.
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
// bits and interpret them directly as an integer.
var n int
if n = len(bts); n < 1 {
return 0, fmt.Errorf("converting empty bytes to int")
}
for i := 0; i < n; i++ {
if bts[i]&0xf0 != 0x30 {
return 0, fmt.Errorf("%s is not a numeric character", string(bts[i]))
}
ret += int(bts[i]&0xf) * pow(10, n-i-1)
}
return ret, nil
}
// pow for integers implementation.
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3.
func pow(a, b int) int {
p := 1
for b > 0 {
if b&1 != 0 {
p *= a
}
b >>= 1
a *= a
}
return p
}
func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) {
a := bytes.IndexByte(bts, sep)
b := bytes.IndexByte(bts[a+1:], sep)
if a == -1 || b == -1 {
return bts, nil, nil
}
b += a + 1
return bts[:a], bts[a+1 : b], bts[b+1:]
}
func btrim(bts []byte) []byte {
var i, j int
for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); {
i++
}
for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); {
j--
}
return bts[i:j]
}
func strHasToken(header, token string) (has bool) {
return btsHasToken(strToBytes(header), strToBytes(token))
}
func btsHasToken(header, token []byte) (has bool) {
httphead.ScanTokens(header, func(v []byte) bool {
has = bytes.EqualFold(v, token)
return !has
})
return has
}
const (
toLower = 'a' - 'A' // for use with OR.
toUpper = ^byte(toLower) // for use with AND.
toLower8 = uint64(toLower) |
uint64(toLower)<<8 |
uint64(toLower)<<16 |
uint64(toLower)<<24 |
uint64(toLower)<<32 |
uint64(toLower)<<40 |
uint64(toLower)<<48 |
uint64(toLower)<<56
)
// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except
// that it operates with slice of bytes and modifies it inplace without copying.
func canonicalizeHeaderKey(k []byte) {
upper := true
for i, c := range k {
if upper && 'a' <= c && c <= 'z' {
k[i] &= toUpper
} else if !upper && 'A' <= c && c <= 'Z' {
k[i] |= toLower
}
upper = c == '-'
}
}
// readLine reads line from br. It reads until '\n' and returns bytes without
// '\n' or '\r\n' at the end.
// It returns err if and only if line does not end in '\n'. Note that read
// bytes returned in any case of error.
//
// It is much like the textproto/Reader.ReadLine() except the thing that it
// returns raw bytes, instead of string. That is, it avoids copying bytes read
// from br.
//
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
// safe with future I/O operations on br.
//
// We could control I/O operations on br and do not need to make additional
// copy for safety.
//
// NOTE: it may return copied flag to notify that returned buffer is safe to
// use.
func readLine(br *bufio.Reader) ([]byte, error) {
var line []byte
for {
bts, err := br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
// Copy bytes because next read will discard them.
line = append(line, bts...)
continue
}
// Avoid copy of single read.
if line == nil {
line = bts
} else {
line = append(line, bts...)
}
if err != nil {
return line, err
}
// Size of line is at least 1.
// In other case bufio.ReadSlice() returns error.
n := len(line)
// Cut '\n' or '\r\n'.
if n > 1 && line[n-2] == '\r' {
line = line[:n-2]
} else {
line = line[:n-1]
}
return line, nil
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func nonZero(a, b int) int {
if a != 0 {
return a
}
return b
}

12
vendor/github.com/gobwas/ws/util_purego.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
//go:build purego
// +build purego
package ws
func strToBytes(str string) (bts []byte) {
return []byte(str)
}
func btsToString(bts []byte) (str string) {
return string(bts)
}

22
vendor/github.com/gobwas/ws/util_unsafe.go generated vendored Normal file
View file

@ -0,0 +1,22 @@
//go:build !purego
// +build !purego
package ws
import (
"reflect"
"unsafe"
)
func strToBytes(str string) (bts []byte) {
s := (*reflect.StringHeader)(unsafe.Pointer(&str))
b := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
b.Data = s.Data
b.Len = s.Len
b.Cap = s.Len
return bts
}
func btsToString(bts []byte) (str string) {
return *(*string)(unsafe.Pointer(&bts))
}

104
vendor/github.com/gobwas/ws/write.go generated vendored Normal file
View file

@ -0,0 +1,104 @@
package ws
import (
"encoding/binary"
"io"
)
// Header size length bounds in bytes.
const (
MaxHeaderSize = 14
MinHeaderSize = 2
)
const (
bit0 = 0x80
bit1 = 0x40
bit2 = 0x20
bit3 = 0x10
bit4 = 0x08
bit5 = 0x04
bit6 = 0x02
bit7 = 0x01
len7 = int64(125)
len16 = int64(^(uint16(0)))
len64 = int64(^(uint64(0)) >> 1)
)
// HeaderSize returns number of bytes that are needed to encode given header.
// It returns -1 if header is malformed.
func HeaderSize(h Header) (n int) {
switch {
case h.Length < 126:
n = 2
case h.Length <= len16:
n = 4
case h.Length <= len64:
n = 10
default:
return -1
}
if h.Masked {
n += len(h.Mask)
}
return n
}
// WriteHeader writes header binary representation into w.
func WriteHeader(w io.Writer, h Header) error {
// Make slice of bytes with capacity 14 that could hold any header.
bts := make([]byte, MaxHeaderSize)
if h.Fin {
bts[0] |= bit0
}
bts[0] |= h.Rsv << 4
bts[0] |= byte(h.OpCode)
var n int
switch {
case h.Length <= len7:
bts[1] = byte(h.Length)
n = 2
case h.Length <= len16:
bts[1] = 126
binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length))
n = 4
case h.Length <= len64:
bts[1] = 127
binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length))
n = 10
default:
return ErrHeaderLengthUnexpected
}
if h.Masked {
bts[1] |= bit0
n += copy(bts[n:], h.Mask[:])
}
_, err := w.Write(bts[:n])
return err
}
// WriteFrame writes frame binary representation into w.
func WriteFrame(w io.Writer, f Frame) error {
err := WriteHeader(w, f.Header)
if err != nil {
return err
}
_, err = w.Write(f.Payload)
return err
}
// MustWriteFrame is like WriteFrame but panics if frame can not be read.
func MustWriteFrame(w io.Writer, f Frame) {
if err := WriteFrame(w, f); err != nil {
panic(err)
}
}

134
vendor/github.com/gobwas/ws/wsflate/cbuf.go generated vendored Normal file
View file

@ -0,0 +1,134 @@
package wsflate
import (
"io"
)
// cbuf is a tiny proxy-buffer that writes all but 4 last bytes to the
// destination.
type cbuf struct {
buf [4]byte
n int
dst io.Writer
err error
}
// Write implements io.Writer interface.
func (c *cbuf) Write(p []byte) (int, error) {
if c.err != nil {
return 0, c.err
}
head, tail := c.split(p)
n := c.n + len(tail)
if n > len(c.buf) {
x := n - len(c.buf)
c.flush(c.buf[:x])
copy(c.buf[:], c.buf[x:])
c.n -= x
}
if len(head) > 0 {
c.flush(head)
}
copy(c.buf[c.n:], tail)
c.n = min(c.n+len(tail), len(c.buf))
return len(p), c.err
}
func (c *cbuf) flush(p []byte) {
if c.err == nil {
_, c.err = c.dst.Write(p)
}
}
func (c *cbuf) split(p []byte) (head, tail []byte) {
if n := len(p); n > len(c.buf) {
x := n - len(c.buf)
head = p[:x]
tail = p[x:]
return head, tail
}
return nil, p
}
func (c *cbuf) reset(dst io.Writer) {
c.n = 0
c.err = nil
c.buf = [4]byte{0, 0, 0, 0}
c.dst = dst
}
type suffixedReader struct {
r io.Reader
pos int // position in the suffix.
suffix [9]byte
rx struct{ io.Reader }
}
func (r *suffixedReader) iface() io.Reader {
if _, ok := r.r.(io.ByteReader); ok {
// If source io.Reader implements io.ByteReader, return full set of
// methods from suffixedReader struct (Read() and ReadByte()).
// This actually is an optimization needed for those Decompressor
// implementations (such as default flate.Reader) which do check if
// given source is already "buffered" by checking if source implements
// io.ByteReader. So without this checks we will always result in
// double-buffering for default decompressors.
return r
}
// Source io.Reader doesn't support io.ByteReader, so we should cut off the
// ReadByte() method from suffixedReader struct. We use r.srx field to
// avoid allocations.
r.rx.Reader = r
return &r.rx
}
func (r *suffixedReader) Read(p []byte) (n int, err error) {
if r.r != nil {
n, err = r.r.Read(p)
if err == io.EOF {
err = nil
r.r = nil
}
return n, err
}
if r.pos >= len(r.suffix) {
return 0, io.EOF
}
n = copy(p, r.suffix[r.pos:])
r.pos += n
return n, nil
}
func (r *suffixedReader) ReadByte() (b byte, err error) {
if r.r != nil {
br, ok := r.r.(io.ByteReader)
if !ok {
panic("wsflate: internal error: incorrect use of suffixedReader")
}
b, err = br.ReadByte()
if err == io.EOF {
err = nil
r.r = nil
}
return b, err
}
if r.pos >= len(r.suffix) {
return 0, io.EOF
}
b = r.suffix[r.pos]
r.pos++
return b, nil
}
func (r *suffixedReader) reset(src io.Reader) {
r.r = src
r.pos = 0
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

208
vendor/github.com/gobwas/ws/wsflate/extension.go generated vendored Normal file
View file

@ -0,0 +1,208 @@
package wsflate
import (
"bytes"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
)
// Extension contains logic of compression extension parameters negotiation
// made during HTTP WebSocket handshake.
// It might be reused between different upgrades (but not concurrently) with
// Reset() being called after each.
type Extension struct {
// Parameters is specification of extension parameters server is going to
// accept.
Parameters Parameters
accepted bool
params Parameters
}
// Negotiate parses given HTTP header option and returns (if any) header option
// which describes accepted parameters.
//
// It may return zero option (i.e. one which Size() returns 0) alongside with
// nil error.
func (n *Extension) Negotiate(opt httphead.Option) (accept httphead.Option, err error) {
if !bytes.Equal(opt.Name, ExtensionNameBytes) {
return accept, nil
}
if n.accepted {
// Negotiate might be called multiple times during upgrade.
// We stick to first one accepted extension since they must be passed
// in ordered by preference.
return accept, nil
}
want := n.Parameters
// NOTE: Parse() resets params inside, so no worries.
if err := n.params.Parse(opt); err != nil {
return accept, err
}
{
offer := n.params.ServerMaxWindowBits
want := want.ServerMaxWindowBits
if offer > want {
// A server declines an extension negotiation offer
// with this parameter if the server doesn't support
// it.
return accept, nil
}
}
{
// If a received extension negotiation offer has the
// "client_max_window_bits" extension parameter, the server MAY
// include the "client_max_window_bits" extension parameter in the
// corresponding extension negotiation response to the offer.
offer := n.params.ClientMaxWindowBits
want := want.ClientMaxWindowBits
if want > offer {
return accept, nil
}
}
{
offer := n.params.ServerNoContextTakeover
want := want.ServerNoContextTakeover
if offer && !want {
return accept, nil
}
}
n.accepted = true
return want.Option(), nil
}
// Accepted returns parameters parsed during last negotiation and a flag that
// reports whether they were accepted.
func (n *Extension) Accepted() (_ Parameters, accepted bool) {
return n.params, n.accepted
}
// Reset resets extension for further reuse.
func (n *Extension) Reset() {
n.accepted = false
n.params = Parameters{}
}
var ErrUnexpectedCompressionBit = ws.ProtocolError(
"control frame or non-first fragment of data contains compression bit set",
)
// UnsetBit clears the Per-Message Compression bit in header h and returns its
// modified copy. It reports whether compression bit was set in header h.
// It returns non-nil error if compression bit has unexpected value.
//
// This function's main purpose is to be compatible with "Framing" section of
// the Compression Extensions for WebSocket RFC. If you don't need to work with
// chains of extensions then IsCompressed() could be enough to check if
// message is compressed.
// See https://tools.ietf.org/html/rfc7692#section-6.2
func UnsetBit(h ws.Header) (_ ws.Header, wasSet bool, err error) {
var s MessageState
h, err = s.UnsetBits(h)
return h, s.IsCompressed(), err
}
// SetBit sets the Per-Message Compression bit in header h and returns its
// modified copy.
// It returns non-nil error if compression bit has unexpected value.
func SetBit(h ws.Header) (_ ws.Header, err error) {
var s MessageState
s.SetCompressed(true)
return s.SetBits(h)
}
// IsCompressed reports whether the Per-Message Compression bit is set in
// header h.
// It returns non-nil error if compression bit has unexpected value.
//
// If you need to be fully compatible with Compression Extensions for WebSocket
// RFC and work with chains of extensions, take a look at the UnsetBit()
// instead. That is, IsCompressed() is a shortcut for UnsetBit() with reduced
// number of return values.
func IsCompressed(h ws.Header) (bool, error) {
_, isSet, err := UnsetBit(h)
return isSet, err
}
// MessageState holds message compression state.
//
// It is consulted during SetBits(h) call to make a decision whether we must
// set the Per-Message Compression bit for given header h argument.
// It is updated during UnsetBits(h) to reflect compression state of a message
// represented by header h argument.
// It can also be consulted/updated directly by calling
// IsCompressed()/SetCompressed().
//
// In general MessageState should be used when there is no direct access to
// connection to read frame from, but it is still needed to know if message
// being read is compressed. For other cases SetBit() and UnsetBit() should be
// used instead.
//
// NOTE: the compression state is updated during UnsetBits(h) only when header
// h argument represents data (text or binary) frame.
type MessageState struct {
compressed bool
}
// SetCompressed marks message as "compressed" or "uncompressed".
// See https://tools.ietf.org/html/rfc7692#section-6
func (s *MessageState) SetCompressed(v bool) {
s.compressed = v
}
// IsCompressed reports whether message is "compressed".
// See https://tools.ietf.org/html/rfc7692#section-6
func (s *MessageState) IsCompressed() bool {
return s.compressed
}
// UnsetBits changes RSV bits of the given frame header h as if compression
// extension was negotiated. It returns modified copy of h and error if header
// is malformed from the RFC perspective.
func (s *MessageState) UnsetBits(h ws.Header) (ws.Header, error) {
r1, r2, r3 := ws.RsvBits(h.Rsv)
switch {
case h.OpCode.IsData() && h.OpCode != ws.OpContinuation:
h.Rsv = ws.Rsv(false, r2, r3)
s.SetCompressed(r1)
return h, nil
case r1:
// An endpoint MUST NOT set the "Per-Message Compressed"
// bit of control frames and non-first fragments of a data
// message. An endpoint receiving such a frame MUST _Fail
// the WebSocket Connection_.
return h, ErrUnexpectedCompressionBit
default:
// NOTE: do not change the state of s.compressed since UnsetBits()
// might also be called for (intermediate) control frames.
return h, nil
}
}
// SetBits changes RSV bits of the frame header h which is being send as if
// compression extension was negotiated. It returns modified copy of h and
// error if header is malformed from the RFC perspective.
func (s *MessageState) SetBits(h ws.Header) (ws.Header, error) {
r1, r2, r3 := ws.RsvBits(h.Rsv)
if r1 {
return h, ErrUnexpectedCompressionBit
}
if !h.OpCode.IsData() || h.OpCode == ws.OpContinuation {
// An endpoint MUST NOT set the "Per-Message Compressed"
// bit of control frames and non-first fragments of a data
// message. An endpoint receiving such a frame MUST _Fail
// the WebSocket Connection_.
return h, nil
}
if s.IsCompressed() {
h.Rsv = ws.Rsv(true, r2, r3)
}
return h, nil
}

195
vendor/github.com/gobwas/ws/wsflate/helper.go generated vendored Normal file
View file

@ -0,0 +1,195 @@
package wsflate
import (
"bytes"
"compress/flate"
"fmt"
"io"
"github.com/gobwas/ws"
)
// DefaultHelper is a default helper instance holding standard library's
// `compress/flate` compressor and decompressor under the hood.
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
var DefaultHelper = Helper{
Compressor: func(w io.Writer) Compressor {
// No error can be returned here as NewWriter() doc says.
f, _ := flate.NewWriter(w, 9)
return f
},
Decompressor: func(r io.Reader) Decompressor {
return flate.NewReader(r)
},
}
// DefaultParameters holds deflate extension parameters which are assumed by
// DefaultHelper to be used during WebSocket handshake.
var DefaultParameters = Parameters{
ServerNoContextTakeover: true,
ClientNoContextTakeover: true,
}
// CompressFrame is a shortcut for DefaultHelper.CompressFrame().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func CompressFrame(f ws.Frame) (ws.Frame, error) {
return DefaultHelper.CompressFrame(f)
}
// CompressFrameBuffer is a shortcut for DefaultHelper.CompressFrameBuffer().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
return DefaultHelper.CompressFrameBuffer(buf, f)
}
// DecompressFrame is a shortcut for DefaultHelper.DecompressFrame().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func DecompressFrame(f ws.Frame) (ws.Frame, error) {
return DefaultHelper.DecompressFrame(f)
}
// DecompressFrameBuffer is a shortcut for
// DefaultHelper.DecompressFrameBuffer().
//
// Note that use of DefaultHelper methods assumes that DefaultParameters were
// used for extension negotiation during WebSocket handshake.
func DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
return DefaultHelper.DecompressFrameBuffer(buf, f)
}
// Helper is a helper struct that holds common code for compression and
// decompression bytes or WebSocket frames.
//
// Its purpose is to reduce boilerplate code in WebSocket applications.
type Helper struct {
Compressor func(w io.Writer) Compressor
Decompressor func(r io.Reader) Decompressor
}
// Buffer is an interface representing some bytes buffering object.
type Buffer interface {
io.Writer
Bytes() []byte
}
// CompressFrame returns compressed version of a frame.
// Note that it does memory allocations internally. To control those
// allocations consider using CompressFrameBuffer().
func (h *Helper) CompressFrame(in ws.Frame) (f ws.Frame, err error) {
var buf bytes.Buffer
return h.CompressFrameBuffer(&buf, in)
}
// DecompressFrame returns decompressed version of a frame.
// Note that it does memory allocations internally. To control those
// allocations consider using DecompressFrameBuffer().
func (h *Helper) DecompressFrame(in ws.Frame) (f ws.Frame, err error) {
var buf bytes.Buffer
return h.DecompressFrameBuffer(&buf, in)
}
// CompressFrameBuffer compresses a frame using given buffer.
// Returned frame's payload holds bytes returned by buf.Bytes().
func (h *Helper) CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
if !f.Header.Fin {
return f, fmt.Errorf("wsflate: fragmented messages are not allowed")
}
if err := h.CompressTo(buf, f.Payload); err != nil {
return f, err
}
var err error
f.Payload = buf.Bytes()
f.Header.Length = int64(len(f.Payload))
f.Header, err = SetBit(f.Header)
if err != nil {
return f, err
}
return f, nil
}
// DecompressFrameBuffer decompresses a frame using given buffer.
// Returned frame's payload holds bytes returned by buf.Bytes().
func (h *Helper) DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
if !f.Header.Fin {
return f, fmt.Errorf(
"wsflate: fragmented messages are not supported by helper",
)
}
var (
compressed bool
err error
)
f.Header, compressed, err = UnsetBit(f.Header)
if err != nil {
return f, err
}
if !compressed {
return f, nil
}
if err := h.DecompressTo(buf, f.Payload); err != nil {
return f, err
}
f.Payload = buf.Bytes()
f.Header.Length = int64(len(f.Payload))
return f, nil
}
// Compress compresses given bytes.
// Note that it does memory allocations internally. To control those
// allocations consider using CompressTo().
func (h *Helper) Compress(p []byte) ([]byte, error) {
var buf bytes.Buffer
if err := h.CompressTo(&buf, p); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Decompress decompresses given bytes.
// Note that it does memory allocations internally. To control those
// allocations consider using DecompressTo().
func (h *Helper) Decompress(p []byte) ([]byte, error) {
var buf bytes.Buffer
if err := h.DecompressTo(&buf, p); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// CompressTo compresses bytes into given buffer.
func (h *Helper) CompressTo(w io.Writer, p []byte) (err error) {
c := NewWriter(w, h.Compressor)
if _, err = c.Write(p); err != nil {
return err
}
if err := c.Flush(); err != nil {
return err
}
if err := c.Close(); err != nil {
return err
}
return nil
}
// DecompressTo decompresses bytes into given buffer.
// Returned bytes are bytes returned by buf.Bytes().
func (h *Helper) DecompressTo(w io.Writer, p []byte) (err error) {
fr := NewReader(bytes.NewReader(p), h.Decompressor)
if _, err = io.Copy(w, fr); err != nil {
return err
}
if err := fr.Close(); err != nil {
return err
}
return nil
}

197
vendor/github.com/gobwas/ws/wsflate/parameters.go generated vendored Normal file
View file

@ -0,0 +1,197 @@
package wsflate
import (
"fmt"
"strconv"
"github.com/gobwas/httphead"
)
const (
ExtensionName = "permessage-deflate"
serverNoContextTakeover = "server_no_context_takeover"
clientNoContextTakeover = "client_no_context_takeover"
serverMaxWindowBits = "server_max_window_bits"
clientMaxWindowBits = "client_max_window_bits"
)
var (
ExtensionNameBytes = []byte(ExtensionName)
serverNoContextTakeoverBytes = []byte(serverNoContextTakeover)
clientNoContextTakeoverBytes = []byte(clientNoContextTakeover)
serverMaxWindowBitsBytes = []byte(serverMaxWindowBits)
clientMaxWindowBitsBytes = []byte(clientMaxWindowBits)
)
var windowBits [8][]byte
func init() {
for i := range windowBits {
windowBits[i] = []byte(strconv.Itoa(i + 8))
}
}
// Parameters contains compression extension options.
type Parameters struct {
ServerNoContextTakeover bool
ClientNoContextTakeover bool
ServerMaxWindowBits WindowBits
ClientMaxWindowBits WindowBits
}
// WindowBits specifies window size accordingly to RFC.
// Use its Bytes() method to obtain actual size of window in bytes.
type WindowBits byte
// Defined reports whether window bits were specified.
func (b WindowBits) Defined() bool {
return b > 0
}
// Bytes returns window size in number of bytes.
func (b WindowBits) Bytes() int {
return 1 << uint(b)
}
const (
MaxLZ77WindowSize = 32768 // 2^15
)
// Parse reads parameters from given HTTP header option accordingly to RFC.
//
// It returns non-nil error at least in these cases:
// - The negotiation offer contains an extension parameter not defined for
// use in an offer/response.
// - The negotiation offer/response contains an extension parameter with an
// invalid value.
// - The negotiation offer/response contains multiple extension parameters
// with the same name.
func (p *Parameters) Parse(opt httphead.Option) (err error) {
const (
clientMaxWindowBitsSeen = 1 << iota
serverMaxWindowBitsSeen
clientNoContextTakeoverSeen
serverNoContextTakeoverSeen
)
// Reset to not mix parsed data from previous Parse() calls.
*p = Parameters{}
var seen byte
opt.Parameters.ForEach(func(key, val []byte) (ok bool) {
switch string(key) {
case clientMaxWindowBits:
if len(val) == 0 {
p.ClientMaxWindowBits = 1
return true
}
if seen&clientMaxWindowBitsSeen != 0 {
err = paramError("duplicate", key, val)
return false
}
seen |= clientMaxWindowBitsSeen
if p.ClientMaxWindowBits, ok = bitsFromASCII(val); !ok {
err = paramError("invalid", key, val)
return false
}
case serverMaxWindowBits:
if len(val) == 0 {
err = paramError("invalid", key, val)
return false
}
if seen&serverMaxWindowBitsSeen != 0 {
err = paramError("duplicate", key, val)
return false
}
seen |= serverMaxWindowBitsSeen
if p.ServerMaxWindowBits, ok = bitsFromASCII(val); !ok {
err = paramError("invalid", key, val)
return false
}
case clientNoContextTakeover:
if len(val) > 0 {
err = paramError("invalid", key, val)
return false
}
if seen&clientNoContextTakeoverSeen != 0 {
err = paramError("duplicate", key, val)
return false
}
seen |= clientNoContextTakeoverSeen
p.ClientNoContextTakeover = true
case serverNoContextTakeover:
if len(val) > 0 {
err = paramError("invalid", key, val)
return false
}
if seen&serverNoContextTakeoverSeen != 0 {
err = paramError("duplicate", key, val)
return false
}
seen |= serverNoContextTakeoverSeen
p.ServerNoContextTakeover = true
default:
err = paramError("unexpected", key, val)
return false
}
return true
})
return err
}
// Option encodes parameters into HTTP header option.
func (p Parameters) Option() httphead.Option {
opt := httphead.Option{
Name: ExtensionNameBytes,
}
setBool(&opt, serverNoContextTakeoverBytes, p.ServerNoContextTakeover)
setBool(&opt, clientNoContextTakeoverBytes, p.ClientNoContextTakeover)
setBits(&opt, serverMaxWindowBitsBytes, p.ServerMaxWindowBits)
setBits(&opt, clientMaxWindowBitsBytes, p.ClientMaxWindowBits)
return opt
}
func isValidBits(x int) bool {
return 8 <= x && x <= 15
}
func bitsFromASCII(p []byte) (WindowBits, bool) {
n, ok := httphead.IntFromASCII(p)
if !ok || !isValidBits(n) {
return 0, false
}
return WindowBits(n), true
}
func setBits(opt *httphead.Option, name []byte, bits WindowBits) {
if bits == 0 {
return
}
if bits == 1 {
opt.Parameters.Set(name, nil)
return
}
if !isValidBits(int(bits)) {
panic(fmt.Sprintf("wsflate: invalid bits value: %d", bits))
}
opt.Parameters.Set(name, windowBits[bits-8])
}
func setBool(opt *httphead.Option, name []byte, flag bool) {
if flag {
opt.Parameters.Set(name, nil)
}
}
func paramError(reason string, key, val []byte) error {
return fmt.Errorf(
"wsflate: %s extension parameter %q: %q",
reason, key, val,
)
}

84
vendor/github.com/gobwas/ws/wsflate/reader.go generated vendored Normal file
View file

@ -0,0 +1,84 @@
package wsflate
import (
"io"
)
// Decompressor is an interface holding deflate decompression implementation.
type Decompressor interface {
io.Reader
}
// ReadResetter is an optional interface that Decompressor can implement.
type ReadResetter interface {
Reset(io.Reader)
}
// Reader implements decompression from an io.Reader object using Decompressor.
// Essentially Reader is a thin wrapper around Decompressor interface to meet
// PMCE specs.
//
// After all data has been written client should call Flush() method.
// If any error occurs after reading from Reader, all subsequent calls to
// Read() or Close() will return the error.
//
// Reader might be reused for different io.Reader objects after its Reset()
// method has been called.
type Reader struct {
src io.Reader
ctor func(io.Reader) Decompressor
d Decompressor
sr suffixedReader
err error
}
// NewReader returns a new Reader.
func NewReader(r io.Reader, ctor func(io.Reader) Decompressor) *Reader {
ret := &Reader{
src: r,
ctor: ctor,
sr: suffixedReader{
suffix: compressionReadTail,
},
}
ret.Reset(r)
return ret
}
// Reset resets Reader to decompress data from src.
func (r *Reader) Reset(src io.Reader) {
r.err = nil
r.src = src
r.sr.reset(src)
if x, ok := r.d.(ReadResetter); ok {
x.Reset(r.sr.iface())
} else {
r.d = r.ctor(r.sr.iface())
}
}
// Read implements io.Reader.
func (r *Reader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
return r.d.Read(p)
}
// Close closes Reader and a Decompressor instance used under the hood (if it
// implements io.Closer interface).
func (r *Reader) Close() error {
if r.err != nil {
return r.err
}
if c, ok := r.d.(io.Closer); ok {
r.err = c.Close()
}
return r.err
}
// Err returns an error happened during any operation.
func (r *Reader) Err() error {
return r.err
}

129
vendor/github.com/gobwas/ws/wsflate/writer.go generated vendored Normal file
View file

@ -0,0 +1,129 @@
package wsflate
import (
"fmt"
"io"
)
var (
compressionTail = [4]byte{
0, 0, 0xff, 0xff,
}
compressionReadTail = [9]byte{
0, 0, 0xff, 0xff,
1,
0, 0, 0xff, 0xff,
}
)
// Compressor is an interface holding deflate compression implementation.
type Compressor interface {
io.Writer
Flush() error
}
// WriteResetter is an optional interface that Compressor can implement.
type WriteResetter interface {
Reset(io.Writer)
}
// Writer implements compression for an io.Writer object using Compressor.
// Essentially Writer is a thin wrapper around Compressor interface to meet
// PMCE specs.
//
// After all data has been written client should call Flush() method.
// If any error occurs after writing to or flushing a Writer, all subsequent
// calls to Write(), Flush() or Close() will return the error.
//
// Writer might be reused for different io.Writer objects after its Reset()
// method has been called.
type Writer struct {
// NOTE: Writer uses compressor constructor function instead of field to
// reach these goals:
// 1. To shrink Compressor interface and make it easier to be implemented.
// 2. If used as a field (and argument to the NewWriter()), Compressor object
// will probably be initialized twice - first time to pass into Writer, and
// second time during Writer initialization (which does Reset() internally).
// 3. To get rid of wrappers if Reset() would be a part of Compressor.
// E.g. non conformant implementations would have to provide it somehow,
// probably making a wrapper with the same constructor function.
// 4. To make Reader and Writer API the same. That is, there is no Reset()
// method for flate.Reader already, so we need to provide it as a wrapper
// (see point #3), or drop the Reader.Reset() method.
dest io.Writer
ctor func(io.Writer) Compressor
c Compressor
cbuf cbuf
err error
}
// NewWriter returns a new Writer.
func NewWriter(w io.Writer, ctor func(io.Writer) Compressor) *Writer {
// NOTE: NewWriter() is chosen against structure with exported fields here
// due its Reset() method, which in case of structure, would change
// exported field.
ret := &Writer{
dest: w,
ctor: ctor,
}
ret.Reset(w)
return ret
}
// Reset resets Writer to compress data into dest.
// Any not flushed data will be lost.
func (w *Writer) Reset(dest io.Writer) {
w.err = nil
w.cbuf.reset(dest)
if x, ok := w.c.(WriteResetter); ok {
x.Reset(&w.cbuf)
} else {
w.c = w.ctor(&w.cbuf)
}
}
// Write implements io.Writer.
func (w *Writer) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
n, w.err = w.c.Write(p)
return n, w.err
}
// Flush writes any pending data into w.Dest.
func (w *Writer) Flush() error {
if w.err != nil {
return w.err
}
w.err = w.c.Flush()
w.checkTail()
return w.err
}
// Close closes Writer and a Compressor instance used under the hood (if it
// implements io.Closer interface).
func (w *Writer) Close() error {
if w.err != nil {
return w.err
}
if c, ok := w.c.(io.Closer); ok {
w.err = c.Close()
}
w.checkTail()
return w.err
}
// Err returns an error happened during any operation.
func (w *Writer) Err() error {
return w.err
}
func (w *Writer) checkTail() {
if w.err == nil && w.cbuf.buf != compressionTail {
w.err = fmt.Errorf(
"wsflate: bad compressor: unexpected stream tail: %#x vs %#x",
w.cbuf.buf, compressionTail,
)
}
}

72
vendor/github.com/gobwas/ws/wsutil/cipher.go generated vendored Normal file
View file

@ -0,0 +1,72 @@
package wsutil
import (
"io"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// CipherReader implements io.Reader that applies xor-cipher to the bytes read
// from source.
// It could help to unmask WebSocket frame payload on the fly.
type CipherReader struct {
r io.Reader
mask [4]byte
pos int
}
// NewCipherReader creates xor-cipher reader from r with given mask.
func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader {
return &CipherReader{r, mask, 0}
}
// Reset resets CipherReader to read from r with given mask.
func (c *CipherReader) Reset(r io.Reader, mask [4]byte) {
c.r = r
c.mask = mask
c.pos = 0
}
// Read implements io.Reader interface. It applies mask given during
// initialization to every read byte.
func (c *CipherReader) Read(p []byte) (n int, err error) {
n, err = c.r.Read(p)
ws.Cipher(p[:n], c.mask, c.pos)
c.pos += n
return n, err
}
// CipherWriter implements io.Writer that applies xor-cipher to the bytes
// written to the destination writer. It does not modify the original bytes.
type CipherWriter struct {
w io.Writer
mask [4]byte
pos int
}
// NewCipherWriter creates xor-cipher writer to w with given mask.
func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter {
return &CipherWriter{w, mask, 0}
}
// Reset reset CipherWriter to write to w with given mask.
func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) {
c.w = w
c.mask = mask
c.pos = 0
}
// Write implements io.Writer interface. It applies masking during
// initialization to every sent byte. It does not modify original slice.
func (c *CipherWriter) Write(p []byte) (n int, err error) {
cp := pbytes.GetLen(len(p))
defer pbytes.Put(cp)
copy(cp, p)
ws.Cipher(cp, c.mask, c.pos)
n, err = c.w.Write(cp)
c.pos += n
return n, err
}

147
vendor/github.com/gobwas/ws/wsutil/dialer.go generated vendored Normal file
View file

@ -0,0 +1,147 @@
package wsutil
import (
"bufio"
"bytes"
"context"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/gobwas/ws"
)
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
// handshake. That is, it gives ability to receive copied HTTP request and
// response bytes that made inside Dialer.Dial().
//
// Note that it must not be used in production applications that requires
// Dial() to be efficient.
type DebugDialer struct {
// Dialer contains WebSocket connection establishment options.
Dialer ws.Dialer
// OnRequest and OnResponse are the callbacks that will be called with the
// HTTP request and response respectively.
OnRequest, OnResponse func([]byte)
}
// Dial connects to the url host and upgrades connection to WebSocket. It makes
// it by calling d.Dialer.Dial().
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
// Need to copy Dialer to prevent original object mutation.
dialer := d.Dialer
var (
reqBuf bytes.Buffer
resBuf bytes.Buffer
resContentLength int64
)
userWrap := dialer.WrapConn
dialer.WrapConn = func(c net.Conn) net.Conn {
if userWrap != nil {
c = userWrap(c)
}
// Save the pointer to the raw connection.
conn = c
var (
r io.Reader = conn
w io.Writer = conn
)
if d.OnResponse != nil {
r = &prefetchResponseReader{
source: conn,
buffer: &resBuf,
contentLength: &resContentLength,
}
}
if d.OnRequest != nil {
w = io.MultiWriter(conn, &reqBuf)
}
return rwConn{conn, r, w}
}
_, br, hs, err = dialer.Dial(ctx, urlstr)
if onRequest := d.OnRequest; onRequest != nil {
onRequest(reqBuf.Bytes())
}
if onResponse := d.OnResponse; onResponse != nil {
// We must split response inside buffered bytes from other received
// bytes from server.
p := resBuf.Bytes()
n := bytes.Index(p, headEnd)
h := n + len(headEnd) // Head end index.
n = h + int(resContentLength) // Body end index.
onResponse(p[:n])
if br != nil {
// If br is non-nil, then it mean two things. First is that
// handshake is OK and server has sent additional bytes probably
// immediate sent frames (or weird but possible response body).
// Second, the bad one, is that br buffer's source is now rwConn
// instance from above WrapConn call. It is incorrect, so we must
// fix it.
var r io.Reader = conn
if len(p) > h {
// Buffer contains more than just HTTP headers bytes.
r = io.MultiReader(
bytes.NewReader(p[h:]),
conn,
)
}
br.Reset(r)
// Must make br.Buffered() to be non-zero.
br.Peek(len(p[h:]))
}
}
return conn, br, hs, err
}
type rwConn struct {
net.Conn
r io.Reader
w io.Writer
}
func (rwc rwConn) Read(p []byte) (int, error) {
return rwc.r.Read(p)
}
func (rwc rwConn) Write(p []byte) (int, error) {
return rwc.w.Write(p)
}
var headEnd = []byte("\r\n\r\n")
type prefetchResponseReader struct {
source io.Reader // Original connection source.
reader io.Reader // Wrapped reader used to read from by clients.
buffer *bytes.Buffer
contentLength *int64
}
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
if r.reader == nil {
resp, err := http.ReadResponse(bufio.NewReader(
io.TeeReader(r.source, r.buffer),
), nil)
if err == nil {
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
bts := r.buffer.Bytes()
r.reader = io.MultiReader(
bytes.NewReader(bts),
r.source,
)
}
return r.reader.Read(p)
}

31
vendor/github.com/gobwas/ws/wsutil/extenstion.go generated vendored Normal file
View file

@ -0,0 +1,31 @@
package wsutil
import "github.com/gobwas/ws"
// RecvExtension is an interface for clearing fragment header RSV bits.
type RecvExtension interface {
UnsetBits(ws.Header) (ws.Header, error)
}
// RecvExtensionFunc is an adapter to allow the use of ordinary functions as
// RecvExtension.
type RecvExtensionFunc func(ws.Header) (ws.Header, error)
// BitsRecv implements RecvExtension.
func (fn RecvExtensionFunc) UnsetBits(h ws.Header) (ws.Header, error) {
return fn(h)
}
// SendExtension is an interface for setting fragment header RSV bits.
type SendExtension interface {
SetBits(ws.Header) (ws.Header, error)
}
// SendExtensionFunc is an adapter to allow the use of ordinary functions as
// SendExtension.
type SendExtensionFunc func(ws.Header) (ws.Header, error)
// BitsSend implements SendExtension.
func (fn SendExtensionFunc) SetBits(h ws.Header) (ws.Header, error) {
return fn(h)
}

219
vendor/github.com/gobwas/ws/wsutil/handler.go generated vendored Normal file
View file

@ -0,0 +1,219 @@
package wsutil
import (
"errors"
"io"
"io/ioutil"
"strconv"
"github.com/gobwas/pool/pbytes"
"github.com/gobwas/ws"
)
// ClosedError returned when peer has closed the connection with appropriate
// code and a textual reason.
type ClosedError struct {
Code ws.StatusCode
Reason string
}
// Error implements error interface.
func (err ClosedError) Error() string {
return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason
}
// ControlHandler contains logic of handling control frames.
//
// The intentional way to use it is to read the next frame header from the
// connection, optionally check its validity via ws.CheckHeader() and if it is
// not a ws.OpText of ws.OpBinary (or ws.OpContinuation) pass it to Handle()
// method.
//
// That is, passed header should be checked to get rid of unexpected errors.
//
// The Handle() method will read out all control frame payload (if any) and
// write necessary bytes as a rfc compatible response.
type ControlHandler struct {
Src io.Reader
Dst io.Writer
State ws.State
// DisableSrcCiphering disables unmasking payload data read from Src.
// It is useful when wsutil.Reader is used or when frame payload already
// pulled and ciphered out from the connection (and introduced by
// bytes.Reader, for example).
DisableSrcCiphering bool
}
// ErrNotControlFrame is returned by ControlHandler to indicate that given
// header could not be handled.
var ErrNotControlFrame = errors.New("not a control frame")
// Handle handles control frames regarding to the c.State and writes responses
// to the c.Dst when needed.
//
// It returns ErrNotControlFrame when given header is not of ws.OpClose,
// ws.OpPing or ws.OpPong operation code.
func (c ControlHandler) Handle(h ws.Header) error {
switch h.OpCode {
case ws.OpPing:
return c.HandlePing(h)
case ws.OpPong:
return c.HandlePong(h)
case ws.OpClose:
return c.HandleClose(h)
}
return ErrNotControlFrame
}
// HandlePing handles ping frame and writes specification compatible response
// to the c.Dst.
func (c ControlHandler) HandlePing(h ws.Header) error {
if h.Length == 0 {
// The most common case when ping is empty.
// Note that when sending masked frame the mask for empty payload is
// just four zero bytes.
return ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpPong,
Masked: c.State.ClientSide(),
})
}
// In other way reply with Pong frame with copied payload.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)
// Deal with ciphering i/o:
// Masking key is used to mask the "Payload data" defined in the same
// section as frame-payload-data, which includes "Extension data" and
// "Application data".
//
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// NOTE: We prefer ControlWriter with preallocated buffer to
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p)
r := c.Src
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
_, err := io.Copy(w, r)
if err == nil {
err = w.Flush()
}
return err
}
// HandlePong handles pong frame by discarding it.
func (c ControlHandler) HandlePong(h ws.Header) error {
if h.Length == 0 {
return nil
}
buf := pbytes.GetLen(int(h.Length))
defer pbytes.Put(buf)
// Discard pong message according to the RFC6455:
// A Pong frame MAY be sent unsolicited. This serves as a
// unidirectional heartbeat. A response to an unsolicited Pong frame
// is not expected.
_, err := io.CopyBuffer(ioutil.Discard, c.Src, buf)
return err
}
// HandleClose handles close frame, makes protocol validity checks and writes
// specification compatible response to the c.Dst.
func (c ControlHandler) HandleClose(h ws.Header) error {
if h.Length == 0 {
err := ws.WriteHeader(c.Dst, ws.Header{
Fin: true,
OpCode: ws.OpClose,
Masked: c.State.ClientSide(),
})
if err != nil {
return err
}
// Due to RFC, we should interpret the code as no status code
// received:
// If this Close control frame contains no status code, _The WebSocket
// Connection Close Code_ is considered to be 1005.
//
// See https://tools.ietf.org/html/rfc6455#section-7.1.5
return ClosedError{
Code: ws.StatusNoStatusRcvd,
}
}
// Prepare bytes both for reading reason and sending response.
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
Length: h.Length,
Masked: c.State.ClientSide(),
}))
defer pbytes.Put(p)
// Get the subslice to read the frame payload out.
subp := p[:h.Length]
r := c.Src
if c.State.ServerSide() && !c.DisableSrcCiphering {
r = NewCipherReader(r, h.Mask)
}
if _, err := io.ReadFull(r, subp); err != nil {
return err
}
code, reason := ws.ParseCloseFrameData(subp)
if err := ws.CheckCloseFrameData(code, reason); err != nil {
// Here we could not use the prepared bytes because there is no
// guarantee that it may fit our protocol error closure code and a
// reason.
c.closeWithProtocolError(err)
return err
}
// Deal with ciphering i/o:
// Masking key is used to mask the "Payload data" defined in the same
// section as frame-payload-data, which includes "Extension data" and
// "Application data".
//
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// NOTE: We prefer ControlWriter with preallocated buffer to
// ws.WriteHeader because it performs one syscall instead of two.
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p)
// RFC6455#5.5.1:
// If an endpoint receives a Close frame and did not previously
// send a Close frame, the endpoint MUST send a Close frame in
// response. (When sending a Close frame in response, the endpoint
// typically echoes the status code it received.)
_, err := w.Write(p[:2])
if err != nil {
return err
}
if err := w.Flush(); err != nil {
return err
}
return ClosedError{
Code: code,
Reason: reason,
}
}
func (c ControlHandler) closeWithProtocolError(reason error) error {
f := ws.NewCloseFrame(ws.NewCloseFrameBody(
ws.StatusProtocolError, reason.Error(),
))
if c.State.ClientSide() {
ws.MaskFrameInPlace(f)
}
return ws.WriteFrame(c.Dst, f)
}

Some files were not shown because too many files have changed in this diff Show more