1 // Package reverseproxy implements a basic HTTP/TLS connection forwarder based either the passed Host header or SNI extension 2 package reverseproxy // import "vimagination.zapto.org/reverseproxy" 3 4 import ( 5 "errors" 6 "io" 7 "net" 8 "net/http" 9 "sync" 10 ) 11 12 var ( 13 lMu sync.RWMutex 14 listeners = make(map[uint16]*listener) 15 ) 16 17 type listener struct { 18 *net.TCPListener 19 20 mu sync.RWMutex 21 ports map[*Port]struct{} 22 } 23 24 var ( 25 httpPool = sync.Pool{ 26 New: func() interface{} { 27 b := make([]byte, http.DefaultMaxHeaderBytes) 28 return &b 29 }, 30 } 31 tlsPool = sync.Pool{ 32 New: func() interface{} { 33 b := make([]byte, maxTLSRead) 34 return &b 35 }, 36 } 37 ) 38 39 func (l *listener) listen() { 40 for { 41 c, err := l.AcceptTCP() 42 if err != nil { 43 if errors.Is(err, net.ErrClosed) { 44 return 45 } 46 l.Close() 47 l.mu.Lock() 48 for p := range l.ports { 49 p.closed = true 50 delete(l.ports, p) 51 } 52 l.mu.Unlock() 53 return 54 } 55 go l.transfer(c) 56 } 57 } 58 59 func (l *listener) transfer(c *net.TCPConn) { 60 var tlsByte [1]byte 61 if n, err := io.ReadFull(c, tlsByte[:]); n == 1 && err == nil { 62 var ( 63 name string 64 pool *sync.Pool 65 readServerName func(io.Reader, []byte) (string, []byte, error) 66 ) 67 if tlsByte[0] == 22 { 68 pool = &tlsPool 69 readServerName = readTLSServerName 70 } else { 71 pool = &httpPool 72 readServerName = readHTTPServerName 73 } 74 b := pool.Get().(*[]byte) 75 buf := *b 76 buf[0] = tlsByte[0] 77 name, buf, err = readServerName(c, buf) 78 if err == nil { 79 if host, _, err := net.SplitHostPort(name); err == nil { 80 name = host 81 } 82 var port *Port 83 l.mu.RLock() 84 for p := range l.ports { 85 if p.MatchService(name) { 86 port = p 87 break 88 } 89 } 90 l.mu.RUnlock() 91 if port != nil { 92 port.Transfer(buf, c) 93 } 94 } else { 95 c.Close() 96 } 97 for n := range buf { 98 buf[n] = 0 99 } 100 pool.Put(b) 101 } else { 102 c.Close() 103 } 104 } 105 106 type service interface { 107 MatchServiceName 108 Transfer([]byte, *net.TCPConn) 109 Active() bool 110 } 111 112 // Port represents a service waiting on a port 113 type Port struct { 114 service 115 port uint16 116 closed bool 117 } 118 119 func addPort(port uint16, service service) (*Port, error) { 120 if port == 0 { 121 return nil, ErrInvalidPort 122 } 123 lMu.Lock() 124 l, ok := listeners[port] 125 if !ok { 126 nl, err := net.ListenTCP("tcp", &net.TCPAddr{Port: int(port)}) 127 if err != nil { 128 return nil, err 129 } 130 l = &listener{ 131 TCPListener: nl, 132 ports: make(map[*Port]struct{}), 133 } 134 go l.listen() 135 listeners[port] = l 136 } 137 lMu.Unlock() 138 p := &Port{ 139 service: service, 140 port: port, 141 } 142 l.mu.Lock() 143 l.ports[p] = struct{}{} 144 l.mu.Unlock() 145 return p, nil 146 } 147 148 // Close closes this port connection 149 func (p *Port) Close() error { 150 lMu.Lock() 151 if !p.closed { 152 l, ok := listeners[p.port] 153 if ok { 154 l.mu.Lock() 155 delete(l.ports, p) 156 if len(l.ports) == 0 { 157 delete(listeners, p.port) 158 l.Close() 159 } 160 l.mu.Unlock() 161 } 162 p.closed = true 163 } 164 lMu.Unlock() 165 return nil 166 } 167 168 // Closed returns whether the port has been closed or not 169 func (p *Port) Closed() bool { 170 return p.closed 171 } 172 173 // Status constains the status of a Port 174 type Status struct { 175 Ports []uint16 176 Closing, Active bool 177 } 178 179 // Status retrieves the status of a Port 180 func (p *Port) Status() Status { 181 lMu.RLock() 182 closed := p.closed 183 lMu.RUnlock() 184 return Status{ 185 Ports: []uint16{p.port}, 186 Closing: closed, 187 Active: p.service.Active(), 188 } 189 } 190 191 // Errors 192 var ( 193 ErrInvalidPort = errors.New("cannot register on port 0") 194 ) 195