Skip to main content

sozu_command_lib/
scm_socket.rs

1//! SCM_RIGHTS socket for FD passing between master and workers.
2//!
3//! Wraps a `SeqPacket` unix socket and uses the `nix` SCM_RIGHTS helpers
4//! to ship listener/accept FDs across the master ↔ worker boundary at
5//! startup and across hot upgrades. The borrowed-FD wrappers
6//! (`set_blocking`) hold an FD without taking ownership; the listener
7//! teardown paths intentionally take ownership through
8//! `TcpListener::from_raw_fd` so the FD is closed by drop.
9
10use std::{
11    io::{IoSlice, IoSliceMut},
12    net::{AddrParseError, SocketAddr},
13    os::unix::{
14        io::{FromRawFd, IntoRawFd, RawFd},
15        net::UnixStream as StdUnixStream,
16    },
17};
18
19use mio::net::TcpListener;
20use nix::{cmsg_space, sys::socket};
21use prost::{DecodeError, Message};
22
23use crate::proto::command::ListenersCount;
24
25pub const MAX_FDS_OUT: usize = 200;
26pub const MAX_BYTES_OUT: usize = 4096;
27
28#[derive(thiserror::Error, Debug)]
29pub enum ScmSocketError {
30    #[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
31    SetBlocking {
32        blocking: bool,
33        error: std::io::Error,
34    },
35    #[error("could not send message for SCM socket: {0}")]
36    Send(String),
37    #[error("could not receive message for SCM socket: {0}")]
38    Receive(String),
39    #[error("invalid char set: {0}")]
40    InvalidCharSet(String),
41    #[error("Could not deserialize utf8 string into listeners: {0}")]
42    ListenerParse(String),
43    #[error("Wrong socket address {address}: {error}")]
44    WrongSocketAddress {
45        address: String,
46        error: AddrParseError,
47    },
48    #[error("error decoding the protobuf format of the listeners: {0}")]
49    DecodeError(DecodeError),
50    #[error(
51        "listeners count manifest is inconsistent with the SCM payload: \
52         http={http}, tls={tls}, tcp={tcp} (sum={total}), fds_received={fds_received}, max_fds={max_fds}"
53    )]
54    ListenersCountInconsistent {
55        http: usize,
56        tls: usize,
57        tcp: usize,
58        total: usize,
59        fds_received: usize,
60        max_fds: usize,
61    },
62}
63
64/// A unix socket specialized for file descriptor passing
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ScmSocket {
67    pub fd: RawFd,
68    pub blocking: bool,
69}
70
71impl ScmSocket {
72    /// Create a blocking SCM socket from a raw file descriptor (unsafe)
73    pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
74        // SAFETY: `fd` is borrowed for the duration of this block. We wrap it
75        // in a `StdUnixStream` to call `set_nonblocking`, then immediately
76        // release ownership again with `into_raw_fd` so the descriptor is
77        // not closed by `Drop`. The caller retains ownership of `fd`.
78        unsafe {
79            let stream = StdUnixStream::from_raw_fd(fd);
80            stream
81                .set_nonblocking(false)
82                .map_err(|error| ScmSocketError::SetBlocking {
83                    blocking: false,
84                    error,
85                })?;
86            let _dropped_fd = stream.into_raw_fd();
87        }
88
89        Ok(ScmSocket { fd, blocking: true })
90    }
91
92    /// Get the raw file descriptor of the scm channel
93    pub fn raw_fd(&self) -> i32 {
94        self.fd
95    }
96
97    /// Use the standard library (unsafe) to set the socket to blocking / unblocking
98    pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
99        if self.blocking == blocking {
100            return Ok(());
101        }
102        // SAFETY: `self.fd` is borrowed for the duration of this block. We wrap
103        // it in a `StdUnixStream` to call `set_nonblocking`, then immediately
104        // release ownership with `into_raw_fd` so the descriptor is not closed
105        // by `Drop`. `ScmSocket` retains the original ownership.
106        unsafe {
107            let stream = StdUnixStream::from_raw_fd(self.fd);
108            stream
109                .set_nonblocking(!blocking)
110                .map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
111            let _dropped_fd = stream.into_raw_fd();
112        }
113        self.blocking = blocking;
114        Ok(())
115    }
116
117    /// Send listeners (socket addresses and file descriptors) via an scm socket
118    pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
119        let listeners_count = ListenersCount {
120            http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
121            tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
122            tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
123        };
124
125        let message = listeners_count.encode_length_delimited_to_vec();
126
127        let mut file_descriptors: Vec<RawFd> = Vec::new();
128
129        file_descriptors.extend(listeners.http.iter().map(|t| t.1));
130        file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
131        file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
132
133        self.send_msg_and_fds(&message, &file_descriptors)
134    }
135
136    /// Receive and parse listeners (socket addresses and file descriptors) via an scm socket
137    pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
138        let mut buf = vec![0; MAX_BYTES_OUT];
139
140        let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
141
142        let (size, file_descriptor_length) =
143            self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
144
145        debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
146
147        let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
148            .map_err(ScmSocketError::DecodeError)?;
149
150        // Validate the manifest before indexing into the fixed-size FD array.
151        // The peer-controlled `listeners_count.{http,tls,tcp}` lists are
152        // matched 1:1 with `received_fds` slots; without these bounds checks
153        // a peer that declared more entries than MAX_FDS_OUT or more entries
154        // than FDs actually arrived would panic the worker on
155        // `received_fds[index..index + len]`.
156        let http_len = listeners_count.http.len();
157        let tls_len = listeners_count.tls.len();
158        let tcp_len = listeners_count.tcp.len();
159        let total = http_len
160            .checked_add(tls_len)
161            .and_then(|s| s.checked_add(tcp_len))
162            .ok_or(ScmSocketError::ListenersCountInconsistent {
163                http: http_len,
164                tls: tls_len,
165                tcp: tcp_len,
166                total: usize::MAX,
167                fds_received: file_descriptor_length,
168                max_fds: MAX_FDS_OUT,
169            })?;
170        if total > MAX_FDS_OUT || total > file_descriptor_length {
171            return Err(ScmSocketError::ListenersCountInconsistent {
172                http: http_len,
173                tls: tls_len,
174                tcp: tcp_len,
175                total,
176                fds_received: file_descriptor_length,
177                max_fds: MAX_FDS_OUT,
178            });
179        }
180
181        let mut http_addresses = parse_addresses(&listeners_count.http)?;
182        let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
183        let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;
184
185        let mut index = 0;
186        let len = http_len;
187        let mut http = Vec::new();
188        http.extend(
189            http_addresses
190                .drain(..)
191                .zip(received_fds[index..index + len].iter().cloned()),
192        );
193
194        index += len;
195        let len = tls_len;
196        let mut tls = Vec::new();
197        tls.extend(
198            tls_addresses
199                .drain(..)
200                .zip(received_fds[index..index + len].iter().cloned()),
201        );
202
203        index += len;
204        let len = tcp_len;
205        let mut tcp = Vec::new();
206        tcp.extend(
207            tcp_addresses
208                .drain(..)
209                .zip(received_fds[index..index + len].iter().cloned()),
210        );
211
212        Ok(Listeners { http, tls, tcp })
213    }
214
215    /// Sends message and file descriptors separately. The file descriptors are summed up
216    /// in a ControlMessage.
217    fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
218        let iov = [IoSlice::new(message)];
219        let flags = if self.blocking {
220            socket::MsgFlags::empty()
221        } else {
222            socket::MsgFlags::MSG_DONTWAIT
223        };
224
225        if fds.is_empty() {
226            debug!("{} send empty", self.fd);
227            socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
228                .map_err(|error| ScmSocketError::Send(error.to_string()))?;
229            return Ok(());
230        };
231
232        let control_message = [socket::ControlMessage::ScmRights(fds)];
233        debug!("{} send with data", self.fd);
234        socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
235            .map_err(|error| ScmSocketError::Send(error.to_string()))?;
236        Ok(())
237    }
238
239    /// Parse the message and receives file descriptors separately via the ControlMessage
240    fn receive_msg_and_fds(
241        &self,
242        message: &mut [u8],
243        fds: &mut [RawFd],
244    ) -> Result<(usize, usize), ScmSocketError> {
245        let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
246        let mut iov = [IoSliceMut::new(message)];
247
248        let flags = if self.blocking {
249            socket::MsgFlags::empty()
250        } else {
251            socket::MsgFlags::MSG_DONTWAIT
252        };
253
254        let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
255            .map_err(|error| ScmSocketError::Receive(error.to_string()))?;
256
257        let mut fd_count = 0;
258        let received_fds = msg
259            .cmsgs()
260            .map_err(|error| ScmSocketError::Receive(error.to_string()))?
261            .filter_map(|cmsg| {
262                if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
263                    Some(s)
264                } else {
265                    None
266                }
267            })
268            .flatten();
269        for (fd, place) in received_fds.zip(fds.iter_mut()) {
270            fd_count += 1;
271            *place = fd;
272        }
273        Ok((msg.bytes, fd_count))
274    }
275}
276
277/// Socket addresses and file descriptors of TCP sockets, needed by a Proxy to start listening
278#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
279pub struct Listeners {
280    pub http: Vec<(SocketAddr, RawFd)>,
281    pub tls: Vec<(SocketAddr, RawFd)>,
282    pub tcp: Vec<(SocketAddr, RawFd)>,
283}
284
285impl Listeners {
286    pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
287        self.http
288            .iter()
289            .position(|(front, _)| front == addr)
290            .map(|pos| self.http.remove(pos).1)
291    }
292
293    pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
294        self.tls
295            .iter()
296            .position(|(front, _)| front == addr)
297            .map(|pos| self.tls.remove(pos).1)
298    }
299
300    pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
301        self.tcp
302            .iter()
303            .position(|(front, _)| front == addr)
304            .map(|pos| self.tcp.remove(pos).1)
305    }
306
307    /// Deactivate all listeners by closing their file descriptors
308    pub fn close(&self) {
309        for (_, fd) in &self.http {
310            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
311            // about to be closed by the binding's `Drop` (intentional
312            // close-by-drop). No other reference to the descriptor survives.
313            unsafe {
314                let _ = TcpListener::from_raw_fd(*fd);
315            }
316        }
317
318        for (_, fd) in &self.tls {
319            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
320            // about to be closed by the binding's `Drop` (intentional
321            // close-by-drop). No other reference to the descriptor survives.
322            unsafe {
323                let _ = TcpListener::from_raw_fd(*fd);
324            }
325        }
326
327        for (_, fd) in &self.tcp {
328            // SAFETY: `*fd` is owned by this `ScmListeners` table and is
329            // about to be closed by the binding's `Drop` (intentional
330            // close-by-drop). No other reference to the descriptor survives.
331            unsafe {
332                let _ = TcpListener::from_raw_fd(*fd);
333            }
334        }
335    }
336}
337
338fn parse_addresses(addresses: &[String]) -> Result<Vec<SocketAddr>, ScmSocketError> {
339    let mut parsed_addresses = Vec::new();
340    for address in addresses {
341        parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
342            ScmSocketError::WrongSocketAddress {
343                address: address.to_owned(),
344                error,
345            }
346        })?);
347    }
348    Ok(parsed_addresses)
349}
350
351#[cfg(test)]
352mod tests {
353
354    use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
355
356    use mio::net::UnixStream as MioUnixStream;
357
358    use super::*;
359
360    #[test]
361    fn create_block_unblock_an_scm_socket() {
362        let (nonblocking_stream, _) =
363            MioUnixStream::pair().expect("Could not create a pair of unix streams");
364        let raw_file_descriptor = nonblocking_stream.into_raw_fd();
365
366        let scm_socket = ScmSocket::new(raw_file_descriptor);
367        assert!(scm_socket.is_ok());
368
369        let mut scm_socket = scm_socket.unwrap();
370
371        assert!(scm_socket.set_blocking(true).is_ok());
372        assert!(scm_socket.set_blocking(false).is_ok());
373    }
374
375    fn socket_addr_from_str(str: &str) -> SocketAddr {
376        SocketAddr::from_str(str)
377            .unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
378    }
379
380    #[test]
381    fn send_and_receive_empty_listeners() {
382        let (stream_1, stream_2) =
383            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
384
385        let sending_scm_socket =
386            ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
387
388        let receiving_scm_socket =
389            ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
390
391        let listeners = Listeners::default();
392
393        sending_scm_socket
394            .send_listeners(&listeners)
395            .expect("Could not send listeners");
396
397        let received_listeners = receiving_scm_socket
398            .receive_listeners()
399            .expect("Could not receive listeners");
400
401        assert_eq!(listeners, received_listeners);
402    }
403
404    #[test]
405    fn send_and_receive_socket_addresses() {
406        let (stream_1, stream_2) =
407            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
408
409        println!("unix stream pair: {stream_1:?} and {stream_2:?}");
410        let sending_scm_socket =
411            ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
412
413        println!("sending socket: {sending_scm_socket:?}");
414
415        let receiving_scm_socket =
416            ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
417
418        println!("receiving socket: {receiving_scm_socket:?}");
419
420        // We have to provide actual file descriptors, even if they will all be changed in the takeover
421        let (http_socket1, http_socket2) =
422            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
423        let (tcp_socket1, tcp_socket2) =
424            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
425        let (tls_socket1, tls_socket2) =
426            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
427
428        let listeners = Listeners {
429            http: vec![
430                (
431                    socket_addr_from_str("127.0.1.1:8080"),
432                    http_socket1.as_raw_fd(),
433                ),
434                (
435                    socket_addr_from_str("127.0.1.2:8080"),
436                    http_socket2.as_raw_fd(),
437                ),
438            ],
439            tcp: vec![
440                (
441                    socket_addr_from_str("127.0.2.1:8080"),
442                    tcp_socket1.as_raw_fd(),
443                ),
444                (
445                    socket_addr_from_str("127.0.2.2:8080"),
446                    tcp_socket2.as_raw_fd(),
447                ),
448            ],
449            tls: vec![
450                (
451                    socket_addr_from_str("127.0.3.1:8443"),
452                    tls_socket1.as_raw_fd(),
453                ),
454                (
455                    socket_addr_from_str("127.0.3.2:8443"),
456                    tls_socket2.as_raw_fd(),
457                ),
458            ],
459        };
460
461        println!("self.fd: {}", sending_scm_socket.fd);
462        println!("listeners to send: {listeners:#?}");
463
464        sending_scm_socket
465            .send_listeners(&listeners)
466            .expect("Could not send listeners");
467
468        let received_listeners = receiving_scm_socket
469            .receive_listeners()
470            .expect("Could not receive listeners");
471
472        assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
473    }
474
475    /// Regression: a malformed `ListenersCount` whose entry counts do not
476    /// match the number of file descriptors received over SCM must be
477    /// rejected with `ListenersCountInconsistent`, never panic the worker
478    /// on `received_fds[index..index + len]`.
479    ///
480    /// Without the bounds check, a peer that declares more addresses than
481    /// `MAX_FDS_OUT` (or more than the FDs that actually arrived) crashes
482    /// the receiving worker on out-of-bounds array indexing.
483    #[test]
484    fn rejects_listeners_count_with_more_entries_than_fds() {
485        let (stream_1, stream_2) =
486            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
487        let sender = ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
488        let receiver = ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
489
490        // Declare three HTTP entries but ship zero file descriptors.
491        let bogus = ListenersCount {
492            http: vec![
493                "127.0.0.1:80".to_string(),
494                "127.0.0.2:80".to_string(),
495                "127.0.0.3:80".to_string(),
496            ],
497            tls: vec![],
498            tcp: vec![],
499        };
500        let payload = bogus.encode_length_delimited_to_vec();
501        sender
502            .send_msg_and_fds(&payload, &[])
503            .expect("manual send_msg_and_fds with zero fds must succeed at the syscall layer");
504
505        match receiver.receive_listeners() {
506            Err(ScmSocketError::ListenersCountInconsistent {
507                http,
508                tls,
509                tcp,
510                total,
511                fds_received,
512                max_fds,
513            }) => {
514                assert_eq!(http, 3);
515                assert_eq!(tls, 0);
516                assert_eq!(tcp, 0);
517                assert_eq!(total, 3);
518                assert_eq!(fds_received, 0);
519                assert_eq!(max_fds, MAX_FDS_OUT);
520            }
521            other => panic!(
522                "expected ListenersCountInconsistent, got {other:?}\n\
523                 NOTE: a panic / OOM here means the SCM bounds check was reverted",
524            ),
525        }
526    }
527}