sozu_command_lib/
scm_socket.rs

1use std::{
2    io::{IoSlice, IoSliceMut},
3    net::{AddrParseError, SocketAddr},
4    os::unix::{
5        io::{FromRawFd, IntoRawFd, RawFd},
6        net::UnixStream as StdUnixStream,
7    },
8};
9
10use mio::net::TcpListener;
11use nix::{cmsg_space, sys::socket};
12use prost::{DecodeError, Message};
13
14use crate::proto::command::ListenersCount;
15
16pub const MAX_FDS_OUT: usize = 200;
17pub const MAX_BYTES_OUT: usize = 4096;
18
19#[derive(thiserror::Error, Debug)]
20pub enum ScmSocketError {
21    #[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
22    SetBlocking {
23        blocking: bool,
24        error: std::io::Error,
25    },
26    #[error("could not send message for SCM socket: {0}")]
27    Send(String),
28    #[error("could not receive message for SCM socket: {0}")]
29    Receive(String),
30    #[error("invalid char set: {0}")]
31    InvalidCharSet(String),
32    #[error("Could not deserialize utf8 string into listeners: {0}")]
33    ListenerParse(String),
34    #[error("Wrong socket address {address}: {error}")]
35    WrongSocketAddress {
36        address: String,
37        error: AddrParseError,
38    },
39    #[error("error decoding the protobuf format of the listeners: {0}")]
40    DecodeError(DecodeError),
41}
42
43/// A unix socket specialized for file descriptor passing
44#[derive(Clone, Debug, Serialize, Deserialize)]
45pub struct ScmSocket {
46    pub fd: RawFd,
47    pub blocking: bool,
48}
49
50impl ScmSocket {
51    /// Create a blocking SCM socket from a raw file descriptor (unsafe)
52    pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
53        unsafe {
54            let stream = StdUnixStream::from_raw_fd(fd);
55            stream
56                .set_nonblocking(false)
57                .map_err(|error| ScmSocketError::SetBlocking {
58                    blocking: false,
59                    error,
60                })?;
61            let _dropped_fd = stream.into_raw_fd();
62        }
63
64        Ok(ScmSocket { fd, blocking: true })
65    }
66
67    /// Get the raw file descriptor of the scm channel
68    pub fn raw_fd(&self) -> i32 {
69        self.fd
70    }
71
72    /// Use the standard library (unsafe) to set the socket to blocking / unblocking
73    pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
74        if self.blocking == blocking {
75            return Ok(());
76        }
77        unsafe {
78            let stream = StdUnixStream::from_raw_fd(self.fd);
79            stream
80                .set_nonblocking(!blocking)
81                .map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
82            let _dropped_fd = stream.into_raw_fd();
83        }
84        self.blocking = blocking;
85        Ok(())
86    }
87
88    /// Send listeners (socket addresses and file descriptors) via an scm socket
89    pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
90        let listeners_count = ListenersCount {
91            http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
92            tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
93            tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
94        };
95
96        let message = listeners_count.encode_length_delimited_to_vec();
97
98        let mut file_descriptors: Vec<RawFd> = Vec::new();
99
100        file_descriptors.extend(listeners.http.iter().map(|t| t.1));
101        file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
102        file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
103
104        self.send_msg_and_fds(&message, &file_descriptors)
105    }
106
107    /// Receive and parse listeners (socket addresses and file descriptors) via an scm socket
108    pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
109        let mut buf = vec![0; MAX_BYTES_OUT];
110
111        let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
112
113        let (size, file_descriptor_length) =
114            self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
115
116        debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
117
118        let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
119            .map_err(ScmSocketError::DecodeError)?;
120
121        let mut http_addresses = parse_addresses(&listeners_count.http)?;
122        let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
123        let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;
124
125        let mut index = 0;
126        let len = listeners_count.http.len();
127        let mut http = Vec::new();
128        http.extend(
129            http_addresses
130                .drain(..)
131                .zip(received_fds[index..index + len].iter().cloned()),
132        );
133
134        index += len;
135        let len = listeners_count.tls.len();
136        let mut tls = Vec::new();
137        tls.extend(
138            tls_addresses
139                .drain(..)
140                .zip(received_fds[index..index + len].iter().cloned()),
141        );
142
143        index += len;
144        let mut tcp = Vec::new();
145        tcp.extend(
146            tcp_addresses
147                .drain(..)
148                .zip(received_fds[index..file_descriptor_length].iter().cloned()),
149        );
150
151        Ok(Listeners { http, tls, tcp })
152    }
153
154    /// Sends message and file descriptors separately. The file descriptors are summed up
155    /// in a ControlMessage.
156    fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
157        let iov = [IoSlice::new(message)];
158        let flags = if self.blocking {
159            socket::MsgFlags::empty()
160        } else {
161            socket::MsgFlags::MSG_DONTWAIT
162        };
163
164        if fds.is_empty() {
165            debug!("{} send empty", self.fd);
166            socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
167                .map_err(|error| ScmSocketError::Send(error.to_string()))?;
168            return Ok(());
169        };
170
171        let control_message = [socket::ControlMessage::ScmRights(fds)];
172        debug!("{} send with data", self.fd);
173        socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
174            .map_err(|error| ScmSocketError::Send(error.to_string()))?;
175        Ok(())
176    }
177
178    /// Parse the message and receives file descriptors separately via the ControlMessage
179    fn receive_msg_and_fds(
180        &self,
181        message: &mut [u8],
182        fds: &mut [RawFd],
183    ) -> Result<(usize, usize), ScmSocketError> {
184        let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
185        let mut iov = [IoSliceMut::new(message)];
186
187        let flags = if self.blocking {
188            socket::MsgFlags::empty()
189        } else {
190            socket::MsgFlags::MSG_DONTWAIT
191        };
192
193        let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
194            .map_err(|error| ScmSocketError::Receive(error.to_string()))?;
195
196        let mut fd_count = 0;
197        let received_fds = msg
198            .cmsgs()
199            .map_err(|error| ScmSocketError::Receive(error.to_string()))?
200            .filter_map(|cmsg| {
201                if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
202                    Some(s)
203                } else {
204                    None
205                }
206            })
207            .flatten();
208        for (fd, place) in received_fds.zip(fds.iter_mut()) {
209            fd_count += 1;
210            *place = fd;
211        }
212        Ok((msg.bytes, fd_count))
213    }
214}
215
216/// Socket addresses and file descriptors of TCP sockets, needed by a Proxy to start listening
217#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
218pub struct Listeners {
219    pub http: Vec<(SocketAddr, RawFd)>,
220    pub tls: Vec<(SocketAddr, RawFd)>,
221    pub tcp: Vec<(SocketAddr, RawFd)>,
222}
223
224impl Listeners {
225    pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
226        self.http
227            .iter()
228            .position(|(front, _)| front == addr)
229            .map(|pos| self.http.remove(pos).1)
230    }
231
232    pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
233        self.tls
234            .iter()
235            .position(|(front, _)| front == addr)
236            .map(|pos| self.tls.remove(pos).1)
237    }
238
239    pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
240        self.tcp
241            .iter()
242            .position(|(front, _)| front == addr)
243            .map(|pos| self.tcp.remove(pos).1)
244    }
245
246    /// Deactivate all listeners by closing their file descriptors
247    pub fn close(&self) {
248        for (_, ref fd) in &self.http {
249            unsafe {
250                let _ = TcpListener::from_raw_fd(*fd);
251            }
252        }
253
254        for (_, ref fd) in &self.tls {
255            unsafe {
256                let _ = TcpListener::from_raw_fd(*fd);
257            }
258        }
259
260        for (_, ref fd) in &self.tcp {
261            unsafe {
262                let _ = TcpListener::from_raw_fd(*fd);
263            }
264        }
265    }
266}
267
268fn parse_addresses(addresses: &[String]) -> Result<Vec<SocketAddr>, ScmSocketError> {
269    let mut parsed_addresses = Vec::new();
270    for address in addresses {
271        parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
272            ScmSocketError::WrongSocketAddress {
273                address: address.to_owned(),
274                error,
275            }
276        })?);
277    }
278    Ok(parsed_addresses)
279}
280
281#[cfg(test)]
282mod tests {
283
284    use super::*;
285    use mio::net::UnixStream as MioUnixStream;
286    use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
287
288    #[test]
289    fn create_block_unblock_an_scm_socket() {
290        let (nonblocking_stream, _) =
291            MioUnixStream::pair().expect("Could not create a pair of unix streams");
292        let raw_file_descriptor = nonblocking_stream.into_raw_fd();
293
294        let scm_socket = ScmSocket::new(raw_file_descriptor);
295        assert!(scm_socket.is_ok());
296
297        let mut scm_socket = scm_socket.unwrap();
298
299        assert!(scm_socket.set_blocking(true).is_ok());
300        assert!(scm_socket.set_blocking(false).is_ok());
301    }
302
303    fn socket_addr_from_str(str: &str) -> SocketAddr {
304        SocketAddr::from_str(str)
305            .unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
306    }
307
308    #[test]
309    fn send_and_receive_empty_listeners() {
310        let (stream_1, stream_2) =
311            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
312
313        let sending_scm_socket =
314            ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
315
316        let receiving_scm_socket =
317            ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
318
319        let listeners = Listeners::default();
320
321        sending_scm_socket
322            .send_listeners(&listeners)
323            .expect("Could not send listeners");
324
325        let received_listeners = receiving_scm_socket
326            .receive_listeners()
327            .expect("Could not receive listeners");
328
329        assert_eq!(listeners, received_listeners);
330    }
331
332    #[test]
333    fn send_and_receive_socket_addresses() {
334        let (stream_1, stream_2) =
335            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
336
337        println!("unix stream pair: {stream_1:?} and {stream_2:?}");
338        let sending_scm_socket =
339            ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
340
341        println!("sending socket: {sending_scm_socket:?}");
342
343        let receiving_scm_socket =
344            ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
345
346        println!("receiving socket: {receiving_scm_socket:?}");
347
348        // We have to provide actual file descriptors, even if they will all be changed in the takeover
349        let (http_socket1, http_socket2) =
350            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
351        let (tcp_socket1, tcp_socket2) =
352            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
353        let (tls_socket1, tls_socket2) =
354            MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
355
356        let listeners = Listeners {
357            http: vec![
358                (
359                    socket_addr_from_str("127.0.1.1:8080"),
360                    http_socket1.as_raw_fd(),
361                ),
362                (
363                    socket_addr_from_str("127.0.1.2:8080"),
364                    http_socket2.as_raw_fd(),
365                ),
366            ],
367            tcp: vec![
368                (
369                    socket_addr_from_str("127.0.2.1:8080"),
370                    tcp_socket1.as_raw_fd(),
371                ),
372                (
373                    socket_addr_from_str("127.0.2.2:8080"),
374                    tcp_socket2.as_raw_fd(),
375                ),
376            ],
377            tls: vec![
378                (
379                    socket_addr_from_str("127.0.3.1:8443"),
380                    tls_socket1.as_raw_fd(),
381                ),
382                (
383                    socket_addr_from_str("127.0.3.2:8443"),
384                    tls_socket2.as_raw_fd(),
385                ),
386            ],
387        };
388
389        println!("self.fd: {}", sending_scm_socket.fd);
390        println!("listeners to send: {listeners:#?}");
391
392        sending_scm_socket
393            .send_listeners(&listeners)
394            .expect("Could not send listeners");
395
396        let received_listeners = receiving_scm_socket
397            .receive_listeners()
398            .expect("Could not receive listeners");
399
400        assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
401    }
402}