1 package reverseproxy 2 3 import ( 4 "bytes" 5 "fmt" 6 "net" 7 "os" 8 "testing" 9 ) 10 11 const ( 12 aDomain = "aaa.com" 13 bDomain = "bbb.com" 14 ) 15 16 type testData struct { 17 buf []byte 18 conn *net.TCPConn 19 } 20 21 type testService chan testData 22 23 func (t testService) Transfer(buf []byte, conn *net.TCPConn) { 24 t <- testData{append(make([]byte, 0, len(buf)), buf...), conn} 25 } 26 27 func (t testService) Active() bool { 28 return false 29 } 30 31 type testServiceA struct { 32 testService 33 } 34 35 func (testServiceA) MatchService(service string) bool { 36 return service == aDomain 37 } 38 39 type testServiceB struct { 40 testService 41 } 42 43 func (testServiceB) MatchService(service string) bool { 44 return service == bDomain 45 } 46 47 func getUnusedPort() uint16 { 48 l, err := net.ListenTCP("tcp", nil) 49 if err != nil { 50 return 0 51 } 52 p := uint16(l.Addr().(*net.TCPAddr).Port) 53 l.Close() 54 return p 55 } 56 57 func TestListener(t *testing.T) { 58 sync := make(chan struct{}) 59 pa := getUnusedPort() 60 sa := make(testService) 61 p, err := addPort(pa, testServiceA{sa}) 62 if err != nil { 63 t.Errorf("unexpected error: %s", err) 64 return 65 } 66 const firstSend = "GET / HTTP/1.1\r\nHost: " + aDomain + "\r\n\r\n" 67 go func() { 68 c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", pa)) 69 if err != nil { 70 fmt.Println(err) 71 os.Exit(1) 72 } 73 c.Write([]byte(firstSend)) 74 <-sync 75 c.Write([]byte{127}) 76 c.Close() 77 }() 78 data := <-sa 79 sync <- struct{}{} 80 if string(data.buf) != firstSend { 81 t.Errorf("test 1: expecting buf to equal %q, got %q", firstSend, data.buf) 82 return 83 } 84 var buf [32]byte 85 n, err := data.conn.Read(buf[:]) 86 if err != nil { 87 t.Errorf("test 2: unexpected error: %s", err) 88 return 89 } else if n != 1 { 90 t.Errorf("test 2: expecting to read 1 byte, read %d", n) 91 return 92 } else if buf[0] != 127 { 93 t.Errorf("test 2: expecting to read 127, read %d", buf[0]) 94 return 95 } 96 err = data.conn.Close() 97 if err != nil { 98 t.Errorf("test 3: unexpected error: %s", err) 99 return 100 } 101 sb := make(testService) 102 q, err := addPort(pa, testServiceB{sb}) 103 if err != nil { 104 t.Fatalf("unexpected error: %s", err) 105 } 106 const secondSend = "GET / HTTP/1.1\r\nHost: " + bDomain + "\r\n\r\n" 107 go func() { 108 c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", pa)) 109 if err != nil { 110 fmt.Println(err) 111 os.Exit(1) 112 } 113 c.Write([]byte(secondSend)) 114 <-sync 115 c.Write([]byte{255, 127}) 116 c.Close() 117 }() 118 data = <-sb 119 sync <- struct{}{} 120 if string(data.buf) != secondSend { 121 t.Errorf("test 4: expecting buf to equal %q, got %q", secondSend, data.buf) 122 return 123 } 124 n, err = data.conn.Read(buf[:]) 125 if err != nil { 126 t.Errorf("test 5: unexpected error: %s", err) 127 return 128 } else if n != 2 { 129 t.Errorf("test 5: expecting to read 1 byte, read %d", n) 130 return 131 } else if buf[0] != 255 || buf[1] != 127 { 132 t.Errorf("test 5: expecting to read 255, 127, read %v", buf[:2]) 133 return 134 } 135 err = data.conn.Close() 136 if err != nil { 137 t.Errorf("test 6: unexpected error: %s", err) 138 return 139 } 140 go func() { 141 c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", pa)) 142 if err != nil { 143 fmt.Println(err) 144 os.Exit(1) 145 } 146 c.Write([]byte(firstSend)) 147 <-sync 148 c.Write([]byte{1, 2, 3}) 149 c.Close() 150 }() 151 go func() { 152 <-sync 153 c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", pa)) 154 if err != nil { 155 fmt.Println(err) 156 os.Exit(1) 157 } 158 c.Write([]byte(secondSend)) 159 <-sync 160 c.Write([]byte{4, 5, 6, 7}) 161 c.Close() 162 }() 163 data = <-sa 164 sync <- struct{}{} 165 sync <- struct{}{} 166 dataB := <-sb 167 sync <- struct{}{} 168 if string(data.buf) != firstSend { 169 t.Errorf("test 7: expecting buf to equal %q, got %q", firstSend, data.buf) 170 return 171 } 172 if string(dataB.buf) != secondSend { 173 t.Errorf("test 8: expecting buf to equal %q, got %q", secondSend, dataB.buf) 174 return 175 } 176 n, err = data.conn.Read(buf[:]) 177 if err != nil { 178 t.Errorf("test 9: unexpected error: %s", err) 179 return 180 } else if n != 3 { 181 t.Errorf("test 9: expecting to read 1 byte, read %d", n) 182 return 183 } else if !bytes.Equal(buf[:3], []byte{1, 2, 3}) { 184 t.Errorf("test 9: expecting to read 1, 2, 3, read %v", buf[:3]) 185 return 186 } 187 n, err = dataB.conn.Read(buf[:]) 188 if err != nil { 189 t.Errorf("test 10: unexpected error: %s", err) 190 return 191 } else if n != 4 { 192 t.Errorf("test 10: expecting to read 1 byte, read %d", n) 193 return 194 } else if !bytes.Equal(buf[:4], []byte{4, 5, 6, 7}) { 195 t.Errorf("test 10: expecting to read 4, 5, 6, 7, read %v", buf[:4]) 196 return 197 } 198 err = data.conn.Close() 199 if err != nil { 200 t.Errorf("test 11: unexpected error: %s", err) 201 return 202 } 203 err = dataB.conn.Close() 204 if err != nil { 205 t.Errorf("test 12: unexpected error: %s", err) 206 return 207 } 208 p.Close() 209 tlsData := tlsServerName(bDomain) 210 go func() { 211 c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", pa)) 212 if err != nil { 213 fmt.Println(err) 214 os.Exit(1) 215 } 216 c.Write(tlsData) 217 <-sync 218 c.Write([]byte{3, 2, 1, 0}) 219 c.Close() 220 }() 221 dataB = <-sb 222 sync <- struct{}{} 223 if !bytes.Equal(dataB.buf, tlsData) { 224 t.Errorf("test 13: expected to read TLS Header, read %v", dataB.buf) 225 return 226 } 227 n, err = dataB.conn.Read(buf[:]) 228 if err != nil { 229 t.Errorf("test 14: unexpected error: %s", err) 230 return 231 } else if n != 4 { 232 t.Errorf("test 14: expecting to read 1 byte, read %d", n) 233 return 234 } else if !bytes.Equal(buf[:4], []byte{3, 2, 1, 0}) { 235 t.Errorf("test 14: expecting to read 3, 2, 1, 0, read %v", buf[:4]) 236 return 237 } 238 err = dataB.conn.Close() 239 if err != nil { 240 t.Errorf("test 15: unexpected error: %s", err) 241 return 242 } 243 q.Close() 244 l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: int(pa)}) 245 if err != nil { 246 t.Errorf("test 13: unexpected error: %s", err) 247 return 248 } 249 l.Close() 250 } 251