1 package main 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "flag" 9 "fmt" 10 "io" 11 "net" 12 "net/http" 13 "os" 14 "os/signal" 15 "sync" 16 17 "golang.org/x/crypto/acme/autocert" 18 "vimagination.zapto.org/reverseproxy/unixconn" 19 ) 20 21 const bufSize = 1<<16 + 16 22 23 var ( 24 forwardHeader bool 25 headerPool = sync.Pool{ 26 New: func() interface{} { 27 return &[bufSize]byte{} 28 }, 29 } 30 eol = []byte{'\r', '\n'} 31 forward = []byte{'\r', '\n', 'F', 'o', 'r', 'w', 'a', 'r', 'd', 'e', 'd', ':', ' ', 'f', 'o', 'r', '='} 32 proxy string 33 wg sync.WaitGroup 34 ) 35 36 type serverNames []string 37 38 func (s *serverNames) String() string { 39 return "" 40 } 41 42 func (s *serverNames) Set(serverName string) error { 43 *s = append(*s, serverName) 44 return nil 45 } 46 47 func copyConn(a io.Writer, b io.Reader) { 48 io.Copy(a, b) 49 wg.Done() 50 } 51 52 func proxyConn(c net.Conn) { 53 defer wg.Done() 54 pc, err := net.Dial("tcp", proxy) 55 if err != nil { 56 c.Close() 57 return 58 } 59 if forwardHeader { 60 buf := headerPool.Get().(*[bufSize]byte) 61 n := 0 62 l := 0 63 for { 64 m, err := c.Read(buf[n:]) 65 n += m 66 if l = bytes.Index(buf[:n], eol); l >= 0 { 67 pc.Write(buf[:l]) 68 pc.Write(forward) 69 io.WriteString(pc, c.RemoteAddr().String()) 70 break 71 } 72 if err != nil { 73 return 74 } 75 } 76 pc.Write(buf[l:n]) 77 for p := range buf[:n] { 78 buf[p] = 0 79 } 80 headerPool.Put(buf) 81 } 82 wg.Add(2) 83 go copyConn(c, pc) 84 go copyConn(pc, c) 85 } 86 87 func proxySSL(l net.Listener) { 88 wg.Add(1) 89 for { 90 c, err := l.Accept() 91 if err != nil { 92 wg.Done() 93 return 94 } 95 wg.Add(1) 96 go proxyConn(c) 97 } 98 } 99 100 func main() { 101 if err := run(); err != nil { 102 fmt.Fprintf(os.Stderr, "error: %s", err) 103 } 104 } 105 106 func run() error { 107 var ( 108 serverNames serverNames 109 server http.Server 110 ) 111 flag.Var(&serverNames, "s", "server name(s) for TLS") 112 flag.StringVar(&proxy, "p", "", "proxy address") 113 flag.BoolVar(&forwardHeader, "f", false, "add forward headers") 114 flag.Parse() 115 if len(serverNames) == 0 { 116 return errors.New("need server name") 117 } 118 if proxy == "" { 119 return errors.New("need proxy address") 120 } 121 leManager := &autocert.Manager{ 122 Prompt: autocert.AcceptTOS, 123 Cache: autocert.DirCache("./certcache/"), 124 HostPolicy: autocert.HostWhitelist(serverNames...), 125 } 126 l, err := unixconn.Listen("tcp", ":80") 127 if err != nil { 128 return errors.New("unable to open port 80") 129 } 130 sl, err := unixconn.Listen("tcp", ":443") 131 if err != nil { 132 return errors.New("unable to open port 443") 133 } 134 server.Handler = leManager.HTTPHandler(nil) 135 go proxySSL(tls.NewListener(sl, &tls.Config{ 136 GetCertificate: leManager.GetCertificate, 137 NextProtos: []string{"http/1.1"}, 138 })) 139 go func() { 140 if err := server.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { 141 fmt.Fprintln(os.Stderr, err) 142 } 143 }() 144 145 sc := make(chan os.Signal, 1) 146 signal.Notify(sc, os.Interrupt) 147 <-sc 148 signal.Stop(sc) 149 close(sc) 150 server.Shutdown(context.Background()) 151 sl.Close() 152 wg.Wait() 153 return nil 154 } 155