solana_streamer/
sendmmsg.rs

1//! The `sendmmsg` module provides sendmmsg() API implementation
2
3#[cfg(target_os = "linux")]
4use {
5    crate::msghdr::create_msghdr,
6    itertools::izip,
7    libc::{iovec, mmsghdr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t},
8    std::{
9        mem::{self, MaybeUninit},
10        os::unix::io::AsRawFd,
11        ptr,
12    },
13};
14use {
15    solana_transaction_error::TransportError,
16    std::{
17        borrow::Borrow,
18        io,
19        net::{SocketAddr, UdpSocket},
20    },
21    thiserror::Error,
22};
23
24#[derive(Debug, Error)]
25pub enum SendPktsError {
26    /// IO Error during send: first error, num failed packets
27    #[error("IO Error, some packets could not be sent")]
28    IoError(io::Error, usize),
29}
30
31impl From<SendPktsError> for TransportError {
32    fn from(err: SendPktsError) -> Self {
33        Self::Custom(format!("{err:?}"))
34    }
35}
36
37// The type and lifetime constraints are overspecified to match 'linux' code.
38#[cfg(not(target_os = "linux"))]
39pub fn batch_send<'a, S, T: 'a + ?Sized>(
40    sock: &UdpSocket,
41    packets: impl IntoIterator<Item = (&'a T, S), IntoIter: ExactSizeIterator>,
42) -> Result<(), SendPktsError>
43where
44    S: Borrow<SocketAddr>,
45    &'a T: AsRef<[u8]>,
46{
47    let mut num_failed = 0;
48    let mut erropt = None;
49    for (p, a) in packets {
50        if let Err(e) = sock.send_to(p.as_ref(), a.borrow()) {
51            num_failed += 1;
52            if erropt.is_none() {
53                erropt = Some(e);
54            }
55        }
56    }
57
58    if let Some(err) = erropt {
59        Err(SendPktsError::IoError(err, num_failed))
60    } else {
61        Ok(())
62    }
63}
64
65#[cfg(target_os = "linux")]
66fn mmsghdr_for_packet(
67    packet: &[u8],
68    dest: &SocketAddr,
69    iov: &mut MaybeUninit<iovec>,
70    addr: &mut MaybeUninit<sockaddr_storage>,
71    hdr: &mut MaybeUninit<mmsghdr>,
72) {
73    const SIZE_OF_SOCKADDR_IN: usize = mem::size_of::<sockaddr_in>();
74    const SIZE_OF_SOCKADDR_IN6: usize = mem::size_of::<sockaddr_in6>();
75    const SIZE_OF_SOCKADDR_STORAGE: usize = mem::size_of::<sockaddr_storage>();
76    const SOCKADDR_IN_PADDING: usize = SIZE_OF_SOCKADDR_STORAGE - SIZE_OF_SOCKADDR_IN;
77    const SOCKADDR_IN6_PADDING: usize = SIZE_OF_SOCKADDR_STORAGE - SIZE_OF_SOCKADDR_IN6;
78
79    iov.write(iovec {
80        iov_base: packet.as_ptr() as *mut libc::c_void,
81        iov_len: packet.len(),
82    });
83
84    let msg_namelen = match dest {
85        SocketAddr::V4(socket_addr_v4) => {
86            let ptr: *mut sockaddr_in = addr.as_mut_ptr() as *mut _;
87            unsafe {
88                ptr::write(
89                    ptr,
90                    *nix::sys::socket::SockaddrIn::from(*socket_addr_v4).as_ref(),
91                );
92                // Zero the remaining bytes after sockaddr_in
93                ptr::write_bytes(
94                    (ptr as *mut u8).add(SIZE_OF_SOCKADDR_IN),
95                    0,
96                    SOCKADDR_IN_PADDING,
97                );
98            }
99            SIZE_OF_SOCKADDR_IN as socklen_t
100        }
101        SocketAddr::V6(socket_addr_v6) => {
102            let ptr: *mut sockaddr_in6 = addr.as_mut_ptr() as *mut _;
103            unsafe {
104                ptr::write(
105                    ptr,
106                    *nix::sys::socket::SockaddrIn6::from(*socket_addr_v6).as_ref(),
107                );
108                // Zero the remaining bytes after sockaddr_in6
109                ptr::write_bytes(
110                    (ptr as *mut u8).add(SIZE_OF_SOCKADDR_IN6),
111                    0,
112                    SOCKADDR_IN6_PADDING,
113                );
114            }
115            SIZE_OF_SOCKADDR_IN6 as socklen_t
116        }
117    };
118
119    let msg_hdr = create_msghdr(addr, msg_namelen, iov);
120
121    hdr.write(mmsghdr {
122        msg_len: 0,
123        msg_hdr,
124    });
125}
126
127#[cfg(target_os = "linux")]
128fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPktsError> {
129    let sock_fd = sock.as_raw_fd();
130    let mut total_sent = 0;
131    let mut erropt = None;
132
133    let mut pkts = &mut *hdrs;
134    while !pkts.is_empty() {
135        let npkts = match unsafe { libc::sendmmsg(sock_fd, &mut pkts[0], pkts.len() as u32, 0) } {
136            -1 => {
137                if erropt.is_none() {
138                    erropt = Some(io::Error::last_os_error());
139                }
140                // skip over the failing packet
141                1_usize
142            }
143            n => {
144                // if we fail to send all packets we advance to the failing
145                // packet and retry in order to capture the error code
146                total_sent += n as usize;
147                n as usize
148            }
149        };
150        pkts = &mut pkts[npkts..];
151    }
152
153    if let Some(err) = erropt {
154        Err(SendPktsError::IoError(err, hdrs.len() - total_sent))
155    } else {
156        Ok(())
157    }
158}
159
160#[cfg(target_os = "linux")]
161const MAX_IOV: usize = libc::UIO_MAXIOV as usize;
162
163#[cfg(target_os = "linux")]
164fn batch_send_max_iov<'a, S, T: 'a + ?Sized>(
165    sock: &UdpSocket,
166    packets: impl IntoIterator<Item = (&'a T, S), IntoIter: ExactSizeIterator>,
167) -> Result<(), SendPktsError>
168where
169    S: Borrow<SocketAddr>,
170    &'a T: AsRef<[u8]>,
171{
172    let packets = packets.into_iter();
173    let num_packets = packets.len();
174    debug_assert!(num_packets <= MAX_IOV);
175
176    let mut iovs = [MaybeUninit::uninit(); MAX_IOV];
177    let mut addrs = [MaybeUninit::uninit(); MAX_IOV];
178    let mut hdrs = [MaybeUninit::uninit(); MAX_IOV];
179
180    // izip! will iterate packets.len() times, leaving hdrs, iovs, and addrs initialized only up to packets.len()
181    for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
182        mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
183    }
184
185    // SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
186    // guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
187    let hdrs_slice =
188        unsafe { std::slice::from_raw_parts_mut(hdrs.as_mut_ptr() as *mut mmsghdr, num_packets) };
189
190    let result = sendmmsg_retry(sock, hdrs_slice);
191
192    // SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
193    // guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
194    for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(num_packets) {
195        unsafe {
196            hdr.assume_init_drop();
197            iov.assume_init_drop();
198            addr.assume_init_drop();
199        }
200    }
201
202    result
203}
204
205// Need &'a to ensure that raw packet pointers obtained in mmsghdr_for_packet
206// stay valid.
207#[cfg(target_os = "linux")]
208pub fn batch_send<'a, S, T: 'a + ?Sized>(
209    sock: &UdpSocket,
210    packets: impl IntoIterator<Item = (&'a T, S), IntoIter: ExactSizeIterator>,
211) -> Result<(), SendPktsError>
212where
213    S: Borrow<SocketAddr>,
214    &'a T: AsRef<[u8]>,
215{
216    let mut packets = packets.into_iter();
217    loop {
218        let chunk = packets.by_ref().take(MAX_IOV);
219        if chunk.len() == 0 {
220            break;
221        }
222        batch_send_max_iov(sock, chunk)?;
223    }
224    Ok(())
225}
226
227pub fn multi_target_send<S, T>(
228    sock: &UdpSocket,
229    packet: T,
230    dests: &[S],
231) -> Result<(), SendPktsError>
232where
233    S: Borrow<SocketAddr>,
234    T: AsRef<[u8]>,
235{
236    let dests = dests.iter().map(Borrow::borrow);
237    let pkts = dests.map(|addr| (&packet, addr));
238    batch_send(sock, pkts)
239}
240
241#[cfg(test)]
242mod tests {
243    use {
244        crate::{
245            packet::Packet,
246            recvmmsg::recv_mmsg,
247            sendmmsg::{batch_send, multi_target_send, SendPktsError},
248        },
249        assert_matches::assert_matches,
250        solana_net_utils::{bind_to_localhost, bind_to_unspecified},
251        solana_packet::PACKET_DATA_SIZE,
252        std::{
253            io::ErrorKind,
254            net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
255        },
256    };
257
258    #[test]
259    pub fn test_send_mmsg_one_dest() {
260        let reader = bind_to_localhost().expect("bind");
261        let addr = reader.local_addr().unwrap();
262        let sender = bind_to_localhost().expect("bind");
263
264        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
265        let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
266
267        let sent = batch_send(&sender, packet_refs).ok();
268        assert_eq!(sent, Some(()));
269
270        let mut packets = vec![Packet::default(); 32];
271        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
272        assert_eq!(32, recv);
273    }
274
275    #[test]
276    pub fn test_send_mmsg_multi_dest() {
277        let reader = bind_to_localhost().expect("bind");
278        let addr = reader.local_addr().unwrap();
279
280        let reader2 = bind_to_localhost().expect("bind");
281        let addr2 = reader2.local_addr().unwrap();
282
283        let sender = bind_to_localhost().expect("bind");
284
285        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
286        let packet_refs: Vec<_> = packets
287            .iter()
288            .enumerate()
289            .map(|(i, p)| {
290                if i < 16 {
291                    (&p[..], &addr)
292                } else {
293                    (&p[..], &addr2)
294                }
295            })
296            .collect();
297
298        let sent = batch_send(&sender, packet_refs).ok();
299        assert_eq!(sent, Some(()));
300
301        let mut packets = vec![Packet::default(); 32];
302        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
303        assert_eq!(16, recv);
304
305        let mut packets = vec![Packet::default(); 32];
306        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
307        assert_eq!(16, recv);
308    }
309
310    #[test]
311    pub fn test_multicast_msg() {
312        let reader = bind_to_localhost().expect("bind");
313        let addr = reader.local_addr().unwrap();
314
315        let reader2 = bind_to_localhost().expect("bind");
316        let addr2 = reader2.local_addr().unwrap();
317
318        let reader3 = bind_to_localhost().expect("bind");
319        let addr3 = reader3.local_addr().unwrap();
320
321        let reader4 = bind_to_localhost().expect("bind");
322        let addr4 = reader4.local_addr().unwrap();
323
324        let sender = bind_to_localhost().expect("bind");
325
326        let packet = Packet::default();
327
328        let sent = multi_target_send(
329            &sender,
330            packet.data(..).unwrap(),
331            &[&addr, &addr2, &addr3, &addr4],
332        )
333        .ok();
334        assert_eq!(sent, Some(()));
335
336        let mut packets = vec![Packet::default(); 32];
337        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
338        assert_eq!(1, recv);
339
340        let mut packets = vec![Packet::default(); 32];
341        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
342        assert_eq!(1, recv);
343
344        let mut packets = vec![Packet::default(); 32];
345        let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap();
346        assert_eq!(1, recv);
347
348        let mut packets = vec![Packet::default(); 32];
349        let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap();
350        assert_eq!(1, recv);
351    }
352
353    #[test]
354    fn test_intermediate_failures_mismatched_bind() {
355        let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
356        let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
357        let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
358        let packet_refs: Vec<_> = vec![
359            (&packets[0][..], &ip4),
360            (&packets[1][..], &ip6),
361            (&packets[2][..], &ip4),
362        ];
363        let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
364
365        let sender = bind_to_unspecified().expect("bind");
366        let res = batch_send(&sender, packet_refs);
367        assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1)));
368        let res = multi_target_send(&sender, &packets[0], &dest_refs);
369        assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1)));
370    }
371
372    #[test]
373    fn test_intermediate_failures_unreachable_address() {
374        let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
375        let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
376        let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
377        let sender = bind_to_unspecified().expect("bind");
378
379        // test intermediate failures for batch_send
380        let packet_refs: Vec<_> = vec![
381            (&packets[0][..], &ipv4local),
382            (&packets[1][..], &ipv4broadcast),
383            (&packets[2][..], &ipv4local),
384            (&packets[3][..], &ipv4broadcast),
385            (&packets[4][..], &ipv4local),
386        ];
387        match batch_send(&sender, packet_refs) {
388            Ok(()) => panic!(),
389            Err(SendPktsError::IoError(ioerror, num_failed)) => {
390                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
391                assert_eq!(num_failed, 2);
392            }
393        }
394
395        // test leading and trailing failures for batch_send
396        let packet_refs: Vec<_> = vec![
397            (&packets[0][..], &ipv4broadcast),
398            (&packets[1][..], &ipv4local),
399            (&packets[2][..], &ipv4broadcast),
400            (&packets[3][..], &ipv4local),
401            (&packets[4][..], &ipv4broadcast),
402        ];
403        match batch_send(&sender, packet_refs) {
404            Ok(()) => panic!(),
405            Err(SendPktsError::IoError(ioerror, num_failed)) => {
406                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
407                assert_eq!(num_failed, 3);
408            }
409        }
410
411        // test consecutive intermediate failures for batch_send
412        let packet_refs: Vec<_> = vec![
413            (&packets[0][..], &ipv4local),
414            (&packets[1][..], &ipv4local),
415            (&packets[2][..], &ipv4broadcast),
416            (&packets[3][..], &ipv4broadcast),
417            (&packets[4][..], &ipv4local),
418        ];
419        match batch_send(&sender, packet_refs) {
420            Ok(()) => panic!(),
421            Err(SendPktsError::IoError(ioerror, num_failed)) => {
422                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
423                assert_eq!(num_failed, 2);
424            }
425        }
426
427        // test intermediate failures for multi_target_send
428        let dest_refs: Vec<_> = vec![
429            &ipv4local,
430            &ipv4broadcast,
431            &ipv4local,
432            &ipv4broadcast,
433            &ipv4local,
434        ];
435        match multi_target_send(&sender, &packets[0], &dest_refs) {
436            Ok(()) => panic!(),
437            Err(SendPktsError::IoError(ioerror, num_failed)) => {
438                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
439                assert_eq!(num_failed, 2);
440            }
441        }
442
443        // test leading and trailing failures for multi_target_send
444        let dest_refs: Vec<_> = vec![
445            &ipv4broadcast,
446            &ipv4local,
447            &ipv4broadcast,
448            &ipv4local,
449            &ipv4broadcast,
450        ];
451        match multi_target_send(&sender, &packets[0], &dest_refs) {
452            Ok(()) => panic!(),
453            Err(SendPktsError::IoError(ioerror, num_failed)) => {
454                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
455                assert_eq!(num_failed, 3);
456            }
457        }
458    }
459}