1use log::*;
3use rand::{thread_rng, Rng};
4use socket2::{Domain, SockAddr, Socket, Type};
5use std::io::{self, Read, Write};
6use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
7use std::sync::mpsc::channel;
8use std::time::Duration;
9
10mod ip_echo_server;
11use ip_echo_server::IpEchoServerMessage;
12pub use ip_echo_server::{ip_echo_server, IpEchoServer};
13
14pub struct UdpSocketPair {
16 pub addr: SocketAddr, pub receiver: UdpSocket, pub sender: UdpSocket, }
20
21pub type PortRange = (u16, u16);
22
23fn ip_echo_server_request(
24 ip_echo_server_addr: &SocketAddr,
25 msg: IpEchoServerMessage,
26) -> Result<IpAddr, String> {
27 let mut data = Vec::new();
28
29 let timeout = Duration::new(5, 0);
30 TcpStream::connect_timeout(ip_echo_server_addr, timeout)
31 .and_then(|mut stream| {
32 let msg = bincode::serialize(&msg).expect("serialize IpEchoServerMessage");
33 stream.write_all(&msg)?;
34 stream.shutdown(std::net::Shutdown::Write)?;
35 stream
36 .set_read_timeout(Some(Duration::new(10, 0)))
37 .expect("set_read_timeout");
38 stream.read_to_end(&mut data)
39 })
40 .and_then(|_| {
41 bincode::deserialize(&data).map_err(|err| {
42 io::Error::new(
43 io::ErrorKind::Other,
44 format!("Failed to deserialize: {:?}", err),
45 )
46 })
47 })
48 .map_err(|err| err.to_string())
49}
50
51pub fn get_public_ip_addr(ip_echo_server_addr: &SocketAddr) -> Result<IpAddr, String> {
54 ip_echo_server_request(ip_echo_server_addr, IpEchoServerMessage::default())
55}
56
57pub fn verify_reachable_ports(
60 ip_echo_server_addr: &SocketAddr,
61 tcp_listeners: Vec<(u16, TcpListener)>,
62 udp_sockets: &[&UdpSocket],
63) {
64 let udp: Vec<(_, _)> = udp_sockets
65 .iter()
66 .map(|udp_socket| {
67 (
68 udp_socket.local_addr().unwrap().port(),
69 udp_socket.try_clone().expect("Unable to clone udp socket"),
70 )
71 })
72 .collect();
73
74 let udp_ports: Vec<_> = udp.iter().map(|x| x.0).collect();
75
76 info!(
77 "Checking that tcp ports {:?} and udp ports {:?} are reachable from {:?}",
78 tcp_listeners, udp_ports, ip_echo_server_addr
79 );
80
81 let tcp_ports: Vec<_> = tcp_listeners.iter().map(|(port, _)| *port).collect();
82 let _ = ip_echo_server_request(
83 ip_echo_server_addr,
84 IpEchoServerMessage::new(&tcp_ports, &udp_ports),
85 )
86 .map_err(|err| warn!("ip_echo_server request failed: {}", err));
87
88 for (port, tcp_listener) in tcp_listeners {
90 let (sender, receiver) = channel();
91 std::thread::spawn(move || {
92 debug!("Waiting for incoming connection on tcp/{}", port);
93 let _ = tcp_listener.incoming().next().expect("tcp incoming failed");
94 sender.send(()).expect("send failure");
95 });
96 receiver
97 .recv_timeout(Duration::from_secs(5))
98 .unwrap_or_else(|err| {
99 error!(
100 "Received no response at tcp/{}, check your port configuration: {}",
101 port, err
102 );
103 std::process::exit(1);
104 });
105 info!("tdp/{} is reachable", port);
106 }
107
108 for (port, udp_socket) in udp {
110 let (sender, receiver) = channel();
111 std::thread::spawn(move || {
112 let mut buf = [0; 1];
113 debug!("Waiting for incoming datagram on udp/{}", port);
114 let _ = udp_socket.recv(&mut buf).expect("udp recv failure");
115 sender.send(()).expect("send failure");
116 });
117 receiver
118 .recv_timeout(Duration::from_secs(5))
119 .unwrap_or_else(|err| {
120 error!(
121 "Received no response at udp/{}, check your port configuration: {}",
122 port, err
123 );
124 std::process::exit(1);
125 });
126 info!("udp/{} is reachable", port);
127 }
128}
129
130pub fn parse_port_or_addr(optstr: Option<&str>, default_addr: SocketAddr) -> SocketAddr {
131 if let Some(addrstr) = optstr {
132 if let Ok(port) = addrstr.parse() {
133 let mut addr = default_addr;
134 addr.set_port(port);
135 addr
136 } else if let Ok(addr) = addrstr.parse() {
137 addr
138 } else {
139 default_addr
140 }
141 } else {
142 default_addr
143 }
144}
145
146pub fn parse_port_range(port_range: &str) -> Option<PortRange> {
147 let ports: Vec<&str> = port_range.split('-').collect();
148 if ports.len() != 2 {
149 return None;
150 }
151
152 let start_port = ports[0].parse();
153 let end_port = ports[1].parse();
154
155 if start_port.is_err() || end_port.is_err() {
156 return None;
157 }
158 let start_port = start_port.unwrap();
159 let end_port = end_port.unwrap();
160 if end_port < start_port {
161 return None;
162 }
163 Some((start_port, end_port))
164}
165
166pub fn parse_host(host: &str) -> Result<IpAddr, String> {
167 let ips: Vec<_> = (host, 0)
168 .to_socket_addrs()
169 .map_err(|err| err.to_string())?
170 .map(|socket_address| socket_address.ip())
171 .collect();
172 if ips.is_empty() {
173 Err(format!("Unable to resolve host: {}", host))
174 } else {
175 Ok(ips[0])
176 }
177}
178
179pub fn parse_host_port(host_port: &str) -> Result<SocketAddr, String> {
180 let addrs: Vec<_> = host_port
181 .to_socket_addrs()
182 .map_err(|err| err.to_string())?
183 .collect();
184 if addrs.is_empty() {
185 Err(format!("Unable to resolve host: {}", host_port))
186 } else {
187 Ok(addrs[0])
188 }
189}
190
191pub fn is_host_port(string: String) -> Result<(), String> {
192 parse_host_port(&string)?;
193 Ok(())
194}
195
196#[cfg(windows)]
197fn udp_socket(_reuseaddr: bool) -> io::Result<Socket> {
198 let sock = Socket::new(Domain::ipv4(), Type::dgram(), None)?;
199 Ok(sock)
200}
201
202#[cfg(not(windows))]
203fn udp_socket(reuseaddr: bool) -> io::Result<Socket> {
204 use nix::sys::socket::setsockopt;
205 use nix::sys::socket::sockopt::{ReuseAddr, ReusePort};
206 use std::os::unix::io::AsRawFd;
207
208 let sock = Socket::new(Domain::ipv4(), Type::dgram(), None)?;
209 let sock_fd = sock.as_raw_fd();
210
211 if reuseaddr {
212 setsockopt(sock_fd, ReusePort, &true).ok();
214 setsockopt(sock_fd, ReuseAddr, &true).ok();
215 }
216
217 Ok(sock)
218}
219
220pub fn bind_common_in_range(range: PortRange) -> io::Result<(u16, (UdpSocket, TcpListener))> {
222 let (start, end) = range;
223 let mut tries_left = end - start;
224 let mut rand_port = thread_rng().gen_range(start, end);
225 loop {
226 match bind_common(rand_port, false) {
227 Ok((sock, listener)) => {
228 break Result::Ok((sock.local_addr().unwrap().port(), (sock, listener)));
229 }
230 Err(err) => {
231 if tries_left == 0 {
232 return Err(err);
233 }
234 }
235 }
236 rand_port += 1;
237 if rand_port == end {
238 rand_port = start;
239 }
240 tries_left -= 1;
241 }
242}
243
244pub fn bind_in_range(range: PortRange) -> io::Result<(u16, UdpSocket)> {
245 let sock = udp_socket(false)?;
246
247 let (start, end) = range;
248 let mut tries_left = end - start;
249 let mut rand_port = thread_rng().gen_range(start, end);
250 loop {
251 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), rand_port);
252
253 match sock.bind(&SockAddr::from(addr)) {
254 Ok(_) => {
255 let sock = sock.into_udp_socket();
256 break Result::Ok((sock.local_addr().unwrap().port(), sock));
257 }
258 Err(err) => {
259 if tries_left == 0 {
260 return Err(err);
261 }
262 }
263 }
264 rand_port += 1;
265 if rand_port == end {
266 rand_port = start;
267 }
268 tries_left -= 1;
269 }
270}
271
272pub fn multi_bind_in_range(range: PortRange, mut num: usize) -> io::Result<(u16, Vec<UdpSocket>)> {
274 if cfg!(windows) && num != 1 {
275 warn!(
277 "multi_bind_in_range() only supports 1 socket in windows ({} requested)",
278 num
279 );
280 num = 1;
281 }
282 let mut sockets = Vec::with_capacity(num);
283
284 let port = {
285 let (port, _) = bind_in_range(range)?;
286 port
287 }; for _ in 0..num {
290 sockets.push(bind_to(port, true)?);
291 }
292 Ok((port, sockets))
293}
294
295pub fn bind_to(port: u16, reuseaddr: bool) -> io::Result<UdpSocket> {
296 let sock = udp_socket(reuseaddr)?;
297
298 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
299
300 match sock.bind(&SockAddr::from(addr)) {
301 Ok(_) => Result::Ok(sock.into_udp_socket()),
302 Err(err) => Err(err),
303 }
304}
305
306pub fn bind_common(port: u16, reuseaddr: bool) -> io::Result<(UdpSocket, TcpListener)> {
308 let sock = udp_socket(reuseaddr)?;
309
310 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
311 let sock_addr = SockAddr::from(addr);
312 match sock.bind(&sock_addr) {
313 Ok(_) => match TcpListener::bind(&addr) {
314 Ok(listener) => Result::Ok((sock.into_udp_socket(), listener)),
315 Err(err) => Err(err),
316 },
317 Err(err) => Err(err),
318 }
319}
320
321pub fn find_available_port_in_range(range: PortRange) -> io::Result<u16> {
322 let (start, end) = range;
323 let mut tries_left = end - start;
324 let mut rand_port = thread_rng().gen_range(start, end);
325 loop {
326 match TcpListener::bind(SocketAddr::new(
327 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
328 rand_port,
329 )) {
330 Ok(_) => {
331 break Ok(rand_port);
332 }
333 Err(err) => {
334 if tries_left == 0 {
335 return Err(err);
336 }
337 }
338 }
339 rand_port += 1;
340 if rand_port == end {
341 rand_port = start;
342 }
343 tries_left -= 1;
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_parse_port_or_addr() {
353 let p1 = parse_port_or_addr(Some("9000"), SocketAddr::from(([1, 2, 3, 4], 1)));
354 assert_eq!(p1.port(), 9000);
355 let p2 = parse_port_or_addr(Some("127.0.0.1:7000"), SocketAddr::from(([1, 2, 3, 4], 1)));
356 assert_eq!(p2.port(), 7000);
357 let p2 = parse_port_or_addr(Some("hi there"), SocketAddr::from(([1, 2, 3, 4], 1)));
358 assert_eq!(p2.port(), 1);
359 let p3 = parse_port_or_addr(None, SocketAddr::from(([1, 2, 3, 4], 1)));
360 assert_eq!(p3.port(), 1);
361 }
362
363 #[test]
364 fn test_parse_port_range() {
365 assert_eq!(parse_port_range("garbage"), None);
366 assert_eq!(parse_port_range("1-"), None);
367 assert_eq!(parse_port_range("1-2"), Some((1, 2)));
368 assert_eq!(parse_port_range("1-2-3"), None);
369 assert_eq!(parse_port_range("2-1"), None);
370 }
371
372 #[test]
373 fn test_parse_host() {
374 parse_host("localhost:1234").unwrap_err();
375 parse_host("localhost").unwrap();
376 parse_host("127.0.0.0:1234").unwrap_err();
377 parse_host("127.0.0.0").unwrap();
378 }
379
380 #[test]
381 fn test_parse_host_port() {
382 parse_host_port("localhost:1234").unwrap();
383 parse_host_port("localhost").unwrap_err();
384 parse_host_port("127.0.0.0:1234").unwrap();
385 parse_host_port("127.0.0.0").unwrap_err();
386 }
387
388 #[test]
389 fn test_bind() {
390 assert_eq!(bind_in_range((2000, 2001)).unwrap().0, 2000);
391 let x = bind_to(2002, true).unwrap();
392 let y = bind_to(2002, true).unwrap();
393 assert_eq!(
394 x.local_addr().unwrap().port(),
395 y.local_addr().unwrap().port()
396 );
397 let (port, v) = multi_bind_in_range((2010, 2110), 10).unwrap();
398 for sock in &v {
399 assert_eq!(port, sock.local_addr().unwrap().port());
400 }
401 }
402
403 #[test]
404 #[should_panic]
405 fn test_bind_in_range_nil() {
406 let _ = bind_in_range((2000, 2000));
407 }
408
409 #[test]
410 fn test_find_available_port_in_range() {
411 assert_eq!(find_available_port_in_range((3000, 3001)).unwrap(), 3000);
412 let port = find_available_port_in_range((3000, 3050)).unwrap();
413 assert!(3000 <= port && port < 3050);
414 }
415
416 #[test]
417 fn test_bind_common_in_range() {
418 let (port, _) = bind_common_in_range((3000, 3050)).unwrap();
419 assert!(3000 <= port && port < 3050);
420 }
421}