1use std::fs::File;
12use std::mem;
13use std::mem::size_of;
14use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
15use std::os::unix::net::{UnixDatagram, UnixStream};
16use std::ptr::null_mut;
17
18use crate::errno::{Error, Result};
19use crate::fam::{FamStruct, FamStructWrapper};
20use libc::{
21 c_uint, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, CMSG_LEN, CMSG_SPACE, MSG_NOSIGNAL,
22 SCM_RIGHTS, SOL_SOCKET,
23};
24
25#[cfg(not(target_env = "musl"))]
26fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
27 msghdr {
28 msg_name: null_mut(),
29 msg_namelen: 0,
30 msg_iov: iovecs.as_mut_ptr(),
31 #[cfg(any(target_os = "linux", target_os = "android"))]
32 msg_iovlen: iovecs.len(),
33 #[cfg(not(any(target_os = "linux", target_os = "android")))]
34 msg_iovlen: iovecs
35 .len()
36 .try_into()
37 .expect("iovecs.len() exceeds i32 range"),
38 msg_control: null_mut(),
39 msg_controllen: 0,
40 msg_flags: 0,
41 }
42}
43
44#[cfg(target_env = "musl")]
45fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
46 assert!(iovecs.len() <= (std::i32::MAX as usize));
47 let mut msg: msghdr = unsafe { std::mem::zeroed() };
48 msg.msg_name = null_mut();
49 msg.msg_iov = iovecs.as_mut_ptr();
50 msg.msg_iovlen = iovecs.len() as i32;
51 msg.msg_control = null_mut();
52 msg
53}
54
55#[cfg(all(
56 not(target_env = "musl"),
57 any(target_os = "linux", target_os = "android")
58))]
59fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: c_uint) {
60 msg.msg_controllen = cmsg_capacity as libc::size_t;
61}
62
63#[cfg(any(
64 target_env = "musl",
65 not(any(target_os = "linux", target_os = "android"))
66))]
67fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: c_uint) {
68 msg.msg_controllen = cmsg_capacity;
69}
70
71#[repr(transparent)]
72struct CmsgHdr(cmsghdr);
73
74impl Default for CmsgHdr {
75 fn default() -> Self {
76 Self(unsafe { mem::zeroed() })
78 }
79}
80
81unsafe impl FamStruct for CmsgHdr {
87 type Entry = RawFd;
88
89 #[allow(clippy::unnecessary_cast)]
91 fn len(&self) -> usize {
92 (self.0.cmsg_len as usize - size_of::<cmsghdr>()) / size_of::<RawFd>()
93 }
94
95 unsafe fn set_len(&mut self, len: usize) {
96 self.0.cmsg_len = CMSG_LEN((len * size_of::<RawFd>()) as _) as _;
101 }
102
103 fn max_len() -> usize {
104 (u32::MAX as usize - size_of::<cmsghdr>()) / size_of::<RawFd>()
105 }
106
107 fn as_slice(&self) -> &[RawFd] {
108 unsafe { std::slice::from_raw_parts((&self.0 as *const cmsghdr).add(1).cast(), self.len()) }
111 }
112
113 fn as_mut_slice(&mut self) -> &mut [RawFd] {
114 unsafe {
117 std::slice::from_raw_parts_mut((&mut self.0 as *mut cmsghdr).add(1).cast(), self.len())
118 }
119 }
120}
121
122type CmsgBuffer = FamStructWrapper<CmsgHdr>;
123
124fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
125 let mut cmsg_buffer =
126 CmsgBuffer::from_entries(out_fds).map_err(|_| Error::new(libc::ENOMEM))?;
127
128 let mut iovecs = Vec::with_capacity(out_data.len());
129 for data in out_data {
130 iovecs.push(iovec {
131 iov_base: data.as_ptr() as *mut c_void,
132 iov_len: data.size(),
133 });
134 }
135
136 let mut msg = new_msghdr(&mut iovecs);
137
138 if !out_fds.is_empty() {
139 unsafe {
141 let hdr = cmsg_buffer.as_mut_fam_struct();
142 hdr.0.cmsg_level = SOL_SOCKET;
143 hdr.0.cmsg_type = SCM_RIGHTS;
144 }
145
146 msg.msg_control = cmsg_buffer.as_mut_fam_struct_ptr() as *mut c_void;
147 unsafe {
149 set_msg_controllen(&mut msg, CMSG_SPACE(size_of_val(out_fds) as _));
150 }
151 }
152
153 let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
156
157 if write_count == -1 {
158 Err(Error::last())
159 } else {
160 Ok(write_count as usize)
161 }
162}
163
164#[allow(clippy::unnecessary_cast)]
165unsafe fn raw_recvmsg(
166 fd: RawFd,
167 iovecs: &mut [iovec],
168 in_fds: &mut [RawFd],
169) -> Result<(usize, usize)> {
170 let mut cmsg_buffer = CmsgBuffer::new(in_fds.len()).map_err(|_| Error::new(libc::ENOMEM))?;
171 let mut msg = new_msghdr(iovecs);
172 let cmsg_capacity = cmsg_buffer.as_fam_struct_ref().len();
174
175 if !in_fds.is_empty() {
176 msg.msg_control = cmsg_buffer.as_mut_fam_struct_ptr() as *mut c_void;
178 unsafe {
180 set_msg_controllen(&mut msg, CMSG_SPACE(size_of_val(in_fds) as _));
181 }
182 }
183
184 let total_read = recvmsg(fd, &mut msg, 0);
188 if total_read == -1 {
189 return Err(Error::last());
190 }
191
192 let mut copied_fds_count = 0;
193 let mut teardown_control_data = msg.msg_flags & libc::MSG_CTRUNC != 0;
196
197 if !msg.msg_control.is_null() {
198 let cmsg = cmsg_buffer.as_mut_fam_struct();
199
200 if cmsg.0.cmsg_level == SOL_SOCKET && cmsg.0.cmsg_type == SCM_RIGHTS {
201 unsafe { cmsg.set_len(cmsg.len().min(cmsg_capacity)) }
206 let fds = cmsg.as_slice();
207 teardown_control_data |= fds.len() > in_fds.len();
212 if teardown_control_data {
213 for fd in fds {
214 libc::close(*fd);
215 }
216 } else {
217 in_fds[..fds.len()].copy_from_slice(fds);
218
219 copied_fds_count = fds.len();
220 }
221 }
222 }
223
224 if teardown_control_data {
225 return Err(Error::new(libc::ENOBUFS));
226 }
227
228 Ok((total_read as usize, copied_fds_count))
229}
230
231pub trait ScmSocket {
268 fn socket_fd(&self) -> RawFd;
270
271 fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
280 self.send_with_fds(&[buf], &[fd])
281 }
282
283 fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> {
292 raw_sendmsg(self.socket_fd(), bufs, fds)
293 }
294
295 fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
303 let mut fd = [0];
304 let mut iovecs = [iovec {
305 iov_base: buf.as_mut_ptr() as *mut c_void,
306 iov_len: buf.len(),
307 }];
308
309 let (read_count, fd_count) = unsafe { self.recv_with_fds(&mut iovecs[..], &mut fd)? };
312 let file = if fd_count == 0 {
313 None
314 } else {
315 Some(unsafe { File::from_raw_fd(fd[0]) })
318 };
319 Ok((read_count, file))
320 }
321
322 unsafe fn recv_with_fds(
341 &self,
342 iovecs: &mut [iovec],
343 fds: &mut [RawFd],
344 ) -> Result<(usize, usize)> {
345 raw_recvmsg(self.socket_fd(), iovecs, fds)
346 }
347}
348
349impl ScmSocket for UnixDatagram {
350 fn socket_fd(&self) -> RawFd {
351 self.as_raw_fd()
352 }
353}
354
355impl ScmSocket for UnixStream {
356 fn socket_fd(&self) -> RawFd {
357 self.as_raw_fd()
358 }
359}
360
361pub unsafe trait IntoIovec {
369 fn as_ptr(&self) -> *const c_void;
371
372 fn size(&self) -> usize;
374}
375
376unsafe impl IntoIovec for &[u8] {
379 #[allow(clippy::useless_asref)]
381 fn as_ptr(&self) -> *const c_void {
382 self.as_ref().as_ptr() as *const c_void
383 }
384
385 fn size(&self) -> usize {
386 self.len()
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 #![allow(clippy::undocumented_unsafe_blocks)]
393 use super::*;
394 use std::io::{pipe, Read};
395
396 use std::io::Write;
397 use std::os::unix::net::UnixDatagram;
398 use std::slice::from_raw_parts;
399
400 #[test]
401 fn send_recv_no_fd() {
402 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
403
404 let write_count = s1
405 .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
406 .expect("failed to send data");
407
408 assert_eq!(write_count, 6);
409
410 let mut buf = [0u8; 6];
411 let mut files = [0; 1];
412 let mut iovecs = [iovec {
413 iov_base: buf.as_mut_ptr() as *mut c_void,
414 iov_len: buf.len(),
415 }];
416 let (read_count, file_count) = unsafe {
417 s2.recv_with_fds(&mut iovecs[..], &mut files)
418 .expect("failed to recv data")
419 };
420
421 assert_eq!(read_count, 6);
422 assert_eq!(file_count, 0);
423 assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
424 }
425
426 #[test]
427 fn send_recv_only_fd() {
428 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
429
430 let (mut evt_consumer, evt_notifier) = pipe().expect("failed to create pipe");
431 let write_count = s1
432 .send_with_fd([].as_ref(), evt_notifier.as_raw_fd())
433 .expect("failed to send fd");
434
435 assert_eq!(write_count, 0);
436
437 let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
438
439 let mut file = file_opt.unwrap();
440
441 assert_eq!(read_count, 0);
442 assert!(file.as_raw_fd() >= 0);
443 assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
444 assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
445 assert_ne!(file.as_raw_fd(), evt_notifier.as_raw_fd());
446
447 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
448 .expect("failed to write to sent fd");
449
450 let mut buf = [0u8; std::mem::size_of::<u64>()];
451 evt_consumer
452 .read_exact(buf.as_mut_slice())
453 .expect("Failed to read from PipeReader");
454 assert_eq!(u64::from_ne_bytes(buf), 1203);
455 }
456
457 #[test]
458 fn send_recv_with_fd() {
459 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
460
461 let (mut evt_consumer, evt_notifier) = pipe().expect("failed to create pipe");
462 let write_count = s1
463 .send_with_fds(&[[237].as_ref()], &[evt_notifier.as_raw_fd()])
464 .expect("failed to send fd");
465
466 assert_eq!(write_count, 1);
467
468 let mut files = [0; 2];
469 let mut buf = [0u8];
470 let mut iovecs = [iovec {
471 iov_base: buf.as_mut_ptr() as *mut c_void,
472 iov_len: buf.len(),
473 }];
474 let (read_count, file_count) = unsafe {
475 s2.recv_with_fds(&mut iovecs[..], &mut files)
476 .expect("failed to recv fd")
477 };
478
479 assert_eq!(read_count, 1);
480 assert_eq!(buf[0], 237);
481 assert_eq!(file_count, 1);
482 assert!(files[0] >= 0);
483 assert_ne!(files[0], s1.as_raw_fd());
484 assert_ne!(files[0], s2.as_raw_fd());
485 assert_ne!(files[0], evt_notifier.as_raw_fd());
486
487 let mut file = unsafe { File::from_raw_fd(files[0]) };
488
489 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
490 .expect("failed to write to sent fd");
491
492 let mut buf = [0u8; std::mem::size_of::<u64>()];
493 evt_consumer
494 .read_exact(buf.as_mut_slice())
495 .expect("Failed to read from PipeReader");
496 assert_eq!(u64::from_ne_bytes(buf), 1203);
497 }
498
499 #[test]
500 fn send_more_recv_less() {
503 #[cfg(any(target_os = "linux", target_os = "android"))]
506 let start = 0;
507 #[cfg(not(any(target_os = "linux", target_os = "android")))]
508 let start = 1;
509
510 for too_small in start..3 {
511 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
512
513 let (_, evt_notifier1) = pipe().expect("failed to create pipe");
514 let (_, evt_notifier2) = pipe().expect("failed to create pipe");
515 let (_, evt_notifier3) = pipe().expect("failed to create pipe");
516 let (_, evt_notifier4) = pipe().expect("failed to create pipe");
517 let write_count = s1
518 .send_with_fds(
519 &[[237].as_ref()],
520 &[
521 evt_notifier1.as_raw_fd(),
522 evt_notifier2.as_raw_fd(),
523 evt_notifier3.as_raw_fd(),
524 evt_notifier4.as_raw_fd(),
525 ],
526 )
527 .expect("failed to send fd");
528
529 assert_eq!(write_count, 1);
530
531 let mut files = vec![0; too_small];
532 let mut buf = [0u8];
533 let mut iovecs = [iovec {
534 iov_base: buf.as_mut_ptr() as *mut c_void,
535 iov_len: buf.len(),
536 }];
537 unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).unwrap_err() };
538 }
539 }
540}