Skip to main content

sozu_command_lib/
scm_socket.rs

1//! SCM_RIGHTS socket for FD passing between master and workers.
2//!
3//! Wraps a `SeqPacket` unix socket and uses the `nix` SCM_RIGHTS helpers
4//! to ship listener/accept FDs across the master ↔ worker boundary at
5//! startup and across hot upgrades. The borrowed-FD wrappers
6//! (`set_blocking`) hold an FD without taking ownership; the listener
7//! teardown paths intentionally take ownership through
8//! `TcpListener::from_raw_fd` so the FD is closed by drop.
9
10use std::{
11    io::{IoSlice, IoSliceMut},
12    net::{AddrParseError, SocketAddr},
13    os::unix::{
14        io::{FromRawFd, IntoRawFd, RawFd},
15        net::UnixStream as StdUnixStream,
16    },
17};
18
19use mio::net::{TcpListener, UdpSocket};
20use nix::{cmsg_space, sys::socket};
21use prost::{DecodeError, Message};
22
23use crate::proto::command::ListenersCount;
24
25pub const MAX_FDS_OUT: usize = 200;
26pub const MAX_BYTES_OUT: usize = 4096;
27
28#[derive(thiserror::Error, Debug)]
29pub enum ScmSocketError {
30    #[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
31    SetBlocking {
32        blocking: bool,
33        error: std::io::Error,
34    },
35    #[error("could not send message for SCM socket: {0}")]
36    Send(String),
37    #[error("could not receive message for SCM socket: {0}")]
38    Receive(String),
39    #[error("invalid char set: {0}")]
40    InvalidCharSet(String),
41    #[error("Could not deserialize utf8 string into listeners: {0}")]
42    ListenerParse(String),
43    #[error("Wrong socket address {address}: {error}")]
44    WrongSocketAddress {
45        address: String,
46        error: AddrParseError,
47    },
48    #[error("error decoding the protobuf format of the listeners: {0}")]
49    DecodeError(DecodeError),
50    #[error(
51        "listeners count manifest is inconsistent with the SCM payload: \
52         http={http}, tls={tls}, tcp={tcp} (sum={total}), fds_received={fds_received}, max_fds={max_fds}"
53    )]
54    ListenersCountInconsistent {
55        http: usize,
56        tls: usize,
57        tcp: usize,
58        total: usize,
59        fds_received: usize,
60        max_fds: usize,
61    },
62}
63
64/// A unix socket specialized for file descriptor passing
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ScmSocket {
67    pub fd: RawFd,
68    pub blocking: bool,
69}
70
71impl ScmSocket {
72    /// Create a blocking SCM socket from a raw file descriptor (unsafe)
73    pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
74        // SAFETY: `fd` is borrowed for the duration of this block. We wrap it
75        // in a `StdUnixStream` to call `set_nonblocking`, then immediately
76        // release ownership again with `into_raw_fd` so the descriptor is
77        // not closed by `Drop`. The caller retains ownership of `fd`.
78        unsafe {
79            let stream = StdUnixStream::from_raw_fd(fd);
80            stream
81                .set_nonblocking(false)
82                .map_err(|error| ScmSocketError::SetBlocking {
83                    blocking: false,
84                    error,
85                })?;
86            let _dropped_fd = stream.into_raw_fd();
87        }
88
89        Ok(ScmSocket { fd, blocking: true })
90    }
91
92    /// Get the raw file descriptor of the scm channel
93    pub fn raw_fd(&self) -> i32 {
94        self.fd
95    }
96
97    /// Use the standard library (unsafe) to set the socket to blocking / unblocking
98    pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
99        if self.blocking == blocking {
100            return Ok(());
101        }
102        // Past the idempotent early-return the flag must actually be flipping.
103        debug_assert_ne!(
104            self.blocking, blocking,
105            "set_blocking only reaches the syscall when the state actually changes"
106        );
107        let blocking_before = self.blocking;
108        // SAFETY: `self.fd` is borrowed for the duration of this block. We wrap
109        // it in a `StdUnixStream` to call `set_nonblocking`, then immediately
110        // release ownership with `into_raw_fd` so the descriptor is not closed
111        // by `Drop`. `ScmSocket` retains the original ownership.
112        unsafe {
113            let stream = StdUnixStream::from_raw_fd(self.fd);
114            stream
115                .set_nonblocking(!blocking)
116                .map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
117            let _dropped_fd = stream.into_raw_fd();
118        }
119        self.blocking = blocking;
120        // The flag landed on the requested value and genuinely toggled.
121        debug_assert_eq!(
122            self.blocking, blocking,
123            "blocking flag must reflect the requested state after a successful set"
124        );
125        debug_assert_ne!(
126            self.blocking, blocking_before,
127            "blocking flag must have toggled across a real state change"
128        );
129        Ok(())
130    }
131
132    /// Send listeners (socket addresses and file descriptors) via an scm socket
133    pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
134        let listeners_count = ListenersCount {
135            http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
136            tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
137            tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
138            udp: listeners.udp.iter().map(|t| t.0.to_string()).collect(),
139        };
140
141        // The manifest is built 1:1 from the listener tables; each address slot
142        // ships exactly one FD, so the per-protocol counts must agree on both
143        // sides of the wire. The receiver reconstructs (address, fd) pairs by
144        // zipping these counts against the FD array — drift here is the exact
145        // bug class `command_channel_security_tests` guards.
146        debug_assert_eq!(
147            listeners_count.http.len(),
148            listeners.http.len(),
149            "http manifest count must match the http listener table"
150        );
151        debug_assert_eq!(
152            listeners_count.tls.len(),
153            listeners.tls.len(),
154            "tls manifest count must match the tls listener table"
155        );
156        debug_assert_eq!(
157            listeners_count.tcp.len(),
158            listeners.tcp.len(),
159            "tcp manifest count must match the tcp listener table"
160        );
161
162        let message = listeners_count.encode_length_delimited_to_vec();
163
164        let mut file_descriptors: Vec<RawFd> = Vec::new();
165
166        file_descriptors.extend(listeners.http.iter().map(|t| t.1));
167        file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
168        file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
169        file_descriptors.extend(listeners.udp.iter().map(|t| t.1));
170
171        // The FD vector must reconcile with the address totals: one descriptor
172        // per listener, folded http+tls+tcp+udp. If these disagree the receiver
173        // would zip mismatched (address, fd) pairs.
174        let address_total =
175            listeners.http.len() + listeners.tls.len() + listeners.tcp.len() + listeners.udp.len();
176        debug_assert_eq!(
177            file_descriptors.len(),
178            address_total,
179            "the FD count sent must equal the total listener-address count (one FD per address)"
180        );
181
182        self.send_msg_and_fds(&message, &file_descriptors)
183    }
184
185    /// Receive and parse listeners (socket addresses and file descriptors) via an scm socket
186    pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
187        let mut buf = vec![0; MAX_BYTES_OUT];
188
189        let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
190
191        let (size, file_descriptor_length) =
192            self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
193
194        debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
195
196        let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
197            .map_err(ScmSocketError::DecodeError)?;
198
199        // Validate the manifest before indexing into the fixed-size FD array.
200        // The peer-controlled `listeners_count.{http,tls,tcp,udp}` lists are
201        // matched 1:1 with `received_fds` slots; without these bounds checks
202        // a peer that declared more entries than MAX_FDS_OUT or more entries
203        // than FDs actually arrived would panic the worker on
204        // `received_fds[index..index + len]`. `udp` is folded into the total
205        // and the inconsistency-error `tcp` slot for diagnostics rather than
206        // widening the error variant.
207        let http_len = listeners_count.http.len();
208        let tls_len = listeners_count.tls.len();
209        let tcp_len = listeners_count.tcp.len();
210        let udp_len = listeners_count.udp.len();
211        let total = http_len
212            .checked_add(tls_len)
213            .and_then(|s| s.checked_add(tcp_len))
214            .and_then(|s| s.checked_add(udp_len))
215            .ok_or(ScmSocketError::ListenersCountInconsistent {
216                http: http_len,
217                tls: tls_len,
218                tcp: tcp_len.saturating_add(udp_len),
219                total: usize::MAX,
220                fds_received: file_descriptor_length,
221                max_fds: MAX_FDS_OUT,
222            })?;
223        if total > MAX_FDS_OUT || total > file_descriptor_length {
224            return Err(ScmSocketError::ListenersCountInconsistent {
225                http: http_len,
226                tls: tls_len,
227                tcp: tcp_len.saturating_add(udp_len),
228                total,
229                fds_received: file_descriptor_length,
230                max_fds: MAX_FDS_OUT,
231            });
232        }
233
234        // Past the consistency guard, the folded total reconciles with both the
235        // fixed FD-array bound and the number of FDs that actually arrived.
236        // These are the invariants that keep every `received_fds[index..index+len]`
237        // slice below in bounds — a malformed manifest already returned an error
238        // and never reaches here.
239        debug_assert_eq!(
240            total,
241            http_len + tls_len + tcp_len + udp_len,
242            "folded total must equal the sum of per-protocol counts"
243        );
244        debug_assert!(
245            total <= MAX_FDS_OUT,
246            "total FD slots must fit the fixed-size received_fds array before indexing"
247        );
248        debug_assert!(
249            total <= file_descriptor_length,
250            "manifest total must not exceed the FDs actually received"
251        );
252        debug_assert!(
253            total <= received_fds.len(),
254            "every (address, fd) zip below must stay within the received_fds array"
255        );
256
257        let mut http_addresses = parse_addresses(&listeners_count.http)?;
258        let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
259        let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;
260        let mut udp_addresses = parse_addresses(&listeners_count.udp)?;
261
262        // Each parsed address list maps 1:1 onto a contiguous FD slice; the
263        // counts must survive `parse_addresses` unchanged.
264        debug_assert_eq!(
265            http_addresses.len(),
266            http_len,
267            "parsed http address count must match the manifest count"
268        );
269        debug_assert_eq!(
270            tls_addresses.len(),
271            tls_len,
272            "parsed tls address count must match the manifest count"
273        );
274        debug_assert_eq!(
275            tcp_addresses.len(),
276            tcp_len,
277            "parsed tcp address count must match the manifest count"
278        );
279
280        let mut index = 0;
281        let len = http_len;
282        // Each FD slice end must stay within the validated total (and thus the
283        // array); pair-assert the slice window before every zip.
284        debug_assert!(
285            index + len <= total,
286            "http FD slice must lie within the reconciled total"
287        );
288        let mut http = Vec::new();
289        http.extend(
290            http_addresses
291                .drain(..)
292                .zip(received_fds[index..index + len].iter().cloned()),
293        );
294        // Each address was wrapped with exactly one FD.
295        debug_assert_eq!(
296            http.len(),
297            http_len,
298            "every http address must be paired with exactly one FD"
299        );
300
301        index += len;
302        let len = tls_len;
303        debug_assert!(
304            index + len <= total,
305            "tls FD slice must lie within the reconciled total"
306        );
307        let mut tls = Vec::new();
308        tls.extend(
309            tls_addresses
310                .drain(..)
311                .zip(received_fds[index..index + len].iter().cloned()),
312        );
313        debug_assert_eq!(
314            tls.len(),
315            tls_len,
316            "every tls address must be paired with exactly one FD"
317        );
318
319        index += len;
320        let len = tcp_len;
321        debug_assert!(
322            index + len <= total,
323            "tcp FD slice must lie within the reconciled total"
324        );
325        let mut tcp = Vec::new();
326        tcp.extend(
327            tcp_addresses
328                .drain(..)
329                .zip(received_fds[index..index + len].iter().cloned()),
330        );
331        debug_assert_eq!(
332            tcp.len(),
333            tcp_len,
334            "every tcp address must be paired with exactly one FD"
335        );
336
337        index += len;
338        let len = udp_len;
339        let mut udp = Vec::new();
340        udp.extend(
341            udp_addresses
342                .drain(..)
343                .zip(received_fds[index..index + len].iter().cloned()),
344        );
345        debug_assert_eq!(
346            udp.len(),
347            udp_len,
348            "every udp address must be paired with exactly one FD"
349        );
350
351        // The reconstructed tables consume every FD slot the manifest declared:
352        // the final cursor lands exactly on the folded total (http+tls+tcp+udp).
353        debug_assert_eq!(
354            index + len,
355            total,
356            "the (address, fd) zips must consume exactly the reconciled total of FD slots"
357        );
358        debug_assert_eq!(
359            http.len() + tls.len() + tcp.len() + udp.len(),
360            total,
361            "reconstructed listener count must equal the reconciled FD total"
362        );
363
364        Ok(Listeners {
365            http,
366            tls,
367            tcp,
368            udp,
369        })
370    }
371
372    /// Sends message and file descriptors separately. The file descriptors are summed up
373    /// in a ControlMessage.
374    fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
375        let iov = [IoSlice::new(message)];
376        let flags = if self.blocking {
377            socket::MsgFlags::empty()
378        } else {
379            socket::MsgFlags::MSG_DONTWAIT
380        };
381
382        if fds.is_empty() {
383            debug!("{} send empty", self.fd);
384            socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
385                .map_err(|error| ScmSocketError::Send(error.to_string()))?;
386            return Ok(());
387        };
388
389        let control_message = [socket::ControlMessage::ScmRights(fds)];
390        debug!("{} send with data", self.fd);
391        socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
392            .map_err(|error| ScmSocketError::Send(error.to_string()))?;
393        Ok(())
394    }
395
396    /// Parse the message and receives file descriptors separately via the ControlMessage
397    fn receive_msg_and_fds(
398        &self,
399        message: &mut [u8],
400        fds: &mut [RawFd],
401    ) -> Result<(usize, usize), ScmSocketError> {
402        // Snapshot the buffer length before `message` is borrowed mutably by
403        // `iov`; the received byte count is asserted against it below.
404        let message_capacity = message.len();
405        let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
406        let mut iov = [IoSliceMut::new(message)];
407
408        let flags = if self.blocking {
409            socket::MsgFlags::empty()
410        } else {
411            socket::MsgFlags::MSG_DONTWAIT
412        };
413
414        let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
415            .map_err(|error| ScmSocketError::Receive(error.to_string()))?;
416
417        // The destination slice is the receiver's fixed `[RawFd; MAX_FDS_OUT]`
418        // array; the zip below cannot write past it.
419        let fds_capacity = fds.len();
420        debug_assert!(
421            fds_capacity <= MAX_FDS_OUT,
422            "destination FD slice must not exceed the MAX_FDS_OUT cmsg space"
423        );
424        let mut fd_count = 0;
425        let received_fds = msg
426            .cmsgs()
427            .map_err(|error| ScmSocketError::Receive(error.to_string()))?
428            .filter_map(|cmsg| {
429                if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
430                    Some(s)
431                } else {
432                    None
433                }
434            })
435            .flatten();
436        for (fd, place) in received_fds.zip(fds.iter_mut()) {
437            fd_count += 1;
438            *place = fd;
439            // The zip is bounded by `fds.iter_mut()`, so each wrap stays within
440            // the destination array — never write past `fds_capacity`.
441            debug_assert!(
442                fd_count <= fds_capacity,
443                "received FD count must never exceed the destination array capacity"
444            );
445        }
446        // Post-condition: the reported count reconciles with the destination
447        // bound and the byte count is within the message buffer we handed in.
448        debug_assert!(
449            fd_count <= fds_capacity,
450            "final received FD count must fit the destination array"
451        );
452        debug_assert!(
453            msg.bytes <= message_capacity,
454            "received byte count must not exceed the message buffer it was read into"
455        );
456        Ok((msg.bytes, fd_count))
457    }
458}
459
460/// Socket addresses and file descriptors of listening sockets, needed by a
461/// Proxy to start listening. The transport is fd-type-agnostic: `udp` carries
462/// `UdpSocket` fds, the others carry `TcpListener` fds.
463#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
464pub struct Listeners {
465    pub http: Vec<(SocketAddr, RawFd)>,
466    pub tls: Vec<(SocketAddr, RawFd)>,
467    pub tcp: Vec<(SocketAddr, RawFd)>,
468    #[serde(default)]
469    pub udp: Vec<(SocketAddr, RawFd)>,
470}
471
472impl Listeners {
473    pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
474        let before = self.http.len();
475        let pos = self.http.iter().position(|(front, _)| front == addr);
476        let result = pos.map(|pos| self.http.remove(pos).1);
477        // Exactly the matched entry is removed: the length drops by one iff a
478        // match was found, and the removed address is truly gone.
479        debug_assert_eq!(
480            self.http.len(),
481            before - result.is_some() as usize,
482            "http listener table shrinks by exactly one iff an address matched"
483        );
484        debug_assert!(
485            result.is_none() || !self.http.iter().any(|(front, _)| front == addr),
486            "the matched http address must no longer be present after removal"
487        );
488        result
489    }
490
491    pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
492        let before = self.tls.len();
493        let pos = self.tls.iter().position(|(front, _)| front == addr);
494        let result = pos.map(|pos| self.tls.remove(pos).1);
495        debug_assert_eq!(
496            self.tls.len(),
497            before - result.is_some() as usize,
498            "tls listener table shrinks by exactly one iff an address matched"
499        );
500        debug_assert!(
501            result.is_none() || !self.tls.iter().any(|(front, _)| front == addr),
502            "the matched tls address must no longer be present after removal"
503        );
504        result
505    }
506
507    pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
508        let before = self.tcp.len();
509        let pos = self.tcp.iter().position(|(front, _)| front == addr);
510        let result = pos.map(|pos| self.tcp.remove(pos).1);
511        debug_assert_eq!(
512            self.tcp.len(),
513            before - result.is_some() as usize,
514            "tcp listener table shrinks by exactly one iff an address matched"
515        );
516        debug_assert!(
517            result.is_none() || !self.tcp.iter().any(|(front, _)| front == addr),
518            "the matched tcp address must no longer be present after removal"
519        );
520        result
521    }
522
523    pub fn get_udp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
524        self.udp
525            .iter()
526            .position(|(front, _)| front == addr)
527            .map(|pos| self.udp.remove(pos).1)
528    }
529
530    /// Deactivate all listeners by closing their file descriptors
531    pub fn close(&self) {
532        for (_, fd) in &self.http {
533            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
534            // about to be closed by the binding's `Drop` (intentional
535            // close-by-drop). No other reference to the descriptor survives.
536            unsafe {
537                let _ = TcpListener::from_raw_fd(*fd);
538            }
539        }
540
541        for (_, fd) in &self.tls {
542            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
543            // about to be closed by the binding's `Drop` (intentional
544            // close-by-drop). No other reference to the descriptor survives.
545            unsafe {
546                let _ = TcpListener::from_raw_fd(*fd);
547            }
548        }
549
550        for (_, fd) in &self.tcp {
551            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
552            // about to be closed by the binding's `Drop` (intentional
553            // close-by-drop). No other reference to the descriptor survives.
554            unsafe {
555                let _ = TcpListener::from_raw_fd(*fd);
556            }
557        }
558
559        for (_, fd) in &self.udp {
560            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
561            // about to be closed by the binding's `Drop` (intentional
562            // close-by-drop). No other reference to the descriptor survives.
563            // UDP listeners are `UdpSocket` fds, so take ownership through the
564            // matching wrapper.
565            unsafe {
566                let _ = UdpSocket::from_raw_fd(*fd);
567            }
568        }
569    }
570}
571
572fn parse_addresses(addresses: &[String]) -> Result<Vec<SocketAddr>, ScmSocketError> {
573    let mut parsed_addresses = Vec::new();
574    for address in addresses {
575        parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
576            ScmSocketError::WrongSocketAddress {
577                address: address.to_owned(),
578                error,
579            }
580        })?);
581    }
582    Ok(parsed_addresses)
583}
584
585#[cfg(test)]
586mod tests {
587
588    use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
589
590    use mio::net::UnixStream as MioUnixStream;
591
592    use super::*;
593
594    #[test]
595    fn create_block_unblock_an_scm_socket() {
596        let (nonblocking_stream, _) =
597            MioUnixStream::pair().expect("Could not create a pair of unix streams");
598        let raw_file_descriptor = nonblocking_stream.into_raw_fd();
599
600        let scm_socket = ScmSocket::new(raw_file_descriptor);
601        assert!(scm_socket.is_ok());
602
603        let mut scm_socket = scm_socket.unwrap();
604
605        assert!(scm_socket.set_blocking(true).is_ok());
606        assert!(scm_socket.set_blocking(false).is_ok());
607    }
608
609    fn socket_addr_from_str(str: &str) -> SocketAddr {
610        SocketAddr::from_str(str)
611            .unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
612    }
613
614    #[test]
615    fn send_and_receive_empty_listeners() {
616        let (stream_1, stream_2) =
617            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
618
619        let sending_scm_socket =
620            ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
621
622        let receiving_scm_socket =
623            ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
624
625        let listeners = Listeners::default();
626
627        sending_scm_socket
628            .send_listeners(&listeners)
629            .expect("Could not send listeners");
630
631        let received_listeners = receiving_scm_socket
632            .receive_listeners()
633            .expect("Could not receive listeners");
634
635        assert_eq!(listeners, received_listeners);
636    }
637
638    #[test]
639    fn send_and_receive_socket_addresses() {
640        let (stream_1, stream_2) =
641            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
642
643        println!("unix stream pair: {stream_1:?} and {stream_2:?}");
644        let sending_scm_socket =
645            ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
646
647        println!("sending socket: {sending_scm_socket:?}");
648
649        let receiving_scm_socket =
650            ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
651
652        println!("receiving socket: {receiving_scm_socket:?}");
653
654        // We have to provide actual file descriptors, even if they will all be changed in the takeover
655        let (http_socket1, http_socket2) =
656            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
657        let (tcp_socket1, tcp_socket2) =
658            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
659        let (tls_socket1, tls_socket2) =
660            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
661        let (udp_socket1, udp_socket2) =
662            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
663
664        let listeners = Listeners {
665            http: vec![
666                (
667                    socket_addr_from_str("127.0.1.1:8080"),
668                    http_socket1.as_raw_fd(),
669                ),
670                (
671                    socket_addr_from_str("127.0.1.2:8080"),
672                    http_socket2.as_raw_fd(),
673                ),
674            ],
675            tcp: vec![
676                (
677                    socket_addr_from_str("127.0.2.1:8080"),
678                    tcp_socket1.as_raw_fd(),
679                ),
680                (
681                    socket_addr_from_str("127.0.2.2:8080"),
682                    tcp_socket2.as_raw_fd(),
683                ),
684            ],
685            tls: vec![
686                (
687                    socket_addr_from_str("127.0.3.1:8443"),
688                    tls_socket1.as_raw_fd(),
689                ),
690                (
691                    socket_addr_from_str("127.0.3.2:8443"),
692                    tls_socket2.as_raw_fd(),
693                ),
694            ],
695            udp: vec![
696                (
697                    socket_addr_from_str("127.0.4.1:5353"),
698                    udp_socket1.as_raw_fd(),
699                ),
700                (
701                    socket_addr_from_str("127.0.4.2:5353"),
702                    udp_socket2.as_raw_fd(),
703                ),
704            ],
705        };
706
707        println!("self.fd: {}", sending_scm_socket.fd);
708        println!("listeners to send: {listeners:#?}");
709
710        sending_scm_socket
711            .send_listeners(&listeners)
712            .expect("Could not send listeners");
713
714        let received_listeners = receiving_scm_socket
715            .receive_listeners()
716            .expect("Could not receive listeners");
717
718        assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
719        assert_eq!(listeners.udp.len(), received_listeners.udp.len());
720        assert_eq!(listeners.udp[0].0, received_listeners.udp[0].0);
721        assert_eq!(listeners.udp[1].0, received_listeners.udp[1].0);
722    }
723
724    /// Regression: a malformed `ListenersCount` whose entry counts do not
725    /// match the number of file descriptors received over SCM must be
726    /// rejected with `ListenersCountInconsistent`, never panic the worker
727    /// on `received_fds[index..index + len]`.
728    ///
729    /// Without the bounds check, a peer that declares more addresses than
730    /// `MAX_FDS_OUT` (or more than the FDs that actually arrived) crashes
731    /// the receiving worker on out-of-bounds array indexing.
732    #[test]
733    fn rejects_listeners_count_with_more_entries_than_fds() {
734        let (stream_1, stream_2) =
735            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
736        let sender = ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
737        let receiver = ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
738
739        // Declare three HTTP entries but ship zero file descriptors.
740        let bogus = ListenersCount {
741            http: vec![
742                "127.0.0.1:80".to_string(),
743                "127.0.0.2:80".to_string(),
744                "127.0.0.3:80".to_string(),
745            ],
746            tls: vec![],
747            tcp: vec![],
748            udp: vec![],
749        };
750        let payload = bogus.encode_length_delimited_to_vec();
751        sender
752            .send_msg_and_fds(&payload, &[])
753            .expect("manual send_msg_and_fds with zero fds must succeed at the syscall layer");
754
755        match receiver.receive_listeners() {
756            Err(ScmSocketError::ListenersCountInconsistent {
757                http,
758                tls,
759                tcp,
760                total,
761                fds_received,
762                max_fds,
763            }) => {
764                assert_eq!(http, 3);
765                assert_eq!(tls, 0);
766                assert_eq!(tcp, 0);
767                assert_eq!(total, 3);
768                assert_eq!(fds_received, 0);
769                assert_eq!(max_fds, MAX_FDS_OUT);
770            }
771            other => panic!(
772                "expected ListenersCountInconsistent, got {other:?}\n\
773                 NOTE: a panic / OOM here means the SCM bounds check was reverted",
774            ),
775        }
776    }
777}