running_process/broker/server/handoff/
unix.rs1#[cfg(unix)]
9use std::path::Path;
10use std::path::PathBuf;
11
12use super::{
13 HandoffAttemptDecision, HandoffAttemptFailure, HandoffFallbackDecision, HandoffFallbackReason,
14 HandoffToken,
15};
16
17pub const SCM_RIGHTS_TRANSPORT_SUPPORTED: bool = cfg!(unix);
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
22pub struct UnixFileDescriptor(i32);
23
24impl UnixFileDescriptor {
25 pub fn new(raw_fd: i32) -> Self {
27 Self(raw_fd)
28 }
29
30 pub fn raw(self) -> i32 {
32 self.0
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
38pub struct UnixHandoffSocket {
39 pub path: PathBuf,
41}
42
43impl UnixHandoffSocket {
44 pub fn new(path: impl Into<PathBuf>) -> Self {
46 Self { path: path.into() }
47 }
48}
49
50#[derive(Clone, Debug, PartialEq, Eq)]
52pub struct ScmRightsAttempt {
53 pub fd: UnixFileDescriptor,
55 pub backend_socket: UnixHandoffSocket,
57 pub handoff_token: HandoffToken,
59}
60
61impl ScmRightsAttempt {
62 pub fn new(
64 fd: UnixFileDescriptor,
65 backend_socket: UnixHandoffSocket,
66 handoff_token: HandoffToken,
67 ) -> Self {
68 Self {
69 fd,
70 backend_socket,
71 handoff_token,
72 }
73 }
74}
75
76#[derive(Clone, Debug, PartialEq, Eq)]
78pub struct ScmRightsSuccess {
79 pub sent_fd: UnixFileDescriptor,
81 pub backend_socket: UnixHandoffSocket,
83 pub handoff_token: HandoffToken,
85}
86
87impl ScmRightsSuccess {
88 pub fn new(
90 sent_fd: UnixFileDescriptor,
91 backend_socket: UnixHandoffSocket,
92 handoff_token: HandoffToken,
93 ) -> Self {
94 Self {
95 sent_fd,
96 backend_socket,
97 handoff_token,
98 }
99 }
100}
101
102pub type ScmRightsResult = Result<ScmRightsSuccess, ScmRightsError>;
104
105pub fn try_send_scm_rights(attempt: &ScmRightsAttempt) -> ScmRightsResult {
111 platform_try_send_scm_rights(attempt)
112}
113
114#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
116pub enum ScmRightsError {
117 #[error("SCM_RIGHTS handoff transport is unsupported on this platform")]
119 UnsupportedPlatform,
120 #[error("permission denied passing fd {fd} to backend handoff socket {socket}")]
122 PermissionDenied {
123 fd: i32,
125 socket: PathBuf,
127 },
128 #[error("backend handoff socket is unavailable: {socket}")]
130 BackendSocketUnavailable {
131 socket: PathBuf,
133 },
134 #[error("SCM_RIGHTS send would block for backend handoff socket {socket}")]
136 WouldBlock {
137 socket: PathBuf,
139 },
140 #[error("SCM_RIGHTS send failed for fd {fd} to backend handoff socket {socket}")]
142 SendFailed {
143 fd: i32,
145 socket: PathBuf,
147 raw_os_error: Option<i32>,
149 },
150 #[error("backend handoff socket {socket} did not acknowledge passed fd")]
152 BackendAckTimeout {
153 socket: PathBuf,
155 },
156}
157
158#[cfg(unix)]
159fn platform_try_send_scm_rights(attempt: &ScmRightsAttempt) -> ScmRightsResult {
160 use std::os::fd::AsRawFd;
161 use std::os::unix::net::UnixStream;
162
163 let stream = UnixStream::connect(&attempt.backend_socket.path)
164 .map_err(|err| socket_connect_error(&attempt.backend_socket.path, err))?;
165 stream
166 .set_nonblocking(true)
167 .map_err(|err| socket_connect_error(&attempt.backend_socket.path, err))?;
168
169 send_fd_with_token(
170 stream.as_raw_fd(),
171 attempt.fd.raw(),
172 attempt.handoff_token.as_bytes(),
173 &attempt.backend_socket.path,
174 )?;
175
176 Ok(ScmRightsSuccess::new(
177 attempt.fd,
178 attempt.backend_socket.clone(),
179 attempt.handoff_token,
180 ))
181}
182
183#[cfg(not(unix))]
184fn platform_try_send_scm_rights(_attempt: &ScmRightsAttempt) -> ScmRightsResult {
185 Err(ScmRightsError::UnsupportedPlatform)
186}
187
188#[cfg(unix)]
189fn send_fd_with_token(
190 socket_fd: std::os::fd::RawFd,
191 sent_fd: std::os::fd::RawFd,
192 token: &[u8; 16],
193 socket_path: &Path,
194) -> Result<(), ScmRightsError> {
195 let mut token_payload = *token;
196 let mut iov = libc::iovec {
197 iov_base: token_payload.as_mut_ptr().cast(),
198 iov_len: token_payload.len(),
199 };
200 let mut control = vec![0_u8; cmsg_space::<libc::c_int>()];
201 let mut message = unsafe { std::mem::zeroed::<libc::msghdr>() };
202 message.msg_iov = &mut iov;
203 message.msg_iovlen = 1;
204 message.msg_control = control.as_mut_ptr().cast();
205 message.msg_controllen = control.len() as _;
206
207 unsafe {
208 let header = libc::CMSG_FIRSTHDR(&message);
209 if header.is_null() {
210 return Err(ScmRightsError::SendFailed {
211 fd: sent_fd,
212 socket: socket_path.to_path_buf(),
213 raw_os_error: None,
214 });
215 }
216
217 (*header).cmsg_level = libc::SOL_SOCKET;
218 (*header).cmsg_type = libc::SCM_RIGHTS;
219 (*header).cmsg_len = cmsg_len::<libc::c_int>() as _;
220 *libc::CMSG_DATA(header).cast::<libc::c_int>() = sent_fd;
221 }
222
223 let sent = unsafe { libc::sendmsg(socket_fd, &message, sendmsg_flags()) };
224 if sent < 0 {
225 return Err(sendmsg_error(
226 sent_fd,
227 socket_path,
228 std::io::Error::last_os_error(),
229 ));
230 }
231 if sent as usize != token_payload.len() {
232 return Err(ScmRightsError::SendFailed {
233 fd: sent_fd,
234 socket: socket_path.to_path_buf(),
235 raw_os_error: None,
236 });
237 }
238
239 Ok(())
240}
241
242#[cfg(unix)]
243fn cmsg_space<T>() -> usize {
244 unsafe { libc::CMSG_SPACE(std::mem::size_of::<T>() as _) as usize }
245}
246
247#[cfg(unix)]
248fn cmsg_len<T>() -> usize {
249 unsafe { libc::CMSG_LEN(std::mem::size_of::<T>() as _) as usize }
250}
251
252#[cfg(all(unix, any(target_os = "android", target_os = "linux")))]
253fn sendmsg_flags() -> libc::c_int {
254 libc::MSG_NOSIGNAL
255}
256
257#[cfg(all(unix, not(any(target_os = "android", target_os = "linux"))))]
258fn sendmsg_flags() -> libc::c_int {
259 0
260}
261
262#[cfg(unix)]
263fn socket_connect_error(socket: &Path, error: std::io::Error) -> ScmRightsError {
264 match error.kind() {
265 std::io::ErrorKind::PermissionDenied => ScmRightsError::PermissionDenied {
266 fd: -1,
267 socket: socket.to_path_buf(),
268 },
269 std::io::ErrorKind::WouldBlock => ScmRightsError::WouldBlock {
270 socket: socket.to_path_buf(),
271 },
272 _ => ScmRightsError::BackendSocketUnavailable {
273 socket: socket.to_path_buf(),
274 },
275 }
276}
277
278#[cfg(unix)]
279fn sendmsg_error(fd: std::os::fd::RawFd, socket: &Path, error: std::io::Error) -> ScmRightsError {
280 match error.kind() {
281 std::io::ErrorKind::PermissionDenied => ScmRightsError::PermissionDenied {
282 fd,
283 socket: socket.to_path_buf(),
284 },
285 std::io::ErrorKind::WouldBlock => ScmRightsError::WouldBlock {
286 socket: socket.to_path_buf(),
287 },
288 std::io::ErrorKind::ConnectionRefused
289 | std::io::ErrorKind::ConnectionReset
290 | std::io::ErrorKind::BrokenPipe
291 | std::io::ErrorKind::NotConnected => ScmRightsError::BackendSocketUnavailable {
292 socket: socket.to_path_buf(),
293 },
294 _ => ScmRightsError::SendFailed {
295 fd,
296 socket: socket.to_path_buf(),
297 raw_os_error: error.raw_os_error(),
298 },
299 }
300}
301
302impl ScmRightsError {
303 pub fn attempt_failure(&self) -> Option<HandoffAttemptFailure> {
305 match self {
306 Self::UnsupportedPlatform => None,
307 Self::PermissionDenied { .. } => Some(HandoffAttemptFailure::PermissionDenied),
308 Self::BackendSocketUnavailable { .. }
309 | Self::WouldBlock { .. }
310 | Self::SendFailed { .. }
311 | Self::BackendAckTimeout { .. } => Some(HandoffAttemptFailure::BackendAckTimeout),
312 }
313 }
314
315 pub fn fallback_reason(&self) -> HandoffFallbackReason {
317 match self.attempt_failure() {
318 Some(failure) => failure.into(),
319 None => HandoffFallbackReason::ServicePolicyDisabled,
320 }
321 }
322
323 pub fn fallback_decision(&self) -> HandoffFallbackDecision {
325 HandoffFallbackDecision::new(self.fallback_reason())
326 }
327
328 pub fn fallback_attempt_decision(&self) -> HandoffAttemptDecision {
330 HandoffAttemptDecision::FallbackToReconnect(self.fallback_decision())
331 }
332
333 pub fn is_fallback_safe(&self) -> bool {
335 let fallback = self.fallback_decision();
336 fallback.uses_backend_reconnect() && !fallback.sends_client_error()
337 }
338}
339
340#[cfg(all(test, unix))]
341mod tests {
342 use std::fs::File;
343 use std::os::fd::{AsRawFd, RawFd};
344 use std::os::unix::net::{UnixListener, UnixStream};
345 use std::thread;
346
347 use super::*;
348
349 #[test]
350 fn send_scm_rights_to_backend_socket_transfers_fd_and_token() {
351 let dir = tempfile::tempdir().unwrap();
352 let socket_path = dir.path().join("handoff.sock");
353 let listener = UnixListener::bind(&socket_path).unwrap();
354 let expected_token = HandoffToken::from_bytes([0x41; 16]);
355 let receiver = thread::spawn(move || {
356 let (stream, _) = listener.accept().unwrap();
357 recv_fd_and_token(stream)
358 });
359 let file = File::open("/dev/null").unwrap();
360 let attempt = ScmRightsAttempt::new(
361 UnixFileDescriptor::new(file.as_raw_fd()),
362 UnixHandoffSocket::new(socket_path),
363 expected_token,
364 );
365
366 let success = try_send_scm_rights(&attempt).unwrap();
367 let (received_fd, received_token) = receiver.join().unwrap();
368
369 assert_eq!(success.sent_fd, attempt.fd);
370 assert_eq!(success.handoff_token, expected_token);
371 assert_eq!(received_token, expected_token);
372 assert_ne!(received_fd, file.as_raw_fd());
373
374 unsafe {
375 libc::close(received_fd);
376 }
377 }
378
379 #[test]
380 fn missing_backend_socket_maps_to_fallback_safe_error() {
381 let dir = tempfile::tempdir().unwrap();
382 let socket = UnixHandoffSocket::new(dir.path().join("missing.sock"));
383 let file = File::open("/dev/null").unwrap();
384 let attempt = ScmRightsAttempt::new(
385 UnixFileDescriptor::new(file.as_raw_fd()),
386 socket.clone(),
387 HandoffToken::from_bytes([0x42; 16]),
388 );
389
390 let err = try_send_scm_rights(&attempt).unwrap_err();
391
392 assert!(matches!(
393 err,
394 ScmRightsError::BackendSocketUnavailable { socket: ref path }
395 if path == &socket.path
396 ));
397 assert!(err.is_fallback_safe());
398 }
399
400 fn recv_fd_and_token(stream: UnixStream) -> (RawFd, HandoffToken) {
401 let mut token_payload = [0_u8; 16];
402 let mut iov = libc::iovec {
403 iov_base: token_payload.as_mut_ptr().cast(),
404 iov_len: token_payload.len(),
405 };
406 let mut control = vec![0_u8; cmsg_space::<libc::c_int>()];
407 let mut message = unsafe { std::mem::zeroed::<libc::msghdr>() };
408 message.msg_iov = &mut iov;
409 message.msg_iovlen = 1;
410 message.msg_control = control.as_mut_ptr().cast();
411 message.msg_controllen = control.len() as _;
412
413 let received = unsafe { libc::recvmsg(stream.as_raw_fd(), &mut message, 0) };
414 assert_eq!(received as usize, token_payload.len());
415
416 let header = unsafe { libc::CMSG_FIRSTHDR(&message) };
417 assert!(!header.is_null());
418 unsafe {
419 assert_eq!((*header).cmsg_level, libc::SOL_SOCKET);
420 assert_eq!((*header).cmsg_type, libc::SCM_RIGHTS);
421 let received_fd = *libc::CMSG_DATA(header).cast::<libc::c_int>();
422 (received_fd, HandoffToken::from_bytes(token_payload))
423 }
424 }
425}