1 // Package unixconn facilitates creating reverse proxy connections 2 package unixconn // import "vimagination.zapto.org/reverseproxy/unixconn" 3 4 import ( 5 "errors" 6 "net" 7 "net/http" 8 "os" 9 "runtime" 10 "strconv" 11 "sync" 12 "sync/atomic" 13 "syscall" 14 "time" 15 ) 16 17 type buffer [http.DefaultMaxHeaderBytes]byte 18 19 type ns struct { 20 c chan net.Conn 21 err error 22 } 23 24 var ( 25 fallback = uint32(1) 26 ucMu sync.Mutex 27 uc *net.UnixConn 28 listeningSockets map[uint16]struct{} 29 newSocket chan ns 30 bufPool = sync.Pool{ 31 New: func() interface{} { 32 return new(buffer) 33 }, 34 } 35 ) 36 37 func init() { 38 c, err := net.FileConn(os.NewFile(3, "")) 39 if err == nil { 40 u, ok := c.(*net.UnixConn) 41 uc = u 42 if ok { 43 fallback = 0 44 newSocket = make(chan ns) 45 listeningSockets = make(map[uint16]struct{}) 46 go runListenLoop() 47 } 48 } 49 } 50 51 func runListenLoop() { 52 buf := bufPool.Get().(*buffer) 53 oob := make([]byte, syscall.CmsgLen(4)) 54 sockets := make(map[uint16]chan net.Conn) 55 for { 56 n, oobn, _, _, err := uc.ReadMsgUnix(buf[:], oob[:]) 57 if err != nil { 58 for _, c := range sockets { 59 close(c) 60 } 61 atomic.StoreUint32(&fallback, 1) 62 break 63 } 64 if oobn == 0 { 65 if n == 2 { 66 port := uint16(buf[1])<<8 | uint16(buf[0]) 67 if s, ok := sockets[port]; ok { 68 close(s) 69 delete(sockets, port) 70 delete(listeningSockets, port) 71 } else { 72 listeningSockets[port] = struct{}{} 73 c := make(chan net.Conn) 74 sockets[port] = c 75 newSocket <- ns{c: c} 76 } 77 } else if n > 2 { 78 newSocket <- ns{err: errors.New(string(buf[2:n]))} 79 } 80 } else if msg, err := syscall.ParseSocketControlMessage(oob[:oobn]); err == nil && len(msg) == 1 { 81 if fd, err := syscall.ParseUnixRights(&msg[0]); err == nil && len(fd) == 1 { 82 nf := os.NewFile(uintptr(fd[0]), "") 83 if cn, err := net.FileConn(nf); err == nil { 84 if ra := cn.RemoteAddr(); ra != nil { 85 var port uint16 86 if tcpaddr, ok := cn.LocalAddr().(*net.TCPAddr); ok { 87 port = uint16(tcpaddr.Port) 88 } else { 89 port = getPort(cn.LocalAddr().String()) 90 } 91 c, ok := sockets[port] 92 if ok { 93 if ka, ok := cn.(keepAlive); ok { 94 if err := ka.SetKeepAlive(true); err == nil { 95 ka.SetKeepAlivePeriod(3 * time.Minute) 96 } 97 } 98 cc := &conn{ 99 Conn: cn, 100 buf: buf, 101 length: n, 102 } 103 buf = bufPool.Get().(*buffer) 104 runtime.SetFinalizer(cc, (*conn).Close) 105 go sendConn(c, cc) 106 continue 107 } else { 108 cn.Close() 109 } 110 } else { 111 cn.Close() 112 } 113 } 114 nf.Close() 115 } 116 } 117 for n := range buf[:n] { 118 buf[n] = 0 119 } 120 } 121 } 122 123 func sendConn(c chan net.Conn, conn *conn) { 124 t := time.NewTimer(time.Minute * 3) 125 select { 126 case <-t.C: 127 conn.Close() 128 case c <- conn: 129 } 130 t.Stop() 131 } 132 133 type keepAlive interface { 134 SetKeepAlive(bool) error 135 SetKeepAlivePeriod(time.Duration) error 136 } 137 138 type conn struct { 139 net.Conn 140 buf *buffer 141 pos int 142 length int 143 } 144 145 func (c *conn) Read(b []byte) (int, error) { 146 if c.buf != nil { 147 n := copy(b, c.buf[c.pos:c.length]) 148 c.pos += n 149 if c.pos == c.length { 150 c.clearBuffer() 151 } 152 return n, nil 153 } 154 return c.Conn.Read(b) 155 } 156 157 func (c *conn) clearBuffer() { 158 for n := range c.buf[:c.length] { 159 c.buf[n] = 0 160 } 161 bufPool.Put(c.buf) 162 c.buf = nil 163 } 164 165 func (c *conn) Close() error { 166 if c.buf != nil { 167 c.clearBuffer() 168 } 169 runtime.SetFinalizer(c, nil) 170 return c.Conn.Close() 171 } 172 173 type listener struct { 174 socket uint16 175 c chan net.Conn 176 addr 177 } 178 179 func (l *listener) Accept() (net.Conn, error) { 180 c, ok := <-l.c 181 if !ok { 182 return nil, net.ErrClosed 183 } 184 return c, nil 185 } 186 187 func (l *listener) Close() error { 188 if l.socket == 0 { 189 return net.ErrClosed 190 } 191 runtime.SetFinalizer(l, nil) 192 var buf [2]byte 193 buf[0] = byte(l.socket) 194 buf[1] = byte(l.socket >> 8) 195 l.socket = 0 196 ucMu.Lock() 197 _, _, err := uc.WriteMsgUnix(buf[:], nil, nil) 198 ucMu.Unlock() 199 return err 200 } 201 202 func (l *listener) Addr() net.Addr { 203 return l.addr 204 } 205 206 type addr struct { 207 network, address string 208 } 209 210 func (a addr) Network() string { 211 return a.network 212 } 213 214 func (a addr) String() string { 215 return a.address 216 } 217 218 // Listen creates a reverse proxy connection, falling back to the net package if 219 // the reverse proxy is not available 220 func Listen(network, address string) (net.Listener, error) { 221 if atomic.LoadUint32(&fallback) == 1 { 222 return net.Listen(network, address) 223 } 224 port := getPort(address) 225 if port == 0 { 226 return nil, ErrInvalidAddress 227 } 228 var buf [2]byte 229 buf[0] = byte(port) 230 buf[1] = byte(port >> 8) 231 ucMu.Lock() 232 if _, ok := listeningSockets[port]; ok { 233 ucMu.Unlock() 234 return nil, ErrAlreadyListening 235 } 236 _, _, err := uc.WriteMsgUnix(buf[:], nil, nil) 237 if err != nil { 238 ucMu.Unlock() 239 return nil, err 240 } 241 ns := <-newSocket 242 ucMu.Unlock() 243 if ns.err != nil { 244 return nil, ns.err 245 } 246 l := &listener{ 247 socket: port, 248 c: ns.c, 249 addr: addr{ 250 network: network, 251 address: address, 252 }, 253 } 254 runtime.SetFinalizer(l, (*listener).Close) 255 return l, nil 256 } 257 258 func getPort(address string) uint16 { 259 _, portStr, _ := net.SplitHostPort(address) 260 port, _ := strconv.ParseUint(portStr, 10, 16) 261 return uint16(port) 262 } 263 264 // Errors 265 var ( 266 ErrInvalidAddress = errors.New("port must be 0 < port < 2^16") 267 ErrAlreadyListening = errors.New("port already being listened on") 268 ) 269