1use std::{
2 io::{IoSlice, IoSliceMut},
3 net::{AddrParseError, SocketAddr},
4 os::unix::{
5 io::{FromRawFd, IntoRawFd, RawFd},
6 net::UnixStream as StdUnixStream,
7 },
8};
9
10use mio::net::TcpListener;
11use nix::{cmsg_space, sys::socket};
12use prost::{DecodeError, Message};
13
14use crate::proto::command::ListenersCount;
15
16pub const MAX_FDS_OUT: usize = 200;
17pub const MAX_BYTES_OUT: usize = 4096;
18
19#[derive(thiserror::Error, Debug)]
20pub enum ScmSocketError {
21 #[error("could not set the blocking status of the unix stream to {blocking}: {error}")]
22 SetBlocking {
23 blocking: bool,
24 error: std::io::Error,
25 },
26 #[error("could not send message for SCM socket: {0}")]
27 Send(String),
28 #[error("could not receive message for SCM socket: {0}")]
29 Receive(String),
30 #[error("invalid char set: {0}")]
31 InvalidCharSet(String),
32 #[error("Could not deserialize utf8 string into listeners: {0}")]
33 ListenerParse(String),
34 #[error("Wrong socket address {address}: {error}")]
35 WrongSocketAddress {
36 address: String,
37 error: AddrParseError,
38 },
39 #[error("error decoding the protobuf format of the listeners: {0}")]
40 DecodeError(DecodeError),
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
45pub struct ScmSocket {
46 pub fd: RawFd,
47 pub blocking: bool,
48}
49
50impl ScmSocket {
51 pub fn new(fd: RawFd) -> Result<Self, ScmSocketError> {
53 unsafe {
54 let stream = StdUnixStream::from_raw_fd(fd);
55 stream
56 .set_nonblocking(false)
57 .map_err(|error| ScmSocketError::SetBlocking {
58 blocking: false,
59 error,
60 })?;
61 let _dropped_fd = stream.into_raw_fd();
62 }
63
64 Ok(ScmSocket { fd, blocking: true })
65 }
66
67 pub fn raw_fd(&self) -> i32 {
69 self.fd
70 }
71
72 pub fn set_blocking(&mut self, blocking: bool) -> Result<(), ScmSocketError> {
74 if self.blocking == blocking {
75 return Ok(());
76 }
77 unsafe {
78 let stream = StdUnixStream::from_raw_fd(self.fd);
79 stream
80 .set_nonblocking(!blocking)
81 .map_err(|error| ScmSocketError::SetBlocking { blocking, error })?;
82 let _dropped_fd = stream.into_raw_fd();
83 }
84 self.blocking = blocking;
85 Ok(())
86 }
87
88 pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
90 let listeners_count = ListenersCount {
91 http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
92 tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
93 tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
94 };
95
96 let message = listeners_count.encode_length_delimited_to_vec();
97
98 let mut file_descriptors: Vec<RawFd> = Vec::new();
99
100 file_descriptors.extend(listeners.http.iter().map(|t| t.1));
101 file_descriptors.extend(listeners.tls.iter().map(|t| t.1));
102 file_descriptors.extend(listeners.tcp.iter().map(|t| t.1));
103
104 self.send_msg_and_fds(&message, &file_descriptors)
105 }
106
107 pub fn receive_listeners(&self) -> Result<Listeners, ScmSocketError> {
109 let mut buf = vec![0; MAX_BYTES_OUT];
110
111 let mut received_fds: [RawFd; MAX_FDS_OUT] = [0; MAX_FDS_OUT];
112
113 let (size, file_descriptor_length) =
114 self.receive_msg_and_fds(&mut buf, &mut received_fds)?;
115
116 debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));
117
118 let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
119 .map_err(ScmSocketError::DecodeError)?;
120
121 let mut http_addresses = parse_addresses(&listeners_count.http)?;
122 let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
123 let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;
124
125 let mut index = 0;
126 let len = listeners_count.http.len();
127 let mut http = Vec::new();
128 http.extend(
129 http_addresses
130 .drain(..)
131 .zip(received_fds[index..index + len].iter().cloned()),
132 );
133
134 index += len;
135 let len = listeners_count.tls.len();
136 let mut tls = Vec::new();
137 tls.extend(
138 tls_addresses
139 .drain(..)
140 .zip(received_fds[index..index + len].iter().cloned()),
141 );
142
143 index += len;
144 let mut tcp = Vec::new();
145 tcp.extend(
146 tcp_addresses
147 .drain(..)
148 .zip(received_fds[index..file_descriptor_length].iter().cloned()),
149 );
150
151 Ok(Listeners { http, tls, tcp })
152 }
153
154 fn send_msg_and_fds(&self, message: &[u8], fds: &[RawFd]) -> Result<(), ScmSocketError> {
157 let iov = [IoSlice::new(message)];
158 let flags = if self.blocking {
159 socket::MsgFlags::empty()
160 } else {
161 socket::MsgFlags::MSG_DONTWAIT
162 };
163
164 if fds.is_empty() {
165 debug!("{} send empty", self.fd);
166 socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)
167 .map_err(|error| ScmSocketError::Send(error.to_string()))?;
168 return Ok(());
169 };
170
171 let control_message = [socket::ControlMessage::ScmRights(fds)];
172 debug!("{} send with data", self.fd);
173 socket::sendmsg::<()>(self.fd, &iov, &control_message, flags, None)
174 .map_err(|error| ScmSocketError::Send(error.to_string()))?;
175 Ok(())
176 }
177
178 fn receive_msg_and_fds(
180 &self,
181 message: &mut [u8],
182 fds: &mut [RawFd],
183 ) -> Result<(usize, usize), ScmSocketError> {
184 let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
185 let mut iov = [IoSliceMut::new(message)];
186
187 let flags = if self.blocking {
188 socket::MsgFlags::empty()
189 } else {
190 socket::MsgFlags::MSG_DONTWAIT
191 };
192
193 let msg = socket::recvmsg::<()>(self.fd, &mut iov[..], Some(&mut cmsg), flags)
194 .map_err(|error| ScmSocketError::Receive(error.to_string()))?;
195
196 let mut fd_count = 0;
197 let received_fds = msg
198 .cmsgs()
199 .map_err(|error| ScmSocketError::Receive(error.to_string()))?
200 .filter_map(|cmsg| {
201 if let socket::ControlMessageOwned::ScmRights(s) = cmsg {
202 Some(s)
203 } else {
204 None
205 }
206 })
207 .flatten();
208 for (fd, place) in received_fds.zip(fds.iter_mut()) {
209 fd_count += 1;
210 *place = fd;
211 }
212 Ok((msg.bytes, fd_count))
213 }
214}
215
216#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
218pub struct Listeners {
219 pub http: Vec<(SocketAddr, RawFd)>,
220 pub tls: Vec<(SocketAddr, RawFd)>,
221 pub tcp: Vec<(SocketAddr, RawFd)>,
222}
223
224impl Listeners {
225 pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
226 self.http
227 .iter()
228 .position(|(front, _)| front == addr)
229 .map(|pos| self.http.remove(pos).1)
230 }
231
232 pub fn get_https(&mut self, addr: &SocketAddr) -> Option<RawFd> {
233 self.tls
234 .iter()
235 .position(|(front, _)| front == addr)
236 .map(|pos| self.tls.remove(pos).1)
237 }
238
239 pub fn get_tcp(&mut self, addr: &SocketAddr) -> Option<RawFd> {
240 self.tcp
241 .iter()
242 .position(|(front, _)| front == addr)
243 .map(|pos| self.tcp.remove(pos).1)
244 }
245
246 pub fn close(&self) {
248 for (_, ref fd) in &self.http {
249 unsafe {
250 let _ = TcpListener::from_raw_fd(*fd);
251 }
252 }
253
254 for (_, ref fd) in &self.tls {
255 unsafe {
256 let _ = TcpListener::from_raw_fd(*fd);
257 }
258 }
259
260 for (_, ref fd) in &self.tcp {
261 unsafe {
262 let _ = TcpListener::from_raw_fd(*fd);
263 }
264 }
265 }
266}
267
268fn parse_addresses(addresses: &[String]) -> Result<Vec<SocketAddr>, ScmSocketError> {
269 let mut parsed_addresses = Vec::new();
270 for address in addresses {
271 parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
272 ScmSocketError::WrongSocketAddress {
273 address: address.to_owned(),
274 error,
275 }
276 })?);
277 }
278 Ok(parsed_addresses)
279}
280
281#[cfg(test)]
282mod tests {
283
284 use super::*;
285 use mio::net::UnixStream as MioUnixStream;
286 use std::{net::SocketAddr, os::unix::prelude::AsRawFd, str::FromStr};
287
288 #[test]
289 fn create_block_unblock_an_scm_socket() {
290 let (nonblocking_stream, _) =
291 MioUnixStream::pair().expect("Could not create a pair of unix streams");
292 let raw_file_descriptor = nonblocking_stream.into_raw_fd();
293
294 let scm_socket = ScmSocket::new(raw_file_descriptor);
295 assert!(scm_socket.is_ok());
296
297 let mut scm_socket = scm_socket.unwrap();
298
299 assert!(scm_socket.set_blocking(true).is_ok());
300 assert!(scm_socket.set_blocking(false).is_ok());
301 }
302
303 fn socket_addr_from_str(str: &str) -> SocketAddr {
304 SocketAddr::from_str(str)
305 .unwrap_or_else(|_| panic!("failed to create socket address from string slice {str}"))
306 }
307
308 #[test]
309 fn send_and_receive_empty_listeners() {
310 let (stream_1, stream_2) =
311 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
312
313 let sending_scm_socket =
314 ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
315
316 let receiving_scm_socket =
317 ScmSocket::new(stream_2.as_raw_fd()).expect("Could not create scm socket");
318
319 let listeners = Listeners::default();
320
321 sending_scm_socket
322 .send_listeners(&listeners)
323 .expect("Could not send listeners");
324
325 let received_listeners = receiving_scm_socket
326 .receive_listeners()
327 .expect("Could not receive listeners");
328
329 assert_eq!(listeners, received_listeners);
330 }
331
332 #[test]
333 fn send_and_receive_socket_addresses() {
334 let (stream_1, stream_2) =
335 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
336
337 println!("unix stream pair: {stream_1:?} and {stream_2:?}");
338 let sending_scm_socket =
339 ScmSocket::new(stream_1.into_raw_fd()).expect("Could not create scm socket");
340
341 println!("sending socket: {sending_scm_socket:?}");
342
343 let receiving_scm_socket =
344 ScmSocket::new(stream_2.into_raw_fd()).expect("Could not create scm socket");
345
346 println!("receiving socket: {receiving_scm_socket:?}");
347
348 let (http_socket1, http_socket2) =
350 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
351 let (tcp_socket1, tcp_socket2) =
352 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
353 let (tls_socket1, tls_socket2) =
354 MioUnixStream::pair().expect("Could not create a pair of mio unix streams");
355
356 let listeners = Listeners {
357 http: vec![
358 (
359 socket_addr_from_str("127.0.1.1:8080"),
360 http_socket1.as_raw_fd(),
361 ),
362 (
363 socket_addr_from_str("127.0.1.2:8080"),
364 http_socket2.as_raw_fd(),
365 ),
366 ],
367 tcp: vec![
368 (
369 socket_addr_from_str("127.0.2.1:8080"),
370 tcp_socket1.as_raw_fd(),
371 ),
372 (
373 socket_addr_from_str("127.0.2.2:8080"),
374 tcp_socket2.as_raw_fd(),
375 ),
376 ],
377 tls: vec![
378 (
379 socket_addr_from_str("127.0.3.1:8443"),
380 tls_socket1.as_raw_fd(),
381 ),
382 (
383 socket_addr_from_str("127.0.3.2:8443"),
384 tls_socket2.as_raw_fd(),
385 ),
386 ],
387 };
388
389 println!("self.fd: {}", sending_scm_socket.fd);
390 println!("listeners to send: {listeners:#?}");
391
392 sending_scm_socket
393 .send_listeners(&listeners)
394 .expect("Could not send listeners");
395
396 let received_listeners = receiving_scm_socket
397 .receive_listeners()
398 .expect("Could not receive listeners");
399
400 assert_eq!(listeners.http[0].0, received_listeners.http[0].0);
401 }
402}