use std::{
io::{IoSlice, IoSliceMut},
net::SocketAddr,
os::unix::{
io::{FromRawFd, IntoRawFd, RawFd},
net::UnixStream as StdUnixStream,
},
str::from_utf8,
};
use mio::net::TcpListener;
use nix::{cmsg_space, sys::socket};
use serde_json;
pub const MAX_FDS_OUT: usize = 200;
pub const MAX_BYTES_OUT: usize = 4096;
#[derive(thiserror::Error, Debug)]
pub enum ScmSocketError {
#[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
SetBlocking {
blocking: bool,
error: std::io::Error,
},
#[error("could not send message per SCM socket: {0}")]
Send(String),
#[error("could not send message per SCM socket: {0}")]
Receive(String),
#[error("invalid char set: {0}")]
InvalidCharSet(String),
#[error("Could not deserialize utf8 string into listeners: {0}")]
ListenerParse(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ScmSocket {
pub fd: RawFd,
pub blocking: bool,
}
impl ScmSocket {
pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
unsafe {
let stream = StdUnixStream::from_raw_fd(fd);
stream
.set_nonblocking(false)
.map_err(|error| ScmSocketError::SetBlocking {
blocking: false,
error,
})?;
let _dropped_fd = stream.into_raw_fd();
}
Ok(ScmSocket { fd, blocking: true })
}
pub fn raw_fd(&self) -> i32 {
self.fd
}
pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
if self.blocking == blocking {
return Ok(());
}
unsafe {
let stream = StdUnixStream::from_raw_fd(self.fd);
stream
.set_nonblocking(!blocking)
.map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
let _dropped_fd = stream.into_raw_fd();
}
self.blocking = blocking;
Ok(())
}
pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
let listeners_count = ListenersCount {
http: listeners.http.iter().map(|t| t.0).collect(),
tls: listeners.tls.iter().map(|t| t.0).collect(),
tcp: listeners.tcp.iter().map(|t| t.0).collect(),
};
let message = serde_json::to_string(&listeners_count)
.map(|s| s.into_bytes())
.unwrap_or_else(|_| Vec::new());
let mut file_descriptors: Vec<RawFd> = Vec::new();
file_descriptors.extend(listeners.http.iter().map(|t| t.1));
file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
self.send_msg_and_fds(&message, &file_descriptors)
}
pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
let mut buf = vec![0; MAX_BYTES_OUT];
let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
let (size, file_descriptor_length) =
self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
let raw_listener_list = from_utf8(&buf[..size])
.map_err(|utf8_error| ScmSocketError::InvalidCharSet(utf8_error.to_string()))?;
let mut listeners_count = serde_json::from_str::<ListenersCount>(raw_listener_list)
.map_err(|error| ScmSocketError::ListenerParse(error.to_string()))?;
let mut index = 0;
let len = listeners_count.http.len();
let mut http = Vec::new();
http.extend(
listeners_count
.http
.drain(..)
.zip(received_fds[index..index + len].iter().cloned()),
);
index += len;
let len = listeners_count.tls.len();
let mut tls = Vec::new();
tls.extend(
listeners_count
.tls
.drain(..)
.zip(received_fds[index..index + len].iter().cloned()),
);
index += len;
let mut tcp = Vec::new();
tcp.extend(
listeners_count
.tcp
.drain(..)
.zip(received_fds[index..file_descriptor_length].iter().cloned()),
);
Ok(Listeners { http, tls, tcp })
}
fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
let iov = [IoSlice::new(message)];
let flags = if self.blocking {
socket::MsgFlags::empty()
} else {
socket::MsgFlags::MSG_DONTWAIT
};
if fds.is_empty() {
debug!("{} send empty", self.fd);
socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
.map_err(|error| ScmSocketError::Send(error.to_string()))?;
return Ok(());
};
let control_message = [socket::ControlMessage::ScmRights(fds)];
debug!("{} send with data", self.fd);
socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
.map_err(|error| ScmSocketError::Send(error.to_string()))?;
Ok(())
}
fn receive_msg_and_fds(
&self,
message: &mut [u8],
fds: &mut [RawFd],
) -> Result<(usize, usize), ScmSocketError> {
let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
let mut iov = [IoSliceMut::new(message)];
let flags = if self.blocking {
socket::MsgFlags::empty()
} else {
socket::MsgFlags::MSG_DONTWAIT
};
let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
.map_err(|error| ScmSocketError::Receive(error.to_string()))?;
let mut fd_count = 0;
let received_fds = msg
.cmsgs()
.filter_map(|cmsg| {
if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
Some(s)
} else {
None
}
})
.flatten();
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd;
}
Ok((msg.bytes, fd_count))
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct Listeners {
pub http: Vec<(SocketAddr, RawFd)>,
pub tls: Vec<(SocketAddr, RawFd)>,
pub tcp: Vec<(SocketAddr, RawFd)>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct ListenersCount {
pub http: Vec<SocketAddr>,
pub tls: Vec<SocketAddr>,
pub tcp: Vec<SocketAddr>,
}
impl Listeners {
pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
self.http
.iter()
.position(|(front, _)| front == addr)
.map(|pos| self.http.remove(pos).1)
}
pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
self.tls
.iter()
.position(|(front, _)| front == addr)
.map(|pos| self.tls.remove(pos).1)
}
pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
self.tcp
.iter()
.position(|(front, _)| front == addr)
.map(|pos| self.tcp.remove(pos).1)
}
pub fn close(&self) {
for (_, ref fd) in &self.http {
unsafe {
let _ = TcpListener::from_raw_fd(*fd);
}
}
for (_, ref fd) in &self.tls {
unsafe {
let _ = TcpListener::from_raw_fd(*fd);
}
}
for (_, ref fd) in &self.tcp {
unsafe {
let _ = TcpListener::from_raw_fd(*fd);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use mio::net::UnixStream as MioUnixStream;
use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
#[test]
fn create_block_unblock_an_scm_socket() {
let (nonblocking_stream, _) =
MioUnixStream::pair().expect("Could not create a pair of unix streams");
let raw_file_descriptor = nonblocking_stream.into_raw_fd();
let scm_socket = ScmSocket::new(raw_file_descriptor);
assert!(scm_socket.is_ok());
let mut scm_socket = scm_socket.unwrap();
assert!(scm_socket.set_blocking(true).is_ok());
assert!(scm_socket.set_blocking(false).is_ok());
}
fn socket_addr_from_str(str: &str) -> SocketAddr {
SocketAddr::from_str(str)
.unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
}
#[test]
fn send_and_receive_empty_listeners() {
let (stream_1, stream_2) =
MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
let sending_scm_socket =
ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
let receiving_scm_socket =
ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
let listeners = Listeners {
http: vec![],
tcp: vec![],
tls: vec![],
};
sending_scm_socket
.send_listeners(&listeners)
.expect("Could not send listeners");
let received_listeners = receiving_scm_socket
.receive_listeners()
.expect("Could not receive listeners");
assert_eq!(listeners, received_listeners);
}
#[test]
fn send_and_receive_socket_addresses() {
let (stream_1, stream_2) =
MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
println!("unix stream pair: {stream_1:?} and {stream_2:?}");
let sending_scm_socket =
ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
println!("sending socket: {sending_scm_socket:?}");
let receiving_scm_socket =
ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
println!("receiving socket: {receiving_scm_socket:?}");
let (http_socket1, http_socket2) =
MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
let (tcp_socket1, tcp_socket2) =
MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
let (tls_socket1, tls_socket2) =
MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
let listeners = Listeners {
http: vec![
(
socket_addr_from_str("127.0.1.1:8080"),
http_socket1.as_raw_fd(),
),
(
socket_addr_from_str("127.0.1.2:8080"),
http_socket2.as_raw_fd(),
),
],
tcp: vec![
(
socket_addr_from_str("127.0.2.1:8080"),
tcp_socket1.as_raw_fd(),
),
(
socket_addr_from_str("127.0.2.2:8080"),
tcp_socket2.as_raw_fd(),
),
],
tls: vec![
(
socket_addr_from_str("127.0.3.1:8443"),
tls_socket1.as_raw_fd(),
),
(
socket_addr_from_str("127.0.3.2:8443"),
tls_socket2.as_raw_fd(),
),
],
};
println!("self.fd: {}", sending_scm_socket.fd);
println!("listeners to send: {listeners:#?}");
sending_scm_socket
.send_listeners(&listeners)
.expect("Could not send listeners");
let received_listeners = receiving_scm_socket
.receive_listeners()
.expect("Could not receive listeners");
assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
}
}