1use std::{
11 io::{IoSlice, IoSliceMut},
12 net::{AddrParseError, SocketAddr},
13 os::unix::{
14 io::{FromRawFd, IntoRawFd, RawFd},
15 net::UnixStream as StdUnixStream,
16 },
17};
18
19use mio::net::TcpListener;
20use nix::{cmsg_space, sys::socket};
21use prost::{DecodeError, Message};
22
23use crate::proto::command::ListenersCount;
24
25pub const MAX_FDS_OUT: usize = 200;
26pub const MAX_BYTES_OUT: usize = 4096;
27
28#[derive(thiserror::Error, Debug)]
29pub enum ScmSocketError {
30 #[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
31 SetBlocking {
32 blocking: bool,
33 error: std::io::Error,
34 },
35 #[error("could not send message for SCM socket: {0}")]
36 Send(String),
37 #[error("could not receive message for SCM socket: {0}")]
38 Receive(String),
39 #[error("invalid char set: {0}")]
40 InvalidCharSet(String),
41 #[error("Could not deserialize utf8 string into listeners: {0}")]
42 ListenerParse(String),
43 #[error("Wrong socket address {address}: {error}")]
44 WrongSocketAddress {
45 address: String,
46 error: AddrParseError,
47 },
48 #[error("error decoding the protobuf format of the listeners: {0}")]
49 DecodeError(DecodeError),
50 #[error(
51 "listeners count manifest is inconsistent with the SCM payload: \
52 http={http}, tls={tls}, tcp={tcp} (sum={total}), fds_received={fds_received}, max_fds={max_fds}"
53 )]
54 ListenersCountInconsistent {
55 http: usize,
56 tls: usize,
57 tcp: usize,
58 total: usize,
59 fds_received: usize,
60 max_fds: usize,
61 },
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ScmSocket {
67 pub fd: RawFd,
68 pub blocking: bool,
69}
70
71impl ScmSocket {
72 pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
74 unsafe {
79 let stream = StdUnixStream::from_raw_fd(fd);
80 stream
81 .set_nonblocking(false)
82 .map_err(|error| ScmSocketError::SetBlocking {
83 blocking: false,
84 error,
85 })?;
86 let _dropped_fd = stream.into_raw_fd();
87 }
88
89 Ok(ScmSocket { fd, blocking: true })
90 }
91
92 pub fn raw_fd(&self) -> i32 {
94 self.fd
95 }
96
97 pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
99 if self.blocking == blocking {
100 return Ok(());
101 }
102 unsafe {
107 let stream = StdUnixStream::from_raw_fd(self.fd);
108 stream
109 .set_nonblocking(!blocking)
110 .map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
111 let _dropped_fd = stream.into_raw_fd();
112 }
113 self.blocking = blocking;
114 Ok(())
115 }
116
117 pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
119 let listeners_count = ListenersCount {
120 http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
121 tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
122 tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
123 };
124
125 let message = listeners_count.encode_length_delimited_to_vec();
126
127 let mut file_descriptors: Vec<RawFd> = Vec::new();
128
129 file_descriptors.extend(listeners.http.iter().map(|t| t.1));
130 file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
131 file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
132
133 self.send_msg_and_fds(&message, &file_descriptors)
134 }
135
136 pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
138 let mut buf = vec![0; MAX_BYTES_OUT];
139
140 let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
141
142 let (size, file_descriptor_length) =
143 self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
144
145 debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
146
147 let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
148 .map_err(ScmSocketError::DecodeError)?;
149
150 let http_len = listeners_count.http.len();
157 let tls_len = listeners_count.tls.len();
158 let tcp_len = listeners_count.tcp.len();
159 let total = http_len
160 .checked_add(tls_len)
161 .and_then(|s| s.checked_add(tcp_len))
162 .ok_or(ScmSocketError::ListenersCountInconsistent {
163 http: http_len,
164 tls: tls_len,
165 tcp: tcp_len,
166 total: usize::MAX,
167 fds_received: file_descriptor_length,
168 max_fds: MAX_FDS_OUT,
169 })?;
170 if total > MAX_FDS_OUT || total > file_descriptor_length {
171 return Err(ScmSocketError::ListenersCountInconsistent {
172 http: http_len,
173 tls: tls_len,
174 tcp: tcp_len,
175 total,
176 fds_received: file_descriptor_length,
177 max_fds: MAX_FDS_OUT,
178 });
179 }
180
181 let mut http_addresses = parse_addresses(&listeners_count.http)?;
182 let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
183 let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;
184
185 let mut index = 0;
186 let len = http_len;
187 let mut http = Vec::new();
188 http.extend(
189 http_addresses
190 .drain(..)
191 .zip(received_fds[index..index + len].iter().cloned()),
192 );
193
194 index += len;
195 let len = tls_len;
196 let mut tls = Vec::new();
197 tls.extend(
198 tls_addresses
199 .drain(..)
200 .zip(received_fds[index..index + len].iter().cloned()),
201 );
202
203 index += len;
204 let len = tcp_len;
205 let mut tcp = Vec::new();
206 tcp.extend(
207 tcp_addresses
208 .drain(..)
209 .zip(received_fds[index..index + len].iter().cloned()),
210 );
211
212 Ok(Listeners { http, tls, tcp })
213 }
214
215 fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
218 let iov = [IoSlice::new(message)];
219 let flags = if self.blocking {
220 socket::MsgFlags::empty()
221 } else {
222 socket::MsgFlags::MSG_DONTWAIT
223 };
224
225 if fds.is_empty() {
226 debug!("{} send empty", self.fd);
227 socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
228 .map_err(|error| ScmSocketError::Send(error.to_string()))?;
229 return Ok(());
230 };
231
232 let control_message = [socket::ControlMessage::ScmRights(fds)];
233 debug!("{} send with data", self.fd);
234 socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
235 .map_err(|error| ScmSocketError::Send(error.to_string()))?;
236 Ok(())
237 }
238
239 fn receive_msg_and_fds(
241 &self,
242 message: &mut [u8],
243 fds: &mut [RawFd],
244 ) -> Result<(usize, usize), ScmSocketError> {
245 let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
246 let mut iov = [IoSliceMut::new(message)];
247
248 let flags = if self.blocking {
249 socket::MsgFlags::empty()
250 } else {
251 socket::MsgFlags::MSG_DONTWAIT
252 };
253
254 let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
255 .map_err(|error| ScmSocketError::Receive(error.to_string()))?;
256
257 let mut fd_count = 0;
258 let received_fds = msg
259 .cmsgs()
260 .map_err(|error| ScmSocketError::Receive(error.to_string()))?
261 .filter_map(|cmsg| {
262 if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
263 Some(s)
264 } else {
265 None
266 }
267 })
268 .flatten();
269 for (fd, place) in received_fds.zip(fds.iter_mut()) {
270 fd_count += 1;
271 *place = fd;
272 }
273 Ok((msg.bytes, fd_count))
274 }
275}
276
277#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
279pub struct Listeners {
280 pub http: Vec<(SocketAddr, RawFd)>,
281 pub tls: Vec<(SocketAddr, RawFd)>,
282 pub tcp: Vec<(SocketAddr, RawFd)>,
283}
284
285impl Listeners {
286 pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
287 self.http
288 .iter()
289 .position(|(front, _)| front == addr)
290 .map(|pos| self.http.remove(pos).1)
291 }
292
293 pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
294 self.tls
295 .iter()
296 .position(|(front, _)| front == addr)
297 .map(|pos| self.tls.remove(pos).1)
298 }
299
300 pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
301 self.tcp
302 .iter()
303 .position(|(front, _)| front == addr)
304 .map(|pos| self.tcp.remove(pos).1)
305 }
306
307 pub fn close(&self) {
309 for (_, fd) in &self.http {
310 unsafe {
314 let _ = TcpListener::from_raw_fd(*fd);
315 }
316 }
317
318 for (_, fd) in &self.tls {
319 unsafe {
323 let _ = TcpListener::from_raw_fd(*fd);
324 }
325 }
326
327 for (_, fd) in &self.tcp {
328 unsafe {
332 let _ = TcpListener::from_raw_fd(*fd);
333 }
334 }
335 }
336}
337
338fn parse_addresses(addresses: &[String]) -> Result<Vec<SocketAddr>, ScmSocketError> {
339 let mut parsed_addresses = Vec::new();
340 for address in addresses {
341 parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
342 ScmSocketError::WrongSocketAddress {
343 address: address.to_owned(),
344 error,
345 }
346 })?);
347 }
348 Ok(parsed_addresses)
349}
350
351#[cfg(test)]
352mod tests {
353
354 use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
355
356 use mio::net::UnixStream as MioUnixStream;
357
358 use super::*;
359
360 #[test]
361 fn create_block_unblock_an_scm_socket() {
362 let (nonblocking_stream, _) =
363 MioUnixStream::pair().expect("Could not create a pair of unix streams");
364 let raw_file_descriptor = nonblocking_stream.into_raw_fd();
365
366 let scm_socket = ScmSocket::new(raw_file_descriptor);
367 assert!(scm_socket.is_ok());
368
369 let mut scm_socket = scm_socket.unwrap();
370
371 assert!(scm_socket.set_blocking(true).is_ok());
372 assert!(scm_socket.set_blocking(false).is_ok());
373 }
374
375 fn socket_addr_from_str(str: &str) -> SocketAddr {
376 SocketAddr::from_str(str)
377 .unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
378 }
379
380 #[test]
381 fn send_and_receive_empty_listeners() {
382 let (stream_1, stream_2) =
383 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
384
385 let sending_scm_socket =
386 ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
387
388 let receiving_scm_socket =
389 ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
390
391 let listeners = Listeners::default();
392
393 sending_scm_socket
394 .send_listeners(&listeners)
395 .expect("Could not send listeners");
396
397 let received_listeners = receiving_scm_socket
398 .receive_listeners()
399 .expect("Could not receive listeners");
400
401 assert_eq!(listeners, received_listeners);
402 }
403
404 #[test]
405 fn send_and_receive_socket_addresses() {
406 let (stream_1, stream_2) =
407 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
408
409 println!("unix stream pair: {stream_1:?} and {stream_2:?}");
410 let sending_scm_socket =
411 ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
412
413 println!("sending socket: {sending_scm_socket:?}");
414
415 let receiving_scm_socket =
416 ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
417
418 println!("receiving socket: {receiving_scm_socket:?}");
419
420 let (http_socket1, http_socket2) =
422 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
423 let (tcp_socket1, tcp_socket2) =
424 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
425 let (tls_socket1, tls_socket2) =
426 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
427
428 let listeners = Listeners {
429 http: vec![
430 (
431 socket_addr_from_str("127.0.1.1:8080"),
432 http_socket1.as_raw_fd(),
433 ),
434 (
435 socket_addr_from_str("127.0.1.2:8080"),
436 http_socket2.as_raw_fd(),
437 ),
438 ],
439 tcp: vec![
440 (
441 socket_addr_from_str("127.0.2.1:8080"),
442 tcp_socket1.as_raw_fd(),
443 ),
444 (
445 socket_addr_from_str("127.0.2.2:8080"),
446 tcp_socket2.as_raw_fd(),
447 ),
448 ],
449 tls: vec![
450 (
451 socket_addr_from_str("127.0.3.1:8443"),
452 tls_socket1.as_raw_fd(),
453 ),
454 (
455 socket_addr_from_str("127.0.3.2:8443"),
456 tls_socket2.as_raw_fd(),
457 ),
458 ],
459 };
460
461 println!("self.fd: {}", sending_scm_socket.fd);
462 println!("listeners to send: {listeners:#?}");
463
464 sending_scm_socket
465 .send_listeners(&listeners)
466 .expect("Could not send listeners");
467
468 let received_listeners = receiving_scm_socket
469 .receive_listeners()
470 .expect("Could not receive listeners");
471
472 assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
473 }
474
475 #[test]
484 fn rejects_listeners_count_with_more_entries_than_fds() {
485 let (stream_1, stream_2) =
486 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
487 let sender = ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
488 let receiver = ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
489
490 let bogus = ListenersCount {
492 http: vec![
493 "127.0.0.1:80".to_string(),
494 "127.0.0.2:80".to_string(),
495 "127.0.0.3:80".to_string(),
496 ],
497 tls: vec![],
498 tcp: vec![],
499 };
500 let payload = bogus.encode_length_delimited_to_vec();
501 sender
502 .send_msg_and_fds(&payload, &[])
503 .expect("manual send_msg_and_fds with zero fds must succeed at the syscall layer");
504
505 match receiver.receive_listeners() {
506 Err(ScmSocketError::ListenersCountInconsistent {
507 http,
508 tls,
509 tcp,
510 total,
511 fds_received,
512 max_fds,
513 }) => {
514 assert_eq!(http, 3);
515 assert_eq!(tls, 0);
516 assert_eq!(tcp, 0);
517 assert_eq!(total, 3);
518 assert_eq!(fds_received, 0);
519 assert_eq!(max_fds, MAX_FDS_OUT);
520 }
521 other => panic!(
522 "expected ListenersCountInconsistent, got {other:?}\n\
523 NOTE: a panic / OOM here means the SCM bounds check was reverted",
524 ),
525 }
526 }
527}