reverseproxy - unixconn/unixconn_test.go
1 package unixconn
2
3 import (
4 "errors"
5 "fmt"
6 "io"
7 "net"
8 "os"
9 "syscall"
10 "testing"
11 )
12
13 var (
14 lone, ltwo *net.TCPListener
15 pone, ptwo uint16
16 )
17
18 func TestMain(m *testing.T) {
19 fallback = 0
20 addr := new(net.TCPAddr)
21 var err error
22 if lone, err = net.ListenTCP("tcp", addr); err != nil {
23 m.Fatalf("unexpected error during setup (1): %q", err)
24 }
25 if ltwo, err = net.ListenTCP("tcp", addr); err != nil {
26 m.Fatalf("unexpected error during setup (2): %q", err)
27 }
28 if pone = getPort(lone.Addr().String()); pone == 0 {
29 m.Fatalf("invalid port number (1): %d", pone)
30 }
31 if ptwo = getPort(ltwo.Addr().String()); ptwo == 0 {
32 m.Fatalf("invalid port number (2): %d", ptwo)
33 }
34 }
35
36 func TestUnixConn(t *testing.T) {
37 fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
38 if err != nil {
39 t.Errorf("unexpected error creating socket pair: %s", err)
40 return
41 }
42 fconn, _ := net.FileConn(os.NewFile(uintptr(fds[0]), ""))
43 go testServerLoop(fconn.(*net.UnixConn))
44 fconn, _ = net.FileConn(os.NewFile(uintptr(fds[1]), ""))
45 uc = fconn.(*net.UnixConn)
46 listeningSockets = make(map[uint16]struct{})
47 defer uc.Close()
48 newSocket = make(chan ns)
49 go runListenLoop()
50 l, err := Listen("tcp", ":8080")
51 if err == nil {
52 t.Errorf("test 1: expecting \"error\", got: nil")
53 return
54 } else if err.Error() != "error" {
55 t.Errorf("test 1: expecting \"error\", got: %q", err)
56 return
57 } else if l != nil {
58 t.Error("test 1: expecting nil listener")
59 return
60 }
61 l, err = Listen("tcp", "80")
62 if err == nil {
63 t.Errorf("test 2: expecting \"error\", got: nil")
64 return
65 } else if !errors.Is(err, ErrInvalidAddress) {
66 t.Errorf("test 2: expecting ErrInvalidAddress, got: %q", err)
67 return
68 } else if l != nil {
69 t.Error("test 2: expecting nil listener")
70 return
71 }
72 pstr := fmt.Sprintf(":%d", pone)
73 l, err = Listen("tcp", pstr)
74 if err != nil {
75 t.Errorf("test 3: unexpected error: %s", err)
76 return
77 } else if l == nil {
78 t.Errorf("test 3: expecting non-nil Listener")
79 return
80 }
81 if net := l.Addr().Network(); net != "tcp" {
82 t.Errorf("test 4: expecting network \"tcp\", got: %q", net)
83 return
84 } else if addr := l.Addr().String(); addr != pstr {
85 t.Errorf("test 4: expecting address %q, got %q", pstr, addr)
86 return
87 }
88 c, err := l.Accept()
89 if err != nil {
90 t.Errorf("test 5: unexpected error: %s", err)
91 return
92 } else if c == nil {
93 t.Error("test 5: conn should not be nil")
94 return
95 }
96 var buf [32]byte
97 n, err := c.Read(buf[:])
98 if err != nil {
99 t.Errorf("test 6: unexpected error: %s", err)
100 return
101 } else if n != 3 {
102 t.Errorf("test 6: expecting to read 3 bytes, read %d: ", n)
103 return
104 } else if string(buf[:3]) != "BIG" {
105 t.Errorf("test 6: expecting to read \"BIG\", read: %q", buf[:3])
106 }
107 n, err = c.Read(buf[:])
108 if err != nil {
109 t.Errorf("test 7: unexpected error: %s", err)
110 return
111 } else if n != 4 {
112 t.Errorf("test 7: expecting to read 4 bytes, read %d: ", n)
113 return
114 } else if string(buf[:4]) != "data" {
115 t.Errorf("test 7: expecting to read \"data\", read: %q", buf[:3])
116 return
117 }
118 n, err = c.Read(buf[:])
119 if n != 0 {
120 t.Errorf("test 8: expecting to read no data, read: %q", buf[:n])
121 return
122 } else if !errors.Is(err, io.EOF) {
123 t.Errorf("test 8: expecting to EOF, got: %s", err)
124 return
125 }
126 var l2 net.Listener
127 pstr = fmt.Sprintf(":%d", ptwo)
128 l2, err = Listen("tcp", pstr)
129 if err != nil {
130 t.Errorf("test 9: unexpected error: %s", err)
131 return
132 } else if l2 == nil {
133 t.Errorf("test 9: expecting non-nil Listener")
134 return
135 }
136 if net := l2.Addr().Network(); net != "tcp" {
137 t.Errorf("test 10: expecting network \"tcp\", got: %q", net)
138 return
139 } else if addr := l2.Addr().String(); addr != pstr {
140 t.Errorf("test 10: expecting address %q, got %q", pstr, addr)
141 return
142 }
143 c, err = l2.Accept()
144 if err != nil {
145 t.Errorf("test 11: unexpected error: %s", err)
146 return
147 } else if c == nil {
148 t.Error("test 11: conn should not be nil")
149 return
150 }
151 err = l2.Close()
152 if err != nil {
153 t.Errorf("test 12: expecting nil error, got: %s", err)
154 }
155 err = l2.Close()
156 if !errors.Is(err, net.ErrClosed) {
157 t.Errorf("test 13: expecting net.ErrClosed, got: %s", err)
158 }
159 ct, err := l2.Accept()
160 if !errors.Is(err, net.ErrClosed) {
161 t.Errorf("test 14: expecting net.ErrClosed, got: %s", err)
162 return
163 } else if ct != nil {
164 t.Errorf("test 14: expecting nil conn, got: %v", ct)
165 }
166 ct, err = l.Accept()
167 if err != nil {
168 t.Errorf("test 15: unexpected error: %s", err)
169 } else if ct == nil {
170 t.Error("test 15: recieved nil conn when conn expected")
171 }
172 n, err = c.Read(buf[:])
173 if err != nil {
174 t.Errorf("test 16: unexpected error: %s", err)
175 return
176 } else if n != 5 {
177 t.Errorf("test 16: expecting to read 3 bytes, read %d: ", n)
178 return
179 } else if string(buf[:5]) != "HELLO" {
180 t.Errorf("test 16: expecting to read \"HELLO\", read: %q", buf[:5])
181 }
182 n, err = ct.Read(buf[:])
183 if err != nil {
184 t.Errorf("test 17: unexpected error: %s", err)
185 return
186 } else if n != 10 {
187 t.Errorf("test 17: expecting to read 10 bytes, read %d: ", n)
188 return
189 } else if string(buf[:10]) != "1234567890" {
190 t.Errorf("test 17: expecting to read \"1234567890\", read: %q", buf[:10])
191 }
192 n, err = c.Read(buf[:])
193 if err != nil {
194 t.Errorf("test 18: unexpected error: %s", err)
195 return
196 } else if n != 5 {
197 t.Errorf("test 18: expecting to read 5 bytes, read %d: ", n)
198 return
199 } else if string(buf[:5]) != "world" {
200 t.Errorf("test 18: expecting to read \"world\", read: %q", buf[:5])
201 return
202 }
203 n, err = ct.Read(buf[:])
204 if err != nil {
205 t.Errorf("test 19: unexpected error: %s", err)
206 return
207 } else if n != 10 {
208 t.Errorf("test 19: expecting to read 10 bytes, read %d: ", n)
209 return
210 } else if string(buf[:10]) != "0987654321" {
211 t.Errorf("test 19: expecting to read \"0987654321\", read: %q", buf[:3])
212 }
213 n, err = c.Read(buf[:])
214 if n != 0 {
215 t.Errorf("test 20: expecting to read no data, read: %q", buf[:n])
216 return
217 } else if !errors.Is(err, io.EOF) {
218 t.Errorf("test 20: expecting to EOF, got: %s", err)
219 return
220 }
221 err = ct.Close()
222 if err != nil {
223 t.Errorf("test 21: expecting nil error, got: %s", err)
224 }
225 n, err = ct.Read(buf[:])
226 if n != 0 {
227 t.Errorf("test 22: expecting to read no data, read: %q", buf[:n])
228 return
229 } else if !errors.Is(err, net.ErrClosed) {
230 t.Errorf("test 22: expecting to EOF, got: %s", err)
231 return
232 }
233 err = ct.Close()
234 if !errors.Is(err, net.ErrClosed) {
235 t.Errorf("test 23: expecting net.ErrClosed, got: %s", err)
236 }
237 }
238
239 func testServerLoop(conn *net.UnixConn) {
240 defer conn.Close()
241 buf := [...]byte{0, 0, 'e', 'r', 'r', 'o', 'r'}
242
243 // test 1
244 conn.ReadMsgUnix(buf[:2], nil)
245 if buf[0] != 0x90 || buf[1] != 0x1f {
246 conn.WriteMsgUnix(buf[:5], nil, nil)
247 return
248 }
249 conn.WriteMsgUnix(buf[:], nil, nil)
250
251 // test 3
252 conn.ReadMsgUnix(buf[:2], nil)
253 p := uint16(buf[1])<<8 | uint16(buf[0])
254 if p != pone {
255 conn.WriteMsgUnix(buf[:5], nil, nil)
256 return
257 }
258 conn.WriteMsgUnix(buf[:2], nil, nil)
259
260 go func() {
261 c, _ := net.DialTCP("tcp", nil, &net.TCPAddr{Port: int(pone)})
262 c.Write([]byte("data")) // test 7
263 c.Close()
264 }()
265
266 c, _ := lone.AcceptTCP() // test 5
267 transfer(conn, c, []byte("BIG")) // test 6
268
269 // test 9
270 conn.ReadMsgUnix(buf[:2], nil)
271 p = uint16(buf[1])<<8 | uint16(buf[0])
272 if p != ptwo {
273 conn.WriteMsgUnix(buf[:5], nil, nil)
274 return
275 }
276 conn.WriteMsgUnix(buf[:2], nil, nil)
277
278 // test 11
279 go func() {
280 c, _ := net.DialTCP("tcp", nil, &net.TCPAddr{Port: int(ptwo)})
281 c.Write([]byte("world")) // test 18
282 c.Close() // test 20
283 }()
284 c, _ = ltwo.AcceptTCP()
285 transfer(conn, c, []byte("HELLO")) // test 16
286
287 // test 12
288 conn.ReadMsgUnix(buf[:2], nil)
289 conn.WriteMsgUnix(buf[:2], nil, nil)
290
291 // test 15
292 go func() {
293 c, _ := net.DialTCP("tcp", nil, &net.TCPAddr{Port: int(pone)})
294 c.Write([]byte("0987654321")) // test 19
295 c.Close()
296 }()
297 c, _ = lone.AcceptTCP()
298 transfer(conn, c, []byte("1234567890")) // test 17
299 }
300
301 func transfer(conn *net.UnixConn, c *net.TCPConn, data []byte) {
302 f, _ := c.File()
303 conn.WriteMsgUnix(data, syscall.UnixRights(int(f.Fd())), nil)
304 }
305