1#![cfg_attr(sendfd_docs, feature(doc_cfg))]
2
3extern crate libc;
4#[cfg(feature = "tokio")]
5extern crate tokio;
6
7use std::os::unix::io::{AsRawFd, RawFd};
8use std::os::unix::net;
9use std::{alloc, io, mem, ptr};
10#[cfg(feature = "tokio")]
11use tokio::io::Interest;
12
13pub mod changelog;
14
15pub trait SendWithFd {
17 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize>;
19}
20
21pub trait RecvWithFd {
23 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)>;
27}
28
29unsafe fn ptr_offset_from(this: *const u8, origin: *const u8) -> isize {
31 isize::wrapping_sub(this as _, origin as _)
32}
33
34unsafe fn construct_msghdr_for(
45 iov: &mut libc::iovec,
46 fd_count: usize,
47) -> (libc::msghdr, alloc::Layout, usize) {
48 let fd_len = mem::size_of::<RawFd>() * fd_count;
49 let cmsg_buffer_len = libc::CMSG_SPACE(fd_len as u32) as usize;
50 let layout = alloc::Layout::from_size_align(cmsg_buffer_len, mem::align_of::<libc::cmsghdr>());
51 let (cmsg_buffer, cmsg_layout) = if let Ok(layout) = layout {
52 const NULL_MUT_U8: *mut u8 = ptr::null_mut();
53 match alloc::alloc(layout) {
54 NULL_MUT_U8 => alloc::handle_alloc_error(layout),
55 x => (x as *mut _, layout),
56 }
57 } else {
58 alloc::handle_alloc_error(alloc::Layout::from_size_align_unchecked(
62 cmsg_buffer_len,
63 mem::align_of::<libc::cmsghdr>(),
64 ))
65 };
66
67 let mut msghdr = mem::zeroed::<libc::msghdr>();
68 msghdr.msg_name = ptr::null_mut();
69 msghdr.msg_namelen = 0;
70 msghdr.msg_iov = iov as *mut _;
71 msghdr.msg_iovlen = 1;
72 msghdr.msg_control = cmsg_buffer;
73 msghdr.msg_controllen = cmsg_buffer_len as _;
74
75 (msghdr, cmsg_layout, fd_len)
76}
77
78fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
81 unsafe {
82 let mut iov = libc::iovec {
83 iov_base: bs.as_ptr() as *const _ as *mut _,
86 iov_len: bs.len(),
87 };
88 let (mut msghdr, cmsg_layout, fd_len) = construct_msghdr_for(&mut iov, fds.len());
89 let cmsg_buffer = msghdr.msg_control;
90
91 let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
93 let mut cmsghdr = mem::zeroed::<libc::cmsghdr>();
94 cmsghdr.cmsg_level = libc::SOL_SOCKET;
95 cmsghdr.cmsg_type = libc::SCM_RIGHTS;
96 cmsghdr.cmsg_len = libc::CMSG_LEN(fd_len as u32) as _;
97
98 ptr::write(cmsg_header, cmsghdr);
99
100 let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd;
101 for (i, fd) in fds.iter().enumerate() {
102 ptr::write_unaligned(cmsg_data.add(i), *fd);
103 }
104 let count = libc::sendmsg(socket, &msghdr as *const _, 0);
105 if count < 0 {
106 let error = io::Error::last_os_error();
107 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
108 Err(error)
109 } else {
110 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
111 Ok(count as usize)
112 }
113 }
114}
115
116fn recv_with_fd(socket: RawFd, bs: &mut [u8], mut fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
119 unsafe {
120 let mut iov = libc::iovec {
121 iov_base: bs.as_mut_ptr() as *mut _,
122 iov_len: bs.len(),
123 };
124 let (mut msghdr, cmsg_layout, _) = construct_msghdr_for(&mut iov, fds.len());
125 let cmsg_buffer = msghdr.msg_control;
126 let count = libc::recvmsg(socket, &mut msghdr as *mut _, 0);
127 if count < 0 {
128 let error = io::Error::last_os_error();
129 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
130 return Err(error);
131 }
132
133 let mut descriptor_count = 0;
136 let mut cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
137 while !cmsg_header.is_null() {
138 if (*cmsg_header).cmsg_level == libc::SOL_SOCKET
139 && (*cmsg_header).cmsg_type == libc::SCM_RIGHTS
140 {
141 let data_ptr = libc::CMSG_DATA(cmsg_header);
142 let data_offset = ptr_offset_from(data_ptr, cmsg_header as *const _);
143 debug_assert!(data_offset >= 0);
144 let data_byte_count = (*cmsg_header).cmsg_len as usize - data_offset as usize;
145 debug_assert!((*cmsg_header).cmsg_len as isize >= data_offset);
146 debug_assert!(data_byte_count % mem::size_of::<RawFd>() == 0);
147 let rawfd_count = (data_byte_count / mem::size_of::<RawFd>()) as isize;
148 let fd_ptr = data_ptr as *const RawFd;
149 for i in 0..rawfd_count {
150 if let Some((dst, rest)) = { fds }.split_first_mut() {
151 *dst = ptr::read_unaligned(fd_ptr.offset(i));
152 descriptor_count += 1;
153 fds = rest;
154 } else {
155 unreachable!();
165 }
166 }
167 }
168 cmsg_header = libc::CMSG_NXTHDR(&mut msghdr as *mut _, cmsg_header);
169 }
170
171 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
172 Ok((count as usize, descriptor_count))
173 }
174}
175
176impl SendWithFd for net::UnixStream {
177 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
182 send_with_fd(self.as_raw_fd(), bytes, fds)
183 }
184}
185
186#[cfg(feature = "tokio")]
187#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
188impl SendWithFd for tokio::net::UnixStream {
189 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
194 self.try_io(Interest::WRITABLE, || {
195 send_with_fd(self.as_raw_fd(), bytes, fds)
196 })
197 }
198}
199
200#[cfg(feature = "tokio")]
201#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
202impl SendWithFd for tokio::net::unix::WriteHalf<'_> {
203 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
208 let unix_stream: &tokio::net::UnixStream = self.as_ref();
209 unix_stream.send_with_fd(bytes, fds)
210 }
211}
212
213impl SendWithFd for net::UnixDatagram {
214 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
220 send_with_fd(self.as_raw_fd(), bytes, fds)
221 }
222}
223
224#[cfg(feature = "tokio")]
225#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
226impl SendWithFd for tokio::net::UnixDatagram {
227 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
233 self.try_io(Interest::WRITABLE, || {
234 send_with_fd(self.as_raw_fd(), bytes, fds)
235 })
236 }
237}
238
239impl RecvWithFd for net::UnixStream {
240 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
246 recv_with_fd(self.as_raw_fd(), bytes, fds)
247 }
248}
249
250#[cfg(feature = "tokio")]
251#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
252impl RecvWithFd for tokio::net::UnixStream {
253 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
259 self.try_io(Interest::READABLE, || {
260 recv_with_fd(self.as_raw_fd(), bytes, fds)
261 })
262 }
263}
264
265#[cfg(feature = "tokio")]
266#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
267impl RecvWithFd for tokio::net::unix::ReadHalf<'_> {
268 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
274 let unix_stream: &tokio::net::UnixStream = self.as_ref();
275 unix_stream.recv_with_fd(bytes, fds)
276 }
277}
278
279impl RecvWithFd for net::UnixDatagram {
280 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
291 recv_with_fd(self.as_raw_fd(), bytes, fds)
292 }
293}
294
295#[cfg(feature = "tokio")]
296#[cfg_attr(sendfd_docs, doc(cfg(feature = "tokio")))]
297impl RecvWithFd for tokio::net::UnixDatagram {
298 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
309 self.try_io(Interest::READABLE, || {
310 recv_with_fd(self.as_raw_fd(), bytes, fds)
311 })
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::{RecvWithFd, SendWithFd};
318 use std::os::unix::io::{AsRawFd, FromRawFd};
319 use std::os::unix::net;
320
321 #[test]
322 fn stream_works() {
323 let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
324 let sent_bytes = b"hello world!";
325 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
326 assert_eq!(
327 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
328 .expect("send should be successful"),
329 sent_bytes.len()
330 );
331 let mut recv_bytes = [0; 128];
332 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
333 assert_eq!(
334 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
335 .expect("recv should be successful"),
336 (sent_bytes.len(), sent_fds.len())
337 );
338 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
339 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
340 let expected_value = Some(std::time::Duration::from_secs(42));
343 unsafe {
344 let s = net::UnixStream::from_raw_fd(sent);
345 s.set_read_timeout(expected_value)
346 .expect("set read timeout");
347 std::mem::forget(s);
348 assert_eq!(
349 net::UnixStream::from_raw_fd(recvd)
350 .read_timeout()
351 .expect("get read timeout"),
352 expected_value
353 );
354 }
355 }
356 }
357
358 #[test]
359 fn datagram_works() {
360 let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
361 let sent_bytes = b"hello world!";
362 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
363 assert_eq!(
364 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
365 .expect("send should be successful"),
366 sent_bytes.len()
367 );
368 let mut recv_bytes = [0; 128];
369 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
370 assert_eq!(
371 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
372 .expect("recv should be successful"),
373 (sent_bytes.len(), sent_fds.len())
374 );
375 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
376 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
377 let expected_value = Some(std::time::Duration::from_secs(42));
380 unsafe {
381 let s = net::UnixDatagram::from_raw_fd(sent);
382 s.set_read_timeout(expected_value)
383 .expect("set read timeout");
384 std::mem::forget(s);
385 assert_eq!(
386 net::UnixDatagram::from_raw_fd(recvd)
387 .read_timeout()
388 .expect("get read timeout"),
389 expected_value
390 );
391 }
392 }
393 }
394
395 #[test]
396 fn datagram_works_across_processes() {
397 let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
398 let sent_bytes = b"hello world!";
399 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
400
401 unsafe {
402 match libc::fork() {
403 -1 => panic!("fork failed!"),
404 0 => {
405 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
408 .expect("send should be successful");
409 ::std::process::exit(0);
410 }
411 _ => {
412 }
414 }
415 let mut recv_bytes = [0; 128];
416 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
417 assert_eq!(
418 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
419 .expect("recv should be successful"),
420 (sent_bytes.len(), sent_fds.len())
421 );
422 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
423 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
424 let expected_value = Some(std::time::Duration::from_secs(42));
427 let s = net::UnixDatagram::from_raw_fd(sent);
428 s.set_read_timeout(expected_value)
429 .expect("set read timeout");
430 std::mem::forget(s);
431 assert_eq!(
432 net::UnixDatagram::from_raw_fd(recvd)
433 .read_timeout()
434 .expect("get read timeout"),
435 expected_value
436 );
437 }
438 }
439 }
440
441 #[test]
442 fn sending_junk_fails() {
443 let (l, _) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
444 let sent_bytes = b"hello world!";
445 if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[i32::max_value()][..]) {
446 panic!("expected an error when sending a junk file descriptor");
447 }
448 if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[0xffi32][..]) {
449 panic!("expected an error when sending a junk file descriptor");
450 }
451 }
452
453 #[test]
454 fn sending_empty_fds_works() {
455 let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
456 let sent_bytes = b"hello world!";
457 let sent_fds = [];
458
459 assert_eq!(
460 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
461 .expect("send should be successful"),
462 sent_bytes.len()
463 );
464
465 let mut recv_bytes = [0; 128];
466 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
467
468 assert_eq!(
469 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
470 .expect("recv should be successful"),
471 (sent_bytes.len(), sent_fds.len())
472 );
473 }
474}