well-goknown/vendor/github.com/nbd-wtf/go-nostr/connection.go

182 lines
4.2 KiB
Go

package nostr
import (
"bytes"
"compress/flate"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
)
type Connection struct {
conn net.Conn
enableCompression bool
controlHandler wsutil.FrameHandlerFunc
flateReader *wsflate.Reader
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
}
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(requestHeader),
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
},
TLSConfig: tlsConfig,
}
conn, _, hs, err := dialer.Dial(ctx, url)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}
// reader
var flateReader *wsflate.Reader
var msgStateR wsflate.MessageState
if enableCompression {
msgStateR.SetCompressed(true)
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
}
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
reader := &wsutil.Reader{
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgStateR,
},
}
// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
InfoLogger.Printf("Failed to create flate writer: %v", err)
}
return fw
})
}
writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgStateW)
return &Connection{
conn: conn,
enableCompression: enableCompression,
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
writer: writer,
msgStateW: &msgStateW,
}, nil
}
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}
if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
if err := c.flateWriter.Close(); err != nil {
return fmt.Errorf("failed to close flate writer: %w", err)
}
} else {
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
}
if err := c.writer.Flush(); err != nil {
return fmt.Errorf("failed to flush writer: %w", err)
}
return nil
}
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
for {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}
h, err := c.reader.NextFrame()
if err != nil {
c.conn.Close()
return fmt.Errorf("failed to advance frame: %w", err)
}
if h.OpCode.IsControl() {
if err := c.controlHandler(h, c.reader); err != nil {
return fmt.Errorf("failed to handle control frame: %w", err)
}
} else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText {
break
}
if err := c.reader.Discard(); err != nil {
return fmt.Errorf("failed to discard: %w", err)
}
}
if c.msgStateR.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
} else {
if _, err := io.Copy(buf, c.reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
}
return nil
}
func (c *Connection) Close() error {
return c.conn.Close()
}