1use std::{
11 io::{IoSlice, IoSliceMut},
12 net::{AddrParseError, SocketAddr},
13 os::unix::{
14 io::{FromRawFd, IntoRawFd, RawFd},
15 net::UnixStream as StdUnixStream,
16 },
17};
18
19use mio::net::{TcpListener, UdpSocket};
20use nix::{cmsg_space, sys::socket};
21use prost::{DecodeError, Message};
22
23use crate::proto::command::ListenersCount;
24
25pub const MAX_FDS_OUT: usize = 200;
26pub const MAX_BYTES_OUT: usize = 4096;
27
28#[derive(thiserror::Error, Debug)]
29pub enum ScmSocketError {
30 #[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
31 SetBlocking {
32 blocking: bool,
33 error: std::io::Error,
34 },
35 #[error("could not send message for SCM socket: {0}")]
36 Send(String),
37 #[error("could not receive message for SCM socket: {0}")]
38 Receive(String),
39 #[error("invalid char set: {0}")]
40 InvalidCharSet(String),
41 #[error("Could not deserialize utf8 string into listeners: {0}")]
42 ListenerParse(String),
43 #[error("Wrong socket address {address}: {error}")]
44 WrongSocketAddress {
45 address: String,
46 error: AddrParseError,
47 },
48 #[error("error decoding the protobuf format of the listeners: {0}")]
49 DecodeError(DecodeError),
50 #[error(
51 "listeners count manifest is inconsistent with the SCM payload: \
52 http={http}, tls={tls}, tcp={tcp} (sum={total}), fds_received={fds_received}, max_fds={max_fds}"
53 )]
54 ListenersCountInconsistent {
55 http: usize,
56 tls: usize,
57 tcp: usize,
58 total: usize,
59 fds_received: usize,
60 max_fds: usize,
61 },
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct ScmSocket {
67 pub fd: RawFd,
68 pub blocking: bool,
69}
70
71impl ScmSocket {
72 pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
74 unsafe {
79 let stream = StdUnixStream::from_raw_fd(fd);
80 stream
81 .set_nonblocking(false)
82 .map_err(|error| ScmSocketError::SetBlocking {
83 blocking: false,
84 error,
85 })?;
86 let _dropped_fd = stream.into_raw_fd();
87 }
88
89 Ok(ScmSocket { fd, blocking: true })
90 }
91
92 pub fn raw_fd(&self) -> i32 {
94 self.fd
95 }
96
97 pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
99 if self.blocking == blocking {
100 return Ok(());
101 }
102 debug_assert_ne!(
104 self.blocking, blocking,
105 "set_blocking only reaches the syscall when the state actually changes"
106 );
107 let blocking_before = self.blocking;
108 unsafe {
113 let stream = StdUnixStream::from_raw_fd(self.fd);
114 stream
115 .set_nonblocking(!blocking)
116 .map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
117 let _dropped_fd = stream.into_raw_fd();
118 }
119 self.blocking = blocking;
120 debug_assert_eq!(
122 self.blocking, blocking,
123 "blocking flag must reflect the requested state after a successful set"
124 );
125 debug_assert_ne!(
126 self.blocking, blocking_before,
127 "blocking flag must have toggled across a real state change"
128 );
129 Ok(())
130 }
131
132 pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
134 let listeners_count = ListenersCount {
135 http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
136 tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
137 tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
138 udp: listeners.udp.iter().map(|t| t.0.to_string()).collect(),
139 };
140
141 debug_assert_eq!(
147 listeners_count.http.len(),
148 listeners.http.len(),
149 "http manifest count must match the http listener table"
150 );
151 debug_assert_eq!(
152 listeners_count.tls.len(),
153 listeners.tls.len(),
154 "tls manifest count must match the tls listener table"
155 );
156 debug_assert_eq!(
157 listeners_count.tcp.len(),
158 listeners.tcp.len(),
159 "tcp manifest count must match the tcp listener table"
160 );
161
162 let message = listeners_count.encode_length_delimited_to_vec();
163
164 let mut file_descriptors: Vec<RawFd> = Vec::new();
165
166 file_descriptors.extend(listeners.http.iter().map(|t| t.1));
167 file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
168 file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
169 file_descriptors.extend(listeners.udp.iter().map(|t| t.1));
170
171 let address_total =
175 listeners.http.len() + listeners.tls.len() + listeners.tcp.len() + listeners.udp.len();
176 debug_assert_eq!(
177 file_descriptors.len(),
178 address_total,
179 "the FD count sent must equal the total listener-address count (one FD per address)"
180 );
181
182 self.send_msg_and_fds(&message, &file_descriptors)
183 }
184
185 pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
187 let mut buf = vec![0; MAX_BYTES_OUT];
188
189 let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
190
191 let (size, file_descriptor_length) =
192 self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
193
194 debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
195
196 let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
197 .map_err(ScmSocketError::DecodeError)?;
198
199 let http_len = listeners_count.http.len();
208 let tls_len = listeners_count.tls.len();
209 let tcp_len = listeners_count.tcp.len();
210 let udp_len = listeners_count.udp.len();
211 let total = http_len
212 .checked_add(tls_len)
213 .and_then(|s| s.checked_add(tcp_len))
214 .and_then(|s| s.checked_add(udp_len))
215 .ok_or(ScmSocketError::ListenersCountInconsistent {
216 http: http_len,
217 tls: tls_len,
218 tcp: tcp_len.saturating_add(udp_len),
219 total: usize::MAX,
220 fds_received: file_descriptor_length,
221 max_fds: MAX_FDS_OUT,
222 })?;
223 if total > MAX_FDS_OUT || total > file_descriptor_length {
224 return Err(ScmSocketError::ListenersCountInconsistent {
225 http: http_len,
226 tls: tls_len,
227 tcp: tcp_len.saturating_add(udp_len),
228 total,
229 fds_received: file_descriptor_length,
230 max_fds: MAX_FDS_OUT,
231 });
232 }
233
234 debug_assert_eq!(
240 total,
241 http_len + tls_len + tcp_len + udp_len,
242 "folded total must equal the sum of per-protocol counts"
243 );
244 debug_assert!(
245 total <= MAX_FDS_OUT,
246 "total FD slots must fit the fixed-size received_fds array before indexing"
247 );
248 debug_assert!(
249 total <= file_descriptor_length,
250 "manifest total must not exceed the FDs actually received"
251 );
252 debug_assert!(
253 total <= received_fds.len(),
254 "every (address, fd) zip below must stay within the received_fds array"
255 );
256
257 let mut http_addresses = parse_addresses(&listeners_count.http)?;
258 let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
259 let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;
260 let mut udp_addresses = parse_addresses(&listeners_count.udp)?;
261
262 debug_assert_eq!(
265 http_addresses.len(),
266 http_len,
267 "parsed http address count must match the manifest count"
268 );
269 debug_assert_eq!(
270 tls_addresses.len(),
271 tls_len,
272 "parsed tls address count must match the manifest count"
273 );
274 debug_assert_eq!(
275 tcp_addresses.len(),
276 tcp_len,
277 "parsed tcp address count must match the manifest count"
278 );
279
280 let mut index = 0;
281 let len = http_len;
282 debug_assert!(
285 index + len <= total,
286 "http FD slice must lie within the reconciled total"
287 );
288 let mut http = Vec::new();
289 http.extend(
290 http_addresses
291 .drain(..)
292 .zip(received_fds[index..index + len].iter().cloned()),
293 );
294 debug_assert_eq!(
296 http.len(),
297 http_len,
298 "every http address must be paired with exactly one FD"
299 );
300
301 index += len;
302 let len = tls_len;
303 debug_assert!(
304 index + len <= total,
305 "tls FD slice must lie within the reconciled total"
306 );
307 let mut tls = Vec::new();
308 tls.extend(
309 tls_addresses
310 .drain(..)
311 .zip(received_fds[index..index + len].iter().cloned()),
312 );
313 debug_assert_eq!(
314 tls.len(),
315 tls_len,
316 "every tls address must be paired with exactly one FD"
317 );
318
319 index += len;
320 let len = tcp_len;
321 debug_assert!(
322 index + len <= total,
323 "tcp FD slice must lie within the reconciled total"
324 );
325 let mut tcp = Vec::new();
326 tcp.extend(
327 tcp_addresses
328 .drain(..)
329 .zip(received_fds[index..index + len].iter().cloned()),
330 );
331 debug_assert_eq!(
332 tcp.len(),
333 tcp_len,
334 "every tcp address must be paired with exactly one FD"
335 );
336
337 index += len;
338 let len = udp_len;
339 let mut udp = Vec::new();
340 udp.extend(
341 udp_addresses
342 .drain(..)
343 .zip(received_fds[index..index + len].iter().cloned()),
344 );
345 debug_assert_eq!(
346 udp.len(),
347 udp_len,
348 "every udp address must be paired with exactly one FD"
349 );
350
351 debug_assert_eq!(
354 index + len,
355 total,
356 "the (address, fd) zips must consume exactly the reconciled total of FD slots"
357 );
358 debug_assert_eq!(
359 http.len() + tls.len() + tcp.len() + udp.len(),
360 total,
361 "reconstructed listener count must equal the reconciled FD total"
362 );
363
364 Ok(Listeners {
365 http,
366 tls,
367 tcp,
368 udp,
369 })
370 }
371
372 fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
375 let iov = [IoSlice::new(message)];
376 let flags = if self.blocking {
377 socket::MsgFlags::empty()
378 } else {
379 socket::MsgFlags::MSG_DONTWAIT
380 };
381
382 if fds.is_empty() {
383 debug!("{} send empty", self.fd);
384 socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
385 .map_err(|error| ScmSocketError::Send(error.to_string()))?;
386 return Ok(());
387 };
388
389 let control_message = [socket::ControlMessage::ScmRights(fds)];
390 debug!("{} send with data", self.fd);
391 socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
392 .map_err(|error| ScmSocketError::Send(error.to_string()))?;
393 Ok(())
394 }
395
396 fn receive_msg_and_fds(
398 &self,
399 message: &mut [u8],
400 fds: &mut [RawFd],
401 ) -> Result<(usize, usize), ScmSocketError> {
402 let message_capacity = message.len();
405 let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
406 let mut iov = [IoSliceMut::new(message)];
407
408 let flags = if self.blocking {
409 socket::MsgFlags::empty()
410 } else {
411 socket::MsgFlags::MSG_DONTWAIT
412 };
413
414 let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
415 .map_err(|error| ScmSocketError::Receive(error.to_string()))?;
416
417 let fds_capacity = fds.len();
420 debug_assert!(
421 fds_capacity <= MAX_FDS_OUT,
422 "destination FD slice must not exceed the MAX_FDS_OUT cmsg space"
423 );
424 let mut fd_count = 0;
425 let received_fds = msg
426 .cmsgs()
427 .map_err(|error| ScmSocketError::Receive(error.to_string()))?
428 .filter_map(|cmsg| {
429 if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
430 Some(s)
431 } else {
432 None
433 }
434 })
435 .flatten();
436 for (fd, place) in received_fds.zip(fds.iter_mut()) {
437 fd_count += 1;
438 *place = fd;
439 debug_assert!(
442 fd_count <= fds_capacity,
443 "received FD count must never exceed the destination array capacity"
444 );
445 }
446 debug_assert!(
449 fd_count <= fds_capacity,
450 "final received FD count must fit the destination array"
451 );
452 debug_assert!(
453 msg.bytes <= message_capacity,
454 "received byte count must not exceed the message buffer it was read into"
455 );
456 Ok((msg.bytes, fd_count))
457 }
458}
459
460#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
464pub struct Listeners {
465 pub http: Vec<(SocketAddr, RawFd)>,
466 pub tls: Vec<(SocketAddr, RawFd)>,
467 pub tcp: Vec<(SocketAddr, RawFd)>,
468 #[serde(default)]
469 pub udp: Vec<(SocketAddr, RawFd)>,
470}
471
472impl Listeners {
473 pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
474 let before = self.http.len();
475 let pos = self.http.iter().position(|(front, _)| front == addr);
476 let result = pos.map(|pos| self.http.remove(pos).1);
477 debug_assert_eq!(
480 self.http.len(),
481 before - result.is_some() as usize,
482 "http listener table shrinks by exactly one iff an address matched"
483 );
484 debug_assert!(
485 result.is_none() || !self.http.iter().any(|(front, _)| front == addr),
486 "the matched http address must no longer be present after removal"
487 );
488 result
489 }
490
491 pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
492 let before = self.tls.len();
493 let pos = self.tls.iter().position(|(front, _)| front == addr);
494 let result = pos.map(|pos| self.tls.remove(pos).1);
495 debug_assert_eq!(
496 self.tls.len(),
497 before - result.is_some() as usize,
498 "tls listener table shrinks by exactly one iff an address matched"
499 );
500 debug_assert!(
501 result.is_none() || !self.tls.iter().any(|(front, _)| front == addr),
502 "the matched tls address must no longer be present after removal"
503 );
504 result
505 }
506
507 pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
508 let before = self.tcp.len();
509 let pos = self.tcp.iter().position(|(front, _)| front == addr);
510 let result = pos.map(|pos| self.tcp.remove(pos).1);
511 debug_assert_eq!(
512 self.tcp.len(),
513 before - result.is_some() as usize,
514 "tcp listener table shrinks by exactly one iff an address matched"
515 );
516 debug_assert!(
517 result.is_none() || !self.tcp.iter().any(|(front, _)| front == addr),
518 "the matched tcp address must no longer be present after removal"
519 );
520 result
521 }
522
523 pub fn get_udp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
524 self.udp
525 .iter()
526 .position(|(front, _)| front == addr)
527 .map(|pos| self.udp.remove(pos).1)
528 }
529
530 pub fn close(&self) {
532 for (_, fd) in &self.http {
533 unsafe {
537 let _ = TcpListener::from_raw_fd(*fd);
538 }
539 }
540
541 for (_, fd) in &self.tls {
542 unsafe {
546 let _ = TcpListener::from_raw_fd(*fd);
547 }
548 }
549
550 for (_, fd) in &self.tcp {
551 unsafe {
555 let _ = TcpListener::from_raw_fd(*fd);
556 }
557 }
558
559 for (_, fd) in &self.udp {
560 unsafe {
566 let _ = UdpSocket::from_raw_fd(*fd);
567 }
568 }
569 }
570}
571
572fn parse_addresses(addresses: &[String]) -> Result<Vec<SocketAddr>, ScmSocketError> {
573 let mut parsed_addresses = Vec::new();
574 for address in addresses {
575 parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
576 ScmSocketError::WrongSocketAddress {
577 address: address.to_owned(),
578 error,
579 }
580 })?);
581 }
582 Ok(parsed_addresses)
583}
584
585#[cfg(test)]
586mod tests {
587
588 use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
589
590 use mio::net::UnixStream as MioUnixStream;
591
592 use super::*;
593
594 #[test]
595 fn create_block_unblock_an_scm_socket() {
596 let (nonblocking_stream, _) =
597 MioUnixStream::pair().expect("Could not create a pair of unix streams");
598 let raw_file_descriptor = nonblocking_stream.into_raw_fd();
599
600 let scm_socket = ScmSocket::new(raw_file_descriptor);
601 assert!(scm_socket.is_ok());
602
603 let mut scm_socket = scm_socket.unwrap();
604
605 assert!(scm_socket.set_blocking(true).is_ok());
606 assert!(scm_socket.set_blocking(false).is_ok());
607 }
608
609 fn socket_addr_from_str(str: &str) -> SocketAddr {
610 SocketAddr::from_str(str)
611 .unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
612 }
613
614 #[test]
615 fn send_and_receive_empty_listeners() {
616 let (stream_1, stream_2) =
617 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
618
619 let sending_scm_socket =
620 ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
621
622 let receiving_scm_socket =
623 ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
624
625 let listeners = Listeners::default();
626
627 sending_scm_socket
628 .send_listeners(&listeners)
629 .expect("Could not send listeners");
630
631 let received_listeners = receiving_scm_socket
632 .receive_listeners()
633 .expect("Could not receive listeners");
634
635 assert_eq!(listeners, received_listeners);
636 }
637
638 #[test]
639 fn send_and_receive_socket_addresses() {
640 let (stream_1, stream_2) =
641 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
642
643 println!("unix stream pair: {stream_1:?} and {stream_2:?}");
644 let sending_scm_socket =
645 ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
646
647 println!("sending socket: {sending_scm_socket:?}");
648
649 let receiving_scm_socket =
650 ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
651
652 println!("receiving socket: {receiving_scm_socket:?}");
653
654 let (http_socket1, http_socket2) =
656 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
657 let (tcp_socket1, tcp_socket2) =
658 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
659 let (tls_socket1, tls_socket2) =
660 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
661 let (udp_socket1, udp_socket2) =
662 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
663
664 let listeners = Listeners {
665 http: vec![
666 (
667 socket_addr_from_str("127.0.1.1:8080"),
668 http_socket1.as_raw_fd(),
669 ),
670 (
671 socket_addr_from_str("127.0.1.2:8080"),
672 http_socket2.as_raw_fd(),
673 ),
674 ],
675 tcp: vec![
676 (
677 socket_addr_from_str("127.0.2.1:8080"),
678 tcp_socket1.as_raw_fd(),
679 ),
680 (
681 socket_addr_from_str("127.0.2.2:8080"),
682 tcp_socket2.as_raw_fd(),
683 ),
684 ],
685 tls: vec![
686 (
687 socket_addr_from_str("127.0.3.1:8443"),
688 tls_socket1.as_raw_fd(),
689 ),
690 (
691 socket_addr_from_str("127.0.3.2:8443"),
692 tls_socket2.as_raw_fd(),
693 ),
694 ],
695 udp: vec![
696 (
697 socket_addr_from_str("127.0.4.1:5353"),
698 udp_socket1.as_raw_fd(),
699 ),
700 (
701 socket_addr_from_str("127.0.4.2:5353"),
702 udp_socket2.as_raw_fd(),
703 ),
704 ],
705 };
706
707 println!("self.fd: {}", sending_scm_socket.fd);
708 println!("listeners to send: {listeners:#?}");
709
710 sending_scm_socket
711 .send_listeners(&listeners)
712 .expect("Could not send listeners");
713
714 let received_listeners = receiving_scm_socket
715 .receive_listeners()
716 .expect("Could not receive listeners");
717
718 assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
719 assert_eq!(listeners.udp.len(), received_listeners.udp.len());
720 assert_eq!(listeners.udp[0].0, received_listeners.udp[0].0);
721 assert_eq!(listeners.udp[1].0, received_listeners.udp[1].0);
722 }
723
724 #[test]
733 fn rejects_listeners_count_with_more_entries_than_fds() {
734 let (stream_1, stream_2) =
735 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
736 let sender = ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
737 let receiver = ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
738
739 let bogus = ListenersCount {
741 http: vec![
742 "127.0.0.1:80".to_string(),
743 "127.0.0.2:80".to_string(),
744 "127.0.0.3:80".to_string(),
745 ],
746 tls: vec![],
747 tcp: vec![],
748 udp: vec![],
749 };
750 let payload = bogus.encode_length_delimited_to_vec();
751 sender
752 .send_msg_and_fds(&payload, &[])
753 .expect("manual send_msg_and_fds with zero fds must succeed at the syscall layer");
754
755 match receiver.receive_listeners() {
756 Err(ScmSocketError::ListenersCountInconsistent {
757 http,
758 tls,
759 tcp,
760 total,
761 fds_received,
762 max_fds,
763 }) => {
764 assert_eq!(http, 3);
765 assert_eq!(tls, 0);
766 assert_eq!(tcp, 0);
767 assert_eq!(total, 3);
768 assert_eq!(fds_received, 0);
769 assert_eq!(max_fds, MAX_FDS_OUT);
770 }
771 other => panic!(
772 "expected ListenersCountInconsistent, got {other:?}\n\
773 NOTE: a panic / OOM here means the SCM bounds check was reverted",
774 ),
775 }
776 }
777}