1 // Package unixconn facilitates creating reverse proxy connections. 2 package unixconn // import "vimagination.zapto.org/reverseproxy/unixconn" 3 4 import ( 5 "errors" 6 "net" 7 "net/http" 8 "os" 9 "runtime" 10 "strconv" 11 "sync" 12 "sync/atomic" 13 "syscall" 14 "time" 15 ) 16 17 type buffer [http.DefaultMaxHeaderBytes]byte 18 19 type ns struct { 20 c chan net.Conn 21 err error 22 } 23 24 var ( 25 fallback = uint32(1) 26 ucMu sync.Mutex 27 uc *net.UnixConn 28 listeningSockets map[uint16]struct{} 29 newSocket chan ns 30 bufPool = sync.Pool{ 31 New: func() interface{} { 32 return new(buffer) 33 }, 34 } 35 ) 36 37 func init() { 38 c, err := net.FileConn(os.NewFile(3, "")) 39 if err == nil { 40 u, ok := c.(*net.UnixConn) 41 uc = u 42 43 if ok { 44 fallback = 0 45 newSocket = make(chan ns) 46 listeningSockets = make(map[uint16]struct{}) 47 48 go runListenLoop() 49 } 50 } 51 } 52 53 func runListenLoop() { 54 buf := bufPool.Get().(*buffer) 55 oob := make([]byte, syscall.CmsgLen(4)) 56 sockets := make(map[uint16]chan net.Conn) 57 58 for { 59 n, oobn, _, _, err := uc.ReadMsgUnix(buf[:], oob) 60 if err != nil { 61 for _, c := range sockets { 62 close(c) 63 } 64 65 atomic.StoreUint32(&fallback, 1) 66 67 break 68 } 69 70 if oobn == 0 { 71 if n == 2 { 72 port := uint16(buf[1])<<8 | uint16(buf[0]) 73 if s, ok := sockets[port]; ok { 74 close(s) 75 delete(sockets, port) 76 delete(listeningSockets, port) 77 } else { 78 listeningSockets[port] = struct{}{} 79 c := make(chan net.Conn) 80 sockets[port] = c 81 newSocket <- ns{c: c} 82 } 83 } else if n > 2 { 84 newSocket <- ns{err: errors.New(string(buf[2:n]))} 85 } 86 } else if msg, err := syscall.ParseSocketControlMessage(oob[:oobn]); err == nil && len(msg) == 1 { 87 if fd, err := syscall.ParseUnixRights(&msg[0]); err == nil && len(fd) == 1 { 88 nf := os.NewFile(uintptr(fd[0]), "") 89 if cn, err := net.FileConn(nf); err == nil { 90 if ra := cn.RemoteAddr(); ra != nil { 91 var port uint16 92 93 if tcpaddr, ok := cn.LocalAddr().(*net.TCPAddr); ok { 94 port = uint16(tcpaddr.Port) 95 } else { 96 port = getPort(cn.LocalAddr().String()) 97 } 98 99 c, ok := sockets[port] 100 if ok { 101 if ka, ok := cn.(keepAlive); ok { 102 if err := ka.SetKeepAlive(true); err == nil { 103 ka.SetKeepAlivePeriod(3 * time.Minute) 104 } 105 } 106 107 cc := &conn{ 108 Conn: cn, 109 buf: buf, 110 length: n, 111 } 112 113 buf = bufPool.Get().(*buffer) 114 115 runtime.SetFinalizer(cc, (*conn).Close) 116 117 go sendConn(c, cc) 118 119 continue 120 } else { 121 cn.Close() 122 } 123 } else { 124 cn.Close() 125 } 126 } 127 128 nf.Close() 129 } 130 } 131 132 for n := range buf[:n] { 133 buf[n] = 0 134 } 135 } 136 } 137 138 func sendConn(c chan net.Conn, conn *conn) { 139 t := time.NewTimer(time.Minute * 3) 140 141 select { 142 case <-t.C: 143 conn.Close() 144 case c <- conn: 145 } 146 147 t.Stop() 148 } 149 150 type keepAlive interface { 151 SetKeepAlive(bool) error 152 SetKeepAlivePeriod(time.Duration) error 153 } 154 155 type conn struct { 156 net.Conn 157 buf *buffer 158 pos int 159 length int 160 } 161 162 func (c *conn) Read(b []byte) (int, error) { 163 if c.buf != nil { 164 n := copy(b, c.buf[c.pos:c.length]) 165 c.pos += n 166 167 if c.pos == c.length { 168 c.clearBuffer() 169 } 170 171 return n, nil 172 } 173 174 return c.Conn.Read(b) 175 } 176 177 func (c *conn) clearBuffer() { 178 for n := range c.buf[:c.length] { 179 c.buf[n] = 0 180 } 181 182 bufPool.Put(c.buf) 183 184 c.buf = nil 185 } 186 187 func (c *conn) Close() error { 188 if c.buf != nil { 189 c.clearBuffer() 190 } 191 192 runtime.SetFinalizer(c, nil) 193 194 return c.Conn.Close() 195 } 196 197 type listener struct { 198 socket uint16 199 c chan net.Conn 200 addr 201 } 202 203 func (l *listener) Accept() (net.Conn, error) { 204 c, ok := <-l.c 205 if !ok { 206 return nil, net.ErrClosed 207 } 208 209 return c, nil 210 } 211 212 func (l *listener) Close() error { 213 if l.socket == 0 { 214 return net.ErrClosed 215 } 216 217 runtime.SetFinalizer(l, nil) 218 219 buf := [2]byte{byte(l.socket), byte(l.socket >> 8)} 220 221 l.socket = 0 222 223 ucMu.Lock() 224 _, _, err := uc.WriteMsgUnix(buf[:], nil, nil) 225 ucMu.Unlock() 226 227 return err 228 } 229 230 func (l *listener) Addr() net.Addr { 231 return l.addr 232 } 233 234 type addr struct { 235 network, address string 236 } 237 238 func (a addr) Network() string { 239 return a.network 240 } 241 242 func (a addr) String() string { 243 return a.address 244 } 245 246 // Listen creates a reverse proxy connection, falling back to the net package if 247 // the reverse proxy is not available. 248 func Listen(network, address string) (net.Listener, error) { 249 if atomic.LoadUint32(&fallback) == 1 { 250 return net.Listen(network, address) 251 } 252 253 port := getPort(address) 254 if port == 0 { 255 return nil, ErrInvalidAddress 256 } 257 258 buf := [2]byte{byte(port), byte(port >> 8)} 259 260 ucMu.Lock() 261 262 if _, ok := listeningSockets[port]; ok { 263 ucMu.Unlock() 264 265 return nil, ErrAlreadyListening 266 } 267 268 _, _, err := uc.WriteMsgUnix(buf[:], nil, nil) 269 if err != nil { 270 ucMu.Unlock() 271 272 return nil, err 273 } 274 275 nss := <-newSocket 276 277 ucMu.Unlock() 278 279 if nss.err != nil { 280 return nil, nss.err 281 } 282 283 l := &listener{ 284 socket: port, 285 c: nss.c, 286 addr: addr{ 287 network: network, 288 address: address, 289 }, 290 } 291 292 runtime.SetFinalizer(l, (*listener).Close) 293 294 return l, nil 295 } 296 297 func getPort(address string) uint16 { 298 _, portStr, _ := net.SplitHostPort(address) 299 port, _ := strconv.ParseUint(portStr, 10, 16) 300 301 return uint16(port) 302 } 303 304 // Errors. 305 var ( 306 ErrInvalidAddress = errors.New("port must be 0 < port < 2^16") 307 ErrAlreadyListening = errors.New("port already being listened on") 308 ) 309