1 package main 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "flag" 9 "fmt" 10 "io" 11 "net" 12 "net/http" 13 "os" 14 "os/signal" 15 "sync" 16 17 "golang.org/x/crypto/acme/autocert" 18 "vimagination.zapto.org/reverseproxy/unixconn" 19 ) 20 21 const bufSize = 1<<16 + 16 22 23 var ( 24 forwardHeader bool 25 headerPool = sync.Pool{ 26 New: func() interface{} { 27 return &[bufSize]byte{} 28 }, 29 } 30 eol = []byte{'\r', '\n'} 31 forward = []byte{'\r', '\n', 'F', 'o', 'r', 'w', 'a', 'r', 'd', 'e', 'd', ':', ' ', 'f', 'o', 'r', '='} 32 proxy string 33 wg sync.WaitGroup 34 ) 35 36 type serverNames []string 37 38 func (s *serverNames) String() string { 39 return "" 40 } 41 42 func (s *serverNames) Set(serverName string) error { 43 *s = append(*s, serverName) 44 45 return nil 46 } 47 48 func copyConn(a io.Writer, b io.Reader) { 49 io.Copy(a, b) 50 wg.Done() 51 } 52 53 func proxyConn(c net.Conn) { 54 defer wg.Done() 55 56 pc, err := net.Dial("tcp", proxy) 57 if err != nil { 58 c.Close() 59 60 return 61 } 62 63 if forwardHeader { 64 buf := headerPool.Get().(*[bufSize]byte) 65 n := 0 66 l := 0 67 68 for { 69 m, err := c.Read(buf[n:]) 70 n += m 71 72 if l = bytes.Index(buf[:n], eol); l >= 0 { 73 pc.Write(buf[:l]) 74 pc.Write(forward) 75 io.WriteString(pc, c.RemoteAddr().String()) 76 77 break 78 } 79 80 if err != nil { 81 return 82 } 83 } 84 85 pc.Write(buf[l:n]) 86 87 for p := range buf[:n] { 88 buf[p] = 0 89 } 90 91 headerPool.Put(buf) 92 } 93 94 wg.Add(2) 95 96 go copyConn(c, pc) 97 go copyConn(pc, c) 98 } 99 100 func proxySSL(l net.Listener) { 101 wg.Add(1) 102 103 for { 104 c, err := l.Accept() 105 if err != nil { 106 wg.Done() 107 108 return 109 } 110 111 wg.Add(1) 112 113 go proxyConn(c) 114 } 115 } 116 117 func main() { 118 if err := run(); err != nil { 119 fmt.Fprintf(os.Stderr, "error: %s", err) 120 } 121 } 122 123 func run() error { 124 var ( 125 sNames serverNames 126 server http.Server 127 ) 128 129 flag.Var(&sNames, "s", "server name(s) for TLS") 130 flag.StringVar(&proxy, "p", "", "proxy address") 131 flag.BoolVar(&forwardHeader, "f", false, "add forward headers") 132 flag.Parse() 133 134 if len(sNames) == 0 { 135 return errors.New("need server name") 136 } 137 138 if proxy == "" { 139 return errors.New("need proxy address") 140 } 141 142 leManager := &autocert.Manager{ 143 Prompt: autocert.AcceptTOS, 144 Cache: autocert.DirCache("./certcache/"), 145 HostPolicy: autocert.HostWhitelist(sNames...), 146 } 147 148 l, err := unixconn.Listen("tcp", ":80") 149 if err != nil { 150 return errors.New("unable to open port 80") 151 } 152 153 sl, err := unixconn.Listen("tcp", ":443") 154 if err != nil { 155 return errors.New("unable to open port 443") 156 } 157 158 server.Handler = leManager.HTTPHandler(nil) 159 160 go proxySSL(tls.NewListener(sl, &tls.Config{ 161 GetCertificate: leManager.GetCertificate, 162 NextProtos: []string{"http/1.1"}, 163 })) 164 165 go func() { 166 if err := server.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { 167 fmt.Fprintln(os.Stderr, err) 168 } 169 }() 170 171 sc := make(chan os.Signal, 1) 172 173 signal.Notify(sc, os.Interrupt) 174 175 <-sc 176 177 signal.Stop(sc) 178 close(sc) 179 server.Shutdown(context.Background()) 180 sl.Close() 181 wg.Wait() 182 183 return nil 184 } 185