solana_netutil/
lib.rs

1//! The `netutil` module assists with networking
2use log::*;
3use rand::{thread_rng, Rng};
4use socket2::{Domain, SockAddr, Socket, Type};
5use std::io::{self, Read, Write};
6use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
7use std::sync::mpsc::channel;
8use std::time::Duration;
9
10mod ip_echo_server;
11use ip_echo_server::IpEchoServerMessage;
12pub use ip_echo_server::{ip_echo_server, IpEchoServer};
13
14/// A data type representing a public Udp socket
15pub struct UdpSocketPair {
16    pub addr: SocketAddr,    // Public address of the socket
17    pub receiver: UdpSocket, // Locally bound socket that can receive from the public address
18    pub sender: UdpSocket,   // Locally bound socket to send via public address
19}
20
21pub type PortRange = (u16, u16);
22
23fn ip_echo_server_request(
24    ip_echo_server_addr: &SocketAddr,
25    msg: IpEchoServerMessage,
26) -> Result<IpAddr, String> {
27    let mut data = Vec::new();
28
29    let timeout = Duration::new(5, 0);
30    TcpStream::connect_timeout(ip_echo_server_addr, timeout)
31        .and_then(|mut stream| {
32            let msg = bincode::serialize(&msg).expect("serialize IpEchoServerMessage");
33            stream.write_all(&msg)?;
34            stream.shutdown(std::net::Shutdown::Write)?;
35            stream
36                .set_read_timeout(Some(Duration::new(10, 0)))
37                .expect("set_read_timeout");
38            stream.read_to_end(&mut data)
39        })
40        .and_then(|_| {
41            bincode::deserialize(&data).map_err(|err| {
42                io::Error::new(
43                    io::ErrorKind::Other,
44                    format!("Failed to deserialize: {:?}", err),
45                )
46            })
47        })
48        .map_err(|err| err.to_string())
49}
50
51/// Determine the public IP address of this machine by asking an ip_echo_server at the given
52/// address
53pub fn get_public_ip_addr(ip_echo_server_addr: &SocketAddr) -> Result<IpAddr, String> {
54    ip_echo_server_request(ip_echo_server_addr, IpEchoServerMessage::default())
55}
56
57// Aborts the process if any of the provided TCP/UDP ports are not reachable by the machine at
58// `ip_echo_server_addr`
59pub fn verify_reachable_ports(
60    ip_echo_server_addr: &SocketAddr,
61    tcp_listeners: Vec<(u16, TcpListener)>,
62    udp_sockets: &[&UdpSocket],
63) {
64    let udp: Vec<(_, _)> = udp_sockets
65        .iter()
66        .map(|udp_socket| {
67            (
68                udp_socket.local_addr().unwrap().port(),
69                udp_socket.try_clone().expect("Unable to clone udp socket"),
70            )
71        })
72        .collect();
73
74    let udp_ports: Vec<_> = udp.iter().map(|x| x.0).collect();
75
76    info!(
77        "Checking that tcp ports {:?} and udp ports {:?} are reachable from {:?}",
78        tcp_listeners, udp_ports, ip_echo_server_addr
79    );
80
81    let tcp_ports: Vec<_> = tcp_listeners.iter().map(|(port, _)| *port).collect();
82    let _ = ip_echo_server_request(
83        ip_echo_server_addr,
84        IpEchoServerMessage::new(&tcp_ports, &udp_ports),
85    )
86    .map_err(|err| warn!("ip_echo_server request failed: {}", err));
87
88    // Wait for a connection to open on each TCP port
89    for (port, tcp_listener) in tcp_listeners {
90        let (sender, receiver) = channel();
91        std::thread::spawn(move || {
92            debug!("Waiting for incoming connection on tcp/{}", port);
93            let _ = tcp_listener.incoming().next().expect("tcp incoming failed");
94            sender.send(()).expect("send failure");
95        });
96        receiver
97            .recv_timeout(Duration::from_secs(5))
98            .unwrap_or_else(|err| {
99                error!(
100                    "Received no response at tcp/{}, check your port configuration: {}",
101                    port, err
102                );
103                std::process::exit(1);
104            });
105        info!("tdp/{} is reachable", port);
106    }
107
108    // Wait for a datagram to arrive at each UDP port
109    for (port, udp_socket) in udp {
110        let (sender, receiver) = channel();
111        std::thread::spawn(move || {
112            let mut buf = [0; 1];
113            debug!("Waiting for incoming datagram on udp/{}", port);
114            let _ = udp_socket.recv(&mut buf).expect("udp recv failure");
115            sender.send(()).expect("send failure");
116        });
117        receiver
118            .recv_timeout(Duration::from_secs(5))
119            .unwrap_or_else(|err| {
120                error!(
121                    "Received no response at udp/{}, check your port configuration: {}",
122                    port, err
123                );
124                std::process::exit(1);
125            });
126        info!("udp/{} is reachable", port);
127    }
128}
129
130pub fn parse_port_or_addr(optstr: Option<&str>, default_addr: SocketAddr) -> SocketAddr {
131    if let Some(addrstr) = optstr {
132        if let Ok(port) = addrstr.parse() {
133            let mut addr = default_addr;
134            addr.set_port(port);
135            addr
136        } else if let Ok(addr) = addrstr.parse() {
137            addr
138        } else {
139            default_addr
140        }
141    } else {
142        default_addr
143    }
144}
145
146pub fn parse_port_range(port_range: &str) -> Option<PortRange> {
147    let ports: Vec<&str> = port_range.split('-').collect();
148    if ports.len() != 2 {
149        return None;
150    }
151
152    let start_port = ports[0].parse();
153    let end_port = ports[1].parse();
154
155    if start_port.is_err() || end_port.is_err() {
156        return None;
157    }
158    let start_port = start_port.unwrap();
159    let end_port = end_port.unwrap();
160    if end_port < start_port {
161        return None;
162    }
163    Some((start_port, end_port))
164}
165
166pub fn parse_host(host: &str) -> Result<IpAddr, String> {
167    let ips: Vec<_> = (host, 0)
168        .to_socket_addrs()
169        .map_err(|err| err.to_string())?
170        .map(|socket_address| socket_address.ip())
171        .collect();
172    if ips.is_empty() {
173        Err(format!("Unable to resolve host: {}", host))
174    } else {
175        Ok(ips[0])
176    }
177}
178
179pub fn parse_host_port(host_port: &str) -> Result<SocketAddr, String> {
180    let addrs: Vec<_> = host_port
181        .to_socket_addrs()
182        .map_err(|err| err.to_string())?
183        .collect();
184    if addrs.is_empty() {
185        Err(format!("Unable to resolve host: {}", host_port))
186    } else {
187        Ok(addrs[0])
188    }
189}
190
191pub fn is_host_port(string: String) -> Result<(), String> {
192    parse_host_port(&string)?;
193    Ok(())
194}
195
196#[cfg(windows)]
197fn udp_socket(_reuseaddr: bool) -> io::Result<Socket> {
198    let sock = Socket::new(Domain::ipv4(), Type::dgram(), None)?;
199    Ok(sock)
200}
201
202#[cfg(not(windows))]
203fn udp_socket(reuseaddr: bool) -> io::Result<Socket> {
204    use nix::sys::socket::setsockopt;
205    use nix::sys::socket::sockopt::{ReuseAddr, ReusePort};
206    use std::os::unix::io::AsRawFd;
207
208    let sock = Socket::new(Domain::ipv4(), Type::dgram(), None)?;
209    let sock_fd = sock.as_raw_fd();
210
211    if reuseaddr {
212        // best effort, i.e. ignore errors here, we'll get the failure in caller
213        setsockopt(sock_fd, ReusePort, &true).ok();
214        setsockopt(sock_fd, ReuseAddr, &true).ok();
215    }
216
217    Ok(sock)
218}
219
220// Find a port in the given range that is available for both TCP and UDP
221pub fn bind_common_in_range(range: PortRange) -> io::Result<(u16, (UdpSocket, TcpListener))> {
222    let (start, end) = range;
223    let mut tries_left = end - start;
224    let mut rand_port = thread_rng().gen_range(start, end);
225    loop {
226        match bind_common(rand_port, false) {
227            Ok((sock, listener)) => {
228                break Result::Ok((sock.local_addr().unwrap().port(), (sock, listener)));
229            }
230            Err(err) => {
231                if tries_left == 0 {
232                    return Err(err);
233                }
234            }
235        }
236        rand_port += 1;
237        if rand_port == end {
238            rand_port = start;
239        }
240        tries_left -= 1;
241    }
242}
243
244pub fn bind_in_range(range: PortRange) -> io::Result<(u16, UdpSocket)> {
245    let sock = udp_socket(false)?;
246
247    let (start, end) = range;
248    let mut tries_left = end - start;
249    let mut rand_port = thread_rng().gen_range(start, end);
250    loop {
251        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), rand_port);
252
253        match sock.bind(&SockAddr::from(addr)) {
254            Ok(_) => {
255                let sock = sock.into_udp_socket();
256                break Result::Ok((sock.local_addr().unwrap().port(), sock));
257            }
258            Err(err) => {
259                if tries_left == 0 {
260                    return Err(err);
261                }
262            }
263        }
264        rand_port += 1;
265        if rand_port == end {
266            rand_port = start;
267        }
268        tries_left -= 1;
269    }
270}
271
272// binds many sockets to the same port in a range
273pub fn multi_bind_in_range(range: PortRange, mut num: usize) -> io::Result<(u16, Vec<UdpSocket>)> {
274    if cfg!(windows) && num != 1 {
275        // See https://github.com/solana-labs/solana/issues/4607
276        warn!(
277            "multi_bind_in_range() only supports 1 socket in windows ({} requested)",
278            num
279        );
280        num = 1;
281    }
282    let mut sockets = Vec::with_capacity(num);
283
284    let port = {
285        let (port, _) = bind_in_range(range)?;
286        port
287    }; // drop the probe, port should be available... briefly.
288
289    for _ in 0..num {
290        sockets.push(bind_to(port, true)?);
291    }
292    Ok((port, sockets))
293}
294
295pub fn bind_to(port: u16, reuseaddr: bool) -> io::Result<UdpSocket> {
296    let sock = udp_socket(reuseaddr)?;
297
298    let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
299
300    match sock.bind(&SockAddr::from(addr)) {
301        Ok(_) => Result::Ok(sock.into_udp_socket()),
302        Err(err) => Err(err),
303    }
304}
305
306// binds both a UdpSocket and a TcpListener
307pub fn bind_common(port: u16, reuseaddr: bool) -> io::Result<(UdpSocket, TcpListener)> {
308    let sock = udp_socket(reuseaddr)?;
309
310    let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
311    let sock_addr = SockAddr::from(addr);
312    match sock.bind(&sock_addr) {
313        Ok(_) => match TcpListener::bind(&addr) {
314            Ok(listener) => Result::Ok((sock.into_udp_socket(), listener)),
315            Err(err) => Err(err),
316        },
317        Err(err) => Err(err),
318    }
319}
320
321pub fn find_available_port_in_range(range: PortRange) -> io::Result<u16> {
322    let (start, end) = range;
323    let mut tries_left = end - start;
324    let mut rand_port = thread_rng().gen_range(start, end);
325    loop {
326        match TcpListener::bind(SocketAddr::new(
327            IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
328            rand_port,
329        )) {
330            Ok(_) => {
331                break Ok(rand_port);
332            }
333            Err(err) => {
334                if tries_left == 0 {
335                    return Err(err);
336                }
337            }
338        }
339        rand_port += 1;
340        if rand_port == end {
341            rand_port = start;
342        }
343        tries_left -= 1;
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_parse_port_or_addr() {
353        let p1 = parse_port_or_addr(Some("9000"), SocketAddr::from(([1, 2, 3, 4], 1)));
354        assert_eq!(p1.port(), 9000);
355        let p2 = parse_port_or_addr(Some("127.0.0.1:7000"), SocketAddr::from(([1, 2, 3, 4], 1)));
356        assert_eq!(p2.port(), 7000);
357        let p2 = parse_port_or_addr(Some("hi there"), SocketAddr::from(([1, 2, 3, 4], 1)));
358        assert_eq!(p2.port(), 1);
359        let p3 = parse_port_or_addr(None, SocketAddr::from(([1, 2, 3, 4], 1)));
360        assert_eq!(p3.port(), 1);
361    }
362
363    #[test]
364    fn test_parse_port_range() {
365        assert_eq!(parse_port_range("garbage"), None);
366        assert_eq!(parse_port_range("1-"), None);
367        assert_eq!(parse_port_range("1-2"), Some((1, 2)));
368        assert_eq!(parse_port_range("1-2-3"), None);
369        assert_eq!(parse_port_range("2-1"), None);
370    }
371
372    #[test]
373    fn test_parse_host() {
374        parse_host("localhost:1234").unwrap_err();
375        parse_host("localhost").unwrap();
376        parse_host("127.0.0.0:1234").unwrap_err();
377        parse_host("127.0.0.0").unwrap();
378    }
379
380    #[test]
381    fn test_parse_host_port() {
382        parse_host_port("localhost:1234").unwrap();
383        parse_host_port("localhost").unwrap_err();
384        parse_host_port("127.0.0.0:1234").unwrap();
385        parse_host_port("127.0.0.0").unwrap_err();
386    }
387
388    #[test]
389    fn test_bind() {
390        assert_eq!(bind_in_range((2000, 2001)).unwrap().0, 2000);
391        let x = bind_to(2002, true).unwrap();
392        let y = bind_to(2002, true).unwrap();
393        assert_eq!(
394            x.local_addr().unwrap().port(),
395            y.local_addr().unwrap().port()
396        );
397        let (port, v) = multi_bind_in_range((2010, 2110), 10).unwrap();
398        for sock in &v {
399            assert_eq!(port, sock.local_addr().unwrap().port());
400        }
401    }
402
403    #[test]
404    #[should_panic]
405    fn test_bind_in_range_nil() {
406        let _ = bind_in_range((2000, 2000));
407    }
408
409    #[test]
410    fn test_find_available_port_in_range() {
411        assert_eq!(find_available_port_in_range((3000, 3001)).unwrap(), 3000);
412        let port = find_available_port_in_range((3000, 3050)).unwrap();
413        assert!(3000 <= port && port < 3050);
414    }
415
416    #[test]
417    fn test_bind_common_in_range() {
418        let (port, _) = bind_common_in_range((3000, 3050)).unwrap();
419        assert!(3000 <= port && port < 3050);
420    }
421}