1#![allow(dead_code)]
8
9use std::fs::File;
13use std::io::{IoSlice, IoSliceMut, Result};
14use std::mem::{size_of, size_of_val};
15use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
16use std::os::unix::net::{UnixDatagram, UnixStream};
17use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
18use std::slice;
19
20use libc::{
21 c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
22};
23
24use crate::Error;
25
26macro_rules! CMSG_ALIGN {
30 ($len:expr) => {
31 (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
32 };
33}
34
35macro_rules! CMSG_SPACE {
36 ($len:expr) => {
37 size_of::<cmsghdr>() + CMSG_ALIGN!($len)
38 };
39}
40
41macro_rules! CMSG_LEN {
42 ($len:expr) => {
43 size_of::<cmsghdr>() + ($len)
44 };
45}
46
47#[allow(non_snake_case)]
51#[inline(always)]
52fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
53 cmsg_buffer.wrapping_offset(1) as *mut RawFd
55}
56
57#[allow(clippy::cast_ptr_alignment)]
60fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
61 let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr;
62 if next_cmsg
63 .wrapping_offset(1)
64 .wrapping_sub(msghdr.msg_control as usize) as usize
65 > msghdr.msg_controllen
66 {
67 null_mut()
68 } else {
69 next_cmsg
70 }
71}
72
73const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
74
75enum CmsgBuffer {
76 Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
77 Heap(Box<[cmsghdr]>),
78}
79
80impl CmsgBuffer {
81 fn with_capacity(capacity: usize) -> CmsgBuffer {
82 let cap_in_cmsghdr_units =
83 (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
84 if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
85 CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
86 } else {
87 CmsgBuffer::Heap(
88 vec![
89 cmsghdr {
90 cmsg_len: 0,
91 cmsg_level: 0,
92 cmsg_type: 0,
93 };
94 cap_in_cmsghdr_units
95 ]
96 .into_boxed_slice(),
97 )
98 }
99 }
100
101 fn as_mut_ptr(&mut self) -> *mut cmsghdr {
102 match self {
103 CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
104 CmsgBuffer::Heap(a) => a.as_mut_ptr(),
105 }
106 }
107}
108
109fn raw_sendmsg<D: IntoIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
110 let cmsg_capacity = CMSG_SPACE!(size_of_val(out_fds));
111 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
112
113 let iovec = IntoIobuf::as_iobufs(out_data);
114
115 let mut msg = msghdr {
116 msg_name: null_mut(),
117 msg_namelen: 0,
118 msg_iov: iovec.as_ptr() as *mut iovec,
119 msg_iovlen: iovec.len(),
120 msg_control: null_mut(),
121 msg_controllen: 0,
122 msg_flags: 0,
123 };
124
125 if !out_fds.is_empty() {
126 let cmsg = cmsghdr {
127 cmsg_len: CMSG_LEN!(size_of_val(out_fds)),
128 cmsg_level: SOL_SOCKET,
129 cmsg_type: SCM_RIGHTS,
130 };
131 unsafe {
132 write_unaligned(cmsg_buffer.as_mut_ptr(), cmsg);
134 copy_nonoverlapping(
137 out_fds.as_ptr(),
138 CMSG_DATA(cmsg_buffer.as_mut_ptr()),
139 out_fds.len(),
140 );
141 }
142
143 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
144 msg.msg_controllen = cmsg_capacity;
145 }
146
147 let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
150
151 if write_count == -1 {
152 Err(Error::last_os_error())
153 } else {
154 Ok(write_count as usize)
155 }
156}
157
158fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
159 let cmsg_capacity = CMSG_SPACE!(size_of_val(in_fds));
160 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
161
162 let mut iovec = iovec {
163 iov_base: in_data.as_mut_ptr() as *mut c_void,
164 iov_len: in_data.len(),
165 };
166
167 let mut msg = msghdr {
168 msg_name: null_mut(),
169 msg_namelen: 0,
170 msg_iov: &mut iovec as *mut iovec,
171 msg_iovlen: 1,
172 msg_control: null_mut(),
173 msg_controllen: 0,
174 msg_flags: 0,
175 };
176
177 if !in_fds.is_empty() {
178 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
179 msg.msg_controllen = cmsg_capacity;
180 }
181
182 let total_read = unsafe { recvmsg(fd, &mut msg, 0) };
185
186 if total_read == -1 {
187 return Err(Error::last_os_error());
188 }
189
190 if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() {
191 return Ok((0, 0));
192 }
193
194 let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
195 let mut in_fds_count = 0;
196 while !cmsg_ptr.is_null() {
197 let cmsg = unsafe { (cmsg_ptr).read_unaligned() };
200
201 if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
202 let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>();
203 unsafe {
204 copy_nonoverlapping(
205 CMSG_DATA(cmsg_ptr),
206 in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
207 fd_count,
208 );
209 }
210 in_fds_count += fd_count;
211 }
212
213 cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
214 }
215
216 Ok((total_read as usize, in_fds_count))
217}
218
219pub const SCM_SOCKET_MAX_FD_COUNT: usize = 253;
221
222pub trait ScmSocket {
225 fn socket_fd(&self) -> RawFd;
227
228 fn send_with_fd<D: IntoIobuf>(&self, buf: &[D], fd: RawFd) -> Result<usize> {
237 self.send_with_fds(buf, &[fd])
238 }
239
240 fn send_with_fds<D: IntoIobuf>(&self, buf: &[D], fd: &[RawFd]) -> Result<usize> {
249 raw_sendmsg(self.socket_fd(), buf, fd)
250 }
251
252 fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
260 let mut fd = [0];
261 let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?;
262 let file = if fd_count == 0 {
263 None
264 } else {
265 Some(unsafe { File::from_raw_fd(fd[0]) })
268 };
269 Ok((read_count, file))
270 }
271
272 fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> {
286 raw_recvmsg(self.socket_fd(), buf, fds)
287 }
288}
289
290impl ScmSocket for UnixDatagram {
291 fn socket_fd(&self) -> RawFd {
292 self.as_raw_fd()
293 }
294}
295
296impl ScmSocket for UnixStream {
297 fn socket_fd(&self) -> RawFd {
298 self.as_raw_fd()
299 }
300}
301
302pub unsafe trait IntoIobuf: Sized {
309 fn into_iobuf(self) -> iovec;
311
312 fn as_iobufs(bufs: &[Self]) -> &[iovec];
314}
315
316unsafe impl<'a> IntoIobuf for IoSlice<'a> {
319 fn into_iobuf(self) -> iovec {
320 iovec {
321 iov_base: self.as_ptr() as *mut c_void,
322 iov_len: self.len(),
323 }
324 }
325
326 fn as_iobufs(bufs: &[Self]) -> &[iovec] {
327 unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
329 }
330}
331
332unsafe impl<'a> IntoIobuf for IoSliceMut<'a> {
335 fn into_iobuf(self) -> iovec {
336 iovec {
337 iov_base: self.as_ptr() as *mut c_void,
338 iov_len: self.len(),
339 }
340 }
341
342 fn as_iobufs(bufs: &[Self]) -> &[iovec] {
343 unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 use std::io::Write;
353 use std::mem::size_of;
354 use std::os::raw::c_long;
355 use std::os::unix::net::UnixDatagram;
356 use std::slice::from_raw_parts;
357
358 use libc::cmsghdr;
359
360 use crate::{EventFd, EventfdFlags};
361
362 #[test]
363 fn buffer_len() {
364 assert_eq!(CMSG_SPACE!(0), size_of::<cmsghdr>());
365 assert_eq!(
366 CMSG_SPACE!(size_of::<RawFd>()),
367 size_of::<cmsghdr>() + size_of::<c_long>()
368 );
369 if size_of::<RawFd>() == 4 {
370 assert_eq!(
371 CMSG_SPACE!(2 * size_of::<RawFd>()),
372 size_of::<cmsghdr>() + size_of::<c_long>()
373 );
374 assert_eq!(
375 CMSG_SPACE!(3 * size_of::<RawFd>()),
376 size_of::<cmsghdr>() + size_of::<c_long>() * 2
377 );
378 assert_eq!(
379 CMSG_SPACE!(4 * size_of::<RawFd>()),
380 size_of::<cmsghdr>() + size_of::<c_long>() * 2
381 );
382 } else if size_of::<RawFd>() == 8 {
383 assert_eq!(
384 CMSG_SPACE!(2 * size_of::<RawFd>()),
385 size_of::<cmsghdr>() + size_of::<c_long>() * 2
386 );
387 assert_eq!(
388 CMSG_SPACE!(3 * size_of::<RawFd>()),
389 size_of::<cmsghdr>() + size_of::<c_long>() * 3
390 );
391 assert_eq!(
392 CMSG_SPACE!(4 * size_of::<RawFd>()),
393 size_of::<cmsghdr>() + size_of::<c_long>() * 4
394 );
395 }
396 }
397
398 #[test]
399 fn send_recv_no_fd() {
400 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
401
402 let ioslice = IoSlice::new([1u8, 1, 2, 21, 34, 55].as_ref());
403 let write_count = s1
404 .send_with_fds(&[ioslice], &[])
405 .expect("failed to send data");
406
407 assert_eq!(write_count, 6);
408
409 let mut buf = [0; 6];
410 let mut files = [0; 1];
411 let (read_count, file_count) = s2
412 .recv_with_fds(&mut buf[..], &mut files)
413 .expect("failed to recv data");
414
415 assert_eq!(read_count, 6);
416 assert_eq!(file_count, 0);
417 assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
418 }
419
420 #[test]
421 fn send_recv_only_fd() {
422 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
423
424 let evt = EventFd::new(EventfdFlags::empty()).expect("failed to create eventfd");
425 let ioslice = IoSlice::new([].as_ref());
426 let write_count = s1
427 .send_with_fd(&[ioslice], evt.as_raw_fd())
428 .expect("failed to send fd");
429
430 assert_eq!(write_count, 0);
431
432 let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
433
434 let mut file = file_opt.unwrap();
435
436 assert_eq!(read_count, 0);
437 assert!(file.as_raw_fd() >= 0);
438 assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
439 assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
440 assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
441
442 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
443 .expect("failed to write to sent fd");
444
445 assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
446 }
447
448 #[test]
449 fn send_recv_with_fd() {
450 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
451
452 let evt = EventFd::new(EventfdFlags::empty()).expect("failed to create eventfd");
453 let ioslice = IoSlice::new([237].as_ref());
454 let write_count = s1
455 .send_with_fds(&[ioslice], &[evt.as_raw_fd()])
456 .expect("failed to send fd");
457
458 assert_eq!(write_count, 1);
459
460 let mut files = [0; 2];
461 let mut buf = [0u8];
462 let (read_count, file_count) = s2
463 .recv_with_fds(&mut buf, &mut files)
464 .expect("failed to recv fd");
465
466 assert_eq!(read_count, 1);
467 assert_eq!(buf[0], 237);
468 assert_eq!(file_count, 1);
469 assert!(files[0] >= 0);
470 assert_ne!(files[0], s1.as_raw_fd());
471 assert_ne!(files[0], s2.as_raw_fd());
472 assert_ne!(files[0], evt.as_raw_fd());
473
474 let mut file = unsafe { File::from_raw_fd(files[0]) };
475
476 file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
477 .expect("failed to write to sent fd");
478
479 assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
480 }
481}