1 package reverseproxy 2 3 import ( 4 "bytes" 5 "net" 6 "os" 7 "syscall" 8 "testing" 9 "time" 10 ) 11 12 func TestUnix(t *testing.T) { 13 fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) 14 if err != nil { 15 t.Fatalf("unexpected error: %s", err) 16 } 17 nf := os.NewFile(uintptr(fds[0]), "") 18 fconn, err := net.FileConn(nf) 19 if err := nf.Close(); err != nil { 20 t.Fatalf("unexpected error: %s", err) 21 } 22 if err != nil { 23 t.Fatalf("unexpected error: %s", err) 24 } 25 u := &UnixCmd{ 26 conn: fconn.(*net.UnixConn), 27 open: make(map[uint16]*Port), 28 } 29 go u.runCmdLoop(testServiceA{make(testService)}) 30 nf = os.NewFile(uintptr(fds[1]), "") 31 fconn, err = net.FileConn(nf) 32 if err := nf.Close(); err != nil { 33 t.Fatalf("unexpected error: %s", err) 34 } 35 if err != nil { 36 t.Fatalf("unexpected error: %s", err) 37 } 38 conn := fconn.(*net.UnixConn) 39 var ( 40 buf [1024]byte 41 oob = make([]byte, syscall.CmsgLen(4)) 42 ) 43 n, _, err := conn.WriteMsgUnix(buf[:2], nil, nil) 44 if err != nil { 45 t.Errorf("test 1: unexpected error: %s", err) 46 return 47 } else if n != 2 { 48 t.Errorf("test 1: expecting to write 2 bytes, wrote %d", n) 49 return 50 } 51 n, _, _, _, err = conn.ReadMsgUnix(buf[:], oob) 52 if err != nil { 53 t.Errorf("test 2: unexpected error: %s", err) 54 return 55 } else if n <= 2 { 56 t.Errorf("test 2: expecting to read more than 2 bytes, read %d", n) 57 } else if pr := uint16(buf[0]) | (uint16(buf[1]) << 8); pr != 0 { 58 t.Errorf("test 2: expecting to read port 0, got %d", pr) 59 return 60 } else if string(buf[2:n]) != "cannot register on port 0" { 61 t.Errorf("test 2: expecting ErrInvalidPort, got %q", buf[2:n]) 62 return 63 } 64 pa := getUnusedPort() 65 buf[0] = uint8(pa) 66 buf[1] = uint8(pa >> 8) 67 n, _, err = conn.WriteMsgUnix(buf[:2], nil, nil) 68 if err != nil { 69 t.Errorf("test 3: unexpected error: %s", err) 70 return 71 } else if n != 2 { 72 t.Errorf("test 3: expecting to write 2 bytes, wrote %d", n) 73 return 74 } 75 n, _, _, _, err = conn.ReadMsgUnix(buf[:], oob) 76 if err != nil { 77 t.Errorf("test 4: unexpected error: %s", err) 78 return 79 } else if n != 2 { 80 t.Errorf("test 4: expecting to read 2 bytes, read %d", n) 81 } else if pr := uint16(buf[0]) | (uint16(buf[1]) << 8); pr != pa { 82 t.Errorf("test 4: expecting to read port %d, got %d", pa, pr) 83 return 84 } 85 nc, err := net.DialTCP("tcp", nil, &net.TCPAddr{Port: int(pa)}) 86 if err != nil { 87 t.Fatalf("unexpected error: %s", err) 88 } 89 data := tlsServerName(aDomain) 90 nc.Write(data) 91 n, oobn, _, _, err := conn.ReadMsgUnix(buf[:], oob) 92 if err != nil { 93 t.Errorf("test 5: unexpected error: %s", err) 94 return 95 } else if !bytes.Equal(buf[:n], data) { 96 t.Errorf("test 5: expecting to read TLS header %v, got %v", data, buf[:n]) 97 return 98 } 99 msg, _ := syscall.ParseSocketControlMessage(oob[:oobn]) 100 fd, _ := syscall.ParseUnixRights(&msg[0]) 101 nf = os.NewFile(uintptr(fd[0]), "") 102 cn, err := net.FileConn(nf) 103 if err != nil { 104 t.Fatalf("unexpected error: %s", err) 105 } 106 nf.Close() 107 addr := cn.LocalAddr().(*net.TCPAddr) 108 if addr.Port != int(pa) { 109 t.Errorf("test 6: expecting port %d, got %d", pa, addr.Port) 110 return 111 } 112 nc.Write([]byte("TEST")) 113 nc.Close() 114 if n, err := cn.Read(buf[:]); err != nil { 115 t.Errorf("test 7: unexpected error: %s", err) 116 return 117 } else if string(buf[:n]) != "TEST" { 118 t.Errorf("test 7: expecting to read \"TEST\", read %q", buf[:n]) 119 return 120 } 121 buf[0] = uint8(pa) 122 buf[1] = uint8(pa >> 8) 123 n, _, err = conn.WriteMsgUnix(buf[:2], nil, nil) 124 if err != nil { 125 t.Errorf("test 8: unexpected error: %s", err) 126 return 127 } else if n != 2 { 128 t.Errorf("test 8: expecting to write 2 bytes, wrote %d", n) 129 return 130 } 131 time.Sleep(time.Second) 132 l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: int(pa)}) 133 if err != nil { 134 t.Errorf("test 9: unexpected error: %s", err) 135 return 136 } 137 l.Close() 138 } 139