package fasthttp import ( "crypto/tls" "net" "sync" ) type perIPConnCounter struct { perIPConnPool sync.Pool perIPTLSConnPool sync.Pool m map[uint32]int lock sync.Mutex } func (cc *perIPConnCounter) Register(ip uint32) int { cc.lock.Lock() if cc.m == nil { cc.m = make(map[uint32]int) } n := cc.m[ip] + 1 cc.m[ip] = n cc.lock.Unlock() return n } func (cc *perIPConnCounter) Unregister(ip uint32) { cc.lock.Lock() defer cc.lock.Unlock() if cc.m == nil { // developer safeguard panic("BUG: perIPConnCounter.Register() wasn't called") } n := cc.m[ip] - 1 if n < 0 { n = 0 } cc.m[ip] = n } type perIPConn struct { net.Conn perIPConnCounter *perIPConnCounter ip uint32 } type perIPTLSConn struct { *tls.Conn perIPConnCounter *perIPConnCounter ip uint32 } func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn { if tlsConn, ok := conn.(*tls.Conn); ok { v := counter.perIPTLSConnPool.Get() if v == nil { return &perIPTLSConn{ perIPConnCounter: counter, Conn: tlsConn, ip: ip, } } c := v.(*perIPTLSConn) c.Conn = tlsConn c.ip = ip return c } v := counter.perIPConnPool.Get() if v == nil { return &perIPConn{ perIPConnCounter: counter, Conn: conn, ip: ip, } } c := v.(*perIPConn) c.Conn = conn c.ip = ip return c } func (c *perIPConn) Close() error { err := c.Conn.Close() c.perIPConnCounter.Unregister(c.ip) c.Conn = nil c.perIPConnCounter.perIPConnPool.Put(c) return err } func (c *perIPTLSConn) Close() error { err := c.Conn.Close() c.perIPConnCounter.Unregister(c.ip) c.Conn = nil c.perIPConnCounter.perIPTLSConnPool.Put(c) return err } func getUint32IP(c net.Conn) uint32 { return ip2uint32(getConnIP4(c)) } func getConnIP4(c net.Conn) net.IP { addr := c.RemoteAddr() ipAddr, ok := addr.(*net.TCPAddr) if !ok { return net.IPv4zero } return ipAddr.IP.To4() } func ip2uint32(ip net.IP) uint32 { if len(ip) != 4 { return 0 } return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } func uint322ip(ip uint32) net.IP { b := make([]byte, 4) b[0] = byte(ip >> 24) b[1] = byte(ip >> 16) b[2] = byte(ip >> 8) b[3] = byte(ip) return b }