running_process/broker/server/handoff/
unix.rs1#[cfg(unix)]
9use std::path::Path;
10use std::path::PathBuf;
11
12use super::{
13 HandoffAttemptDecision, HandoffAttemptFailure, HandoffFallbackDecision, HandoffFallbackReason,
14 HandoffToken,
15};
16
17pub const SCM_RIGHTS_TRANSPORT_SUPPORTED: bool = cfg!(unix);
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
22pub struct UnixFileDescriptor(i32);
23
24impl UnixFileDescriptor {
25 pub fn new(raw_fd: i32) -> Self {
27 Self(raw_fd)
28 }
29
30 pub fn raw(self) -> i32 {
32 self.0
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
38pub struct UnixHandoffSocket {
39 pub path: PathBuf,
41}
42
43impl UnixHandoffSocket {
44 pub fn new(path: impl Into<PathBuf>) -> Self {
46 Self { path: path.into() }
47 }
48}
49
50#[derive(Clone, Debug, PartialEq, Eq)]
52pub struct ScmRightsAttempt {
53 pub fd: UnixFileDescriptor,
55 pub backend_socket: UnixHandoffSocket,
57 pub handoff_token: HandoffToken,
59}
60
61impl ScmRightsAttempt {
62 pub fn new(
64 fd: UnixFileDescriptor,
65 backend_socket: UnixHandoffSocket,
66 handoff_token: HandoffToken,
67 ) -> Self {
68 Self {
69 fd,
70 backend_socket,
71 handoff_token,
72 }
73 }
74}
75
76#[derive(Clone, Debug, PartialEq, Eq)]
78pub struct ScmRightsSuccess {
79 pub sent_fd: UnixFileDescriptor,
81 pub backend_socket: UnixHandoffSocket,
83 pub handoff_token: HandoffToken,
85}
86
87impl ScmRightsSuccess {
88 pub fn new(
90 sent_fd: UnixFileDescriptor,
91 backend_socket: UnixHandoffSocket,
92 handoff_token: HandoffToken,
93 ) -> Self {
94 Self {
95 sent_fd,
96 backend_socket,
97 handoff_token,
98 }
99 }
100}
101
102pub type ScmRightsResult = Result<ScmRightsSuccess, ScmRightsError>;
104
105pub fn try_send_scm_rights(attempt: &ScmRightsAttempt) -> ScmRightsResult {
111 platform_try_send_scm_rights(attempt)
112}
113
114#[cfg(unix)]
123pub fn try_send_scm_rights_over(
124 socket_fd: std::os::fd::RawFd,
125 attempt: &ScmRightsAttempt,
126) -> ScmRightsResult {
127 send_fd_with_token(
128 socket_fd,
129 attempt.fd.raw(),
130 attempt.handoff_token.as_bytes(),
131 &attempt.backend_socket.path,
132 )?;
133 Ok(ScmRightsSuccess::new(
134 attempt.fd,
135 attempt.backend_socket.clone(),
136 attempt.handoff_token,
137 ))
138}
139
140#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
142pub enum ScmRightsError {
143 #[error("SCM_RIGHTS handoff transport is unsupported on this platform")]
145 UnsupportedPlatform,
146 #[error("permission denied passing fd {fd} to backend handoff socket {socket}")]
148 PermissionDenied {
149 fd: i32,
151 socket: PathBuf,
153 },
154 #[error("backend handoff socket is unavailable: {socket}")]
156 BackendSocketUnavailable {
157 socket: PathBuf,
159 },
160 #[error("SCM_RIGHTS send would block for backend handoff socket {socket}")]
162 WouldBlock {
163 socket: PathBuf,
165 },
166 #[error("SCM_RIGHTS send failed for fd {fd} to backend handoff socket {socket}")]
168 SendFailed {
169 fd: i32,
171 socket: PathBuf,
173 raw_os_error: Option<i32>,
175 },
176 #[error("backend handoff socket {socket} did not acknowledge passed fd")]
178 BackendAckTimeout {
179 socket: PathBuf,
181 },
182}
183
184#[cfg(unix)]
185fn platform_try_send_scm_rights(attempt: &ScmRightsAttempt) -> ScmRightsResult {
186 use std::os::fd::AsRawFd;
187 use std::os::unix::net::UnixStream;
188
189 let stream = UnixStream::connect(&attempt.backend_socket.path)
190 .map_err(|err| socket_connect_error(&attempt.backend_socket.path, err))?;
191 stream
192 .set_nonblocking(true)
193 .map_err(|err| socket_connect_error(&attempt.backend_socket.path, err))?;
194
195 send_fd_with_token(
196 stream.as_raw_fd(),
197 attempt.fd.raw(),
198 attempt.handoff_token.as_bytes(),
199 &attempt.backend_socket.path,
200 )?;
201
202 Ok(ScmRightsSuccess::new(
203 attempt.fd,
204 attempt.backend_socket.clone(),
205 attempt.handoff_token,
206 ))
207}
208
209#[cfg(not(unix))]
210fn platform_try_send_scm_rights(_attempt: &ScmRightsAttempt) -> ScmRightsResult {
211 Err(ScmRightsError::UnsupportedPlatform)
212}
213
214#[cfg(unix)]
215fn send_fd_with_token(
216 socket_fd: std::os::fd::RawFd,
217 sent_fd: std::os::fd::RawFd,
218 token: &[u8; 16],
219 socket_path: &Path,
220) -> Result<(), ScmRightsError> {
221 let mut token_payload = *token;
222 let mut iov = libc::iovec {
223 iov_base: token_payload.as_mut_ptr().cast(),
224 iov_len: token_payload.len(),
225 };
226 let mut control = vec![0_u8; cmsg_space::<libc::c_int>()];
227 let mut message = unsafe { std::mem::zeroed::<libc::msghdr>() };
228 message.msg_iov = &mut iov;
229 message.msg_iovlen = 1;
230 message.msg_control = control.as_mut_ptr().cast();
231 message.msg_controllen = control.len() as _;
232
233 unsafe {
234 let header = libc::CMSG_FIRSTHDR(&message);
235 if header.is_null() {
236 return Err(ScmRightsError::SendFailed {
237 fd: sent_fd,
238 socket: socket_path.to_path_buf(),
239 raw_os_error: None,
240 });
241 }
242
243 (*header).cmsg_level = libc::SOL_SOCKET;
244 (*header).cmsg_type = libc::SCM_RIGHTS;
245 (*header).cmsg_len = cmsg_len::<libc::c_int>() as _;
246 *libc::CMSG_DATA(header).cast::<libc::c_int>() = sent_fd;
247 }
248
249 let sent = unsafe { libc::sendmsg(socket_fd, &message, sendmsg_flags()) };
250 if sent < 0 {
251 return Err(sendmsg_error(
252 sent_fd,
253 socket_path,
254 std::io::Error::last_os_error(),
255 ));
256 }
257 if sent as usize != token_payload.len() {
258 return Err(ScmRightsError::SendFailed {
259 fd: sent_fd,
260 socket: socket_path.to_path_buf(),
261 raw_os_error: None,
262 });
263 }
264
265 Ok(())
266}
267
268#[cfg(unix)]
269fn cmsg_space<T>() -> usize {
270 unsafe { libc::CMSG_SPACE(std::mem::size_of::<T>() as _) as usize }
271}
272
273#[cfg(unix)]
274fn cmsg_len<T>() -> usize {
275 unsafe { libc::CMSG_LEN(std::mem::size_of::<T>() as _) as usize }
276}
277
278#[cfg(all(unix, any(target_os = "android", target_os = "linux")))]
279fn sendmsg_flags() -> libc::c_int {
280 libc::MSG_NOSIGNAL
281}
282
283#[cfg(all(unix, not(any(target_os = "android", target_os = "linux"))))]
284fn sendmsg_flags() -> libc::c_int {
285 0
286}
287
288#[cfg(unix)]
289fn socket_connect_error(socket: &Path, error: std::io::Error) -> ScmRightsError {
290 match error.kind() {
291 std::io::ErrorKind::PermissionDenied => ScmRightsError::PermissionDenied {
292 fd: -1,
293 socket: socket.to_path_buf(),
294 },
295 std::io::ErrorKind::WouldBlock => ScmRightsError::WouldBlock {
296 socket: socket.to_path_buf(),
297 },
298 _ => ScmRightsError::BackendSocketUnavailable {
299 socket: socket.to_path_buf(),
300 },
301 }
302}
303
304#[cfg(unix)]
305fn sendmsg_error(fd: std::os::fd::RawFd, socket: &Path, error: std::io::Error) -> ScmRightsError {
306 match error.kind() {
307 std::io::ErrorKind::PermissionDenied => ScmRightsError::PermissionDenied {
308 fd,
309 socket: socket.to_path_buf(),
310 },
311 std::io::ErrorKind::WouldBlock => ScmRightsError::WouldBlock {
312 socket: socket.to_path_buf(),
313 },
314 std::io::ErrorKind::ConnectionRefused
315 | std::io::ErrorKind::ConnectionReset
316 | std::io::ErrorKind::BrokenPipe
317 | std::io::ErrorKind::NotConnected => ScmRightsError::BackendSocketUnavailable {
318 socket: socket.to_path_buf(),
319 },
320 _ => ScmRightsError::SendFailed {
321 fd,
322 socket: socket.to_path_buf(),
323 raw_os_error: error.raw_os_error(),
324 },
325 }
326}
327
328impl ScmRightsError {
329 pub fn attempt_failure(&self) -> Option<HandoffAttemptFailure> {
331 match self {
332 Self::UnsupportedPlatform => None,
333 Self::PermissionDenied { .. } => Some(HandoffAttemptFailure::PermissionDenied),
334 Self::BackendSocketUnavailable { .. }
335 | Self::WouldBlock { .. }
336 | Self::SendFailed { .. }
337 | Self::BackendAckTimeout { .. } => Some(HandoffAttemptFailure::BackendAckTimeout),
338 }
339 }
340
341 pub fn fallback_reason(&self) -> HandoffFallbackReason {
343 match self.attempt_failure() {
344 Some(failure) => failure.into(),
345 None => HandoffFallbackReason::ServicePolicyDisabled,
346 }
347 }
348
349 pub fn fallback_decision(&self) -> HandoffFallbackDecision {
351 HandoffFallbackDecision::new(self.fallback_reason())
352 }
353
354 pub fn fallback_attempt_decision(&self) -> HandoffAttemptDecision {
356 HandoffAttemptDecision::FallbackToReconnect(self.fallback_decision())
357 }
358
359 pub fn is_fallback_safe(&self) -> bool {
361 let fallback = self.fallback_decision();
362 fallback.uses_backend_reconnect() && !fallback.sends_client_error()
363 }
364}
365
366#[cfg(all(test, unix))]
367mod tests {
368 use std::fs::File;
369 use std::os::fd::{AsRawFd, RawFd};
370 use std::os::unix::net::{UnixListener, UnixStream};
371 use std::thread;
372
373 use super::*;
374
375 #[test]
376 fn send_scm_rights_to_backend_socket_transfers_fd_and_token() {
377 let dir = tempfile::tempdir().unwrap();
378 let socket_path = dir.path().join("handoff.sock");
379 let listener = UnixListener::bind(&socket_path).unwrap();
380 let expected_token = HandoffToken::from_bytes([0x41; 16]);
381 let receiver = thread::spawn(move || {
382 let (stream, _) = listener.accept().unwrap();
383 recv_fd_and_token(stream)
384 });
385 let file = File::open("/dev/null").unwrap();
386 let attempt = ScmRightsAttempt::new(
387 UnixFileDescriptor::new(file.as_raw_fd()),
388 UnixHandoffSocket::new(socket_path),
389 expected_token,
390 );
391
392 let success = try_send_scm_rights(&attempt).unwrap();
393 let (received_fd, received_token) = receiver.join().unwrap();
394
395 assert_eq!(success.sent_fd, attempt.fd);
396 assert_eq!(success.handoff_token, expected_token);
397 assert_eq!(received_token, expected_token);
398 assert_ne!(received_fd, file.as_raw_fd());
399
400 unsafe {
401 libc::close(received_fd);
402 }
403 }
404
405 #[test]
406 fn missing_backend_socket_maps_to_fallback_safe_error() {
407 let dir = tempfile::tempdir().unwrap();
408 let socket = UnixHandoffSocket::new(dir.path().join("missing.sock"));
409 let file = File::open("/dev/null").unwrap();
410 let attempt = ScmRightsAttempt::new(
411 UnixFileDescriptor::new(file.as_raw_fd()),
412 socket.clone(),
413 HandoffToken::from_bytes([0x42; 16]),
414 );
415
416 let err = try_send_scm_rights(&attempt).unwrap_err();
417
418 assert!(matches!(
419 err,
420 ScmRightsError::BackendSocketUnavailable { socket: ref path }
421 if path == &socket.path
422 ));
423 assert!(err.is_fallback_safe());
424 }
425
426 fn recv_fd_and_token(stream: UnixStream) -> (RawFd, HandoffToken) {
427 let mut token_payload = [0_u8; 16];
428 let mut iov = libc::iovec {
429 iov_base: token_payload.as_mut_ptr().cast(),
430 iov_len: token_payload.len(),
431 };
432 let mut control = vec![0_u8; cmsg_space::<libc::c_int>()];
433 let mut message = unsafe { std::mem::zeroed::<libc::msghdr>() };
434 message.msg_iov = &mut iov;
435 message.msg_iovlen = 1;
436 message.msg_control = control.as_mut_ptr().cast();
437 message.msg_controllen = control.len() as _;
438
439 let received = unsafe { libc::recvmsg(stream.as_raw_fd(), &mut message, 0) };
440 assert_eq!(received as usize, token_payload.len());
441
442 let header = unsafe { libc::CMSG_FIRSTHDR(&message) };
443 assert!(!header.is_null());
444 unsafe {
445 assert_eq!((*header).cmsg_level, libc::SOL_SOCKET);
446 assert_eq!((*header).cmsg_type, libc::SCM_RIGHTS);
447 let received_fd = *libc::CMSG_DATA(header).cast::<libc::c_int>();
448 (received_fd, HandoffToken::from_bytes(token_payload))
449 }
450 }
451}