1 package reverseproxy 2 3 import ( 4 "errors" 5 "net" 6 "os" 7 "os/exec" 8 "sync" 9 "sync/atomic" 10 "syscall" 11 ) 12 13 type unixService struct { 14 transferring uint64 15 MatchServiceName 16 conn *net.UnixConn 17 } 18 19 func (u *unixService) Transfer(buf []byte, conn *net.TCPConn) { 20 f, err := conn.File() 21 conn.Close() 22 if err == nil { 23 atomic.AddUint64(&u.transferring, 1) 24 u.conn.WriteMsgUnix(buf, syscall.UnixRights(int(f.Fd())), nil) 25 atomic.AddUint64(&u.transferring, ^uint64(0)) 26 f.Close() 27 } 28 } 29 30 func (u *unixService) Active() bool { 31 return atomic.LoadUint64(&u.transferring) > 0 32 } 33 34 // UnixCmd holds the information required to control (close) a server and its 35 // resources 36 type UnixCmd struct { 37 cmd *exec.Cmd 38 conn *net.UnixConn 39 40 mu sync.Mutex 41 open map[uint16]*Port 42 closed bool 43 exited bool 44 } 45 46 // Close closes all ports for the server and sends a signal to the server to 47 // close 48 func (u *UnixCmd) Close() error { 49 u.mu.Lock() 50 if u.closed { 51 u.mu.Unlock() 52 return ErrClosed 53 } 54 for port, p := range u.open { 55 delete(u.open, port) 56 p.Close() 57 } 58 err := u.conn.Close() 59 errr := u.cmd.Process.Signal(os.Interrupt) 60 u.closed = true 61 u.mu.Unlock() 62 if err != nil { 63 return err 64 } 65 return errr 66 } 67 68 // Status retrieves the Status of the UnixCmd 69 func (u *UnixCmd) Status() Status { 70 u.mu.Lock() 71 closed := u.closed 72 ports := make([]uint16, 0, len(u.open)) 73 for p := range u.open { 74 ports = append(ports, p) 75 } 76 u.mu.Unlock() 77 return Status{ 78 Ports: ports, 79 Closing: closed, 80 Active: !u.exited, 81 } 82 } 83 84 // RegisterCmd runs the given command and waits for incoming listeners from it 85 func RegisterCmd(msn MatchServiceName, cmd *exec.Cmd) (*UnixCmd, error) { 86 fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) 87 if err != nil { 88 return nil, err 89 } 90 nf := os.NewFile(uintptr(fds[0]), "") 91 fconn, err := net.FileConn(nf) 92 if err := nf.Close(); err != nil { 93 return nil, err 94 } 95 if err != nil { 96 return nil, err 97 } 98 f := os.NewFile(uintptr(fds[1]), "") 99 cmd.ExtraFiles = append([]*os.File{}, f) 100 err = cmd.Start() 101 f.Close() 102 if err != nil { 103 return nil, err 104 } 105 u := &UnixCmd{ 106 cmd: cmd, 107 conn: fconn.(*net.UnixConn), 108 open: make(map[uint16]*Port), 109 } 110 go u.runCmdLoop(msn) 111 return u, nil 112 } 113 114 func (u *UnixCmd) runCmdLoop(msn MatchServiceName) { 115 var ( 116 buf [2]byte 117 srv = &unixService{ 118 MatchServiceName: msn, 119 conn: u.conn, 120 } 121 ) 122 for { 123 n, _, _, _, err := u.conn.ReadMsgUnix(buf[:], nil) 124 if err != nil { 125 u.mu.Lock() 126 if !u.closed { 127 for port, p := range u.open { 128 delete(u.open, port) 129 p.Close() 130 } 131 u.conn.Close() 132 u.closed = true 133 } 134 u.mu.Unlock() 135 u.cmd.Wait() 136 u.mu.Lock() 137 u.exited = true 138 u.mu.Unlock() 139 return 140 } 141 if n < 2 { 142 continue 143 } 144 u.mu.Lock() 145 if !u.closed { 146 port := uint16(buf[1])<<8 | uint16(buf[0]) 147 if p, ok := u.open[port]; ok { 148 delete(u.open, port) 149 p.Close() 150 } else { 151 p, err = addPort(port, srv) 152 if err != nil { 153 errStr := err.Error() 154 b := make([]byte, 2, 2+len(errStr)) 155 b[0] = buf[0] 156 b[1] = buf[1] 157 b = append(b, errStr...) 158 u.conn.WriteMsgUnix(b, nil, nil) 159 } else { 160 u.open[port] = p 161 u.conn.WriteMsgUnix(buf[:], nil, nil) 162 } 163 } 164 } 165 u.mu.Unlock() 166 } 167 } 168 169 // Error 170 var ( 171 ErrClosed = errors.New("closed") 172 ) 173