1use std::os::unix::io::{AsRawFd, RawFd};
2use std::os::unix::net;
3use std::{alloc, io, mem, ptr};
4
5pub mod changelog;
6
7pub trait SendWithFd {
9 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize>;
11}
12
13pub trait RecvWithFd {
15 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)>;
19}
20
21unsafe fn ptr_offset_from(this: *const u8, origin: *const u8) -> isize {
23 isize::wrapping_sub(this as _, origin as _)
24}
25
26unsafe fn construct_msghdr_for(
37 iov: &mut libc::iovec,
38 fd_count: usize,
39) -> (libc::msghdr, alloc::Layout, usize) {
40 let fd_len = mem::size_of::<RawFd>() * fd_count;
41 let cmsg_buffer_len = libc::CMSG_SPACE(fd_len as u32) as usize;
42 let layout = alloc::Layout::from_size_align(cmsg_buffer_len, mem::align_of::<libc::cmsghdr>());
43 let (cmsg_buffer, cmsg_layout) = if let Ok(layout) = layout {
44 const NULL_MUT_U8: *mut u8 = ptr::null_mut();
45 match alloc::alloc(layout) {
46 NULL_MUT_U8 => alloc::handle_alloc_error(layout),
47 x => (x as *mut _, layout),
48 }
49 } else {
50 alloc::handle_alloc_error(alloc::Layout::from_size_align_unchecked(
54 cmsg_buffer_len,
55 mem::align_of::<libc::cmsghdr>(),
56 ))
57 };
58 (
59 libc::msghdr {
60 msg_name: ptr::null_mut(),
61 msg_namelen: 0,
62 msg_iov: iov as *mut _,
63 msg_iovlen: 1,
64 msg_control: cmsg_buffer,
65 msg_controllen: cmsg_buffer_len as _,
66 ..mem::zeroed()
67 },
68 cmsg_layout,
69 fd_len,
70 )
71}
72
73fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
76 unsafe {
77 let mut iov = libc::iovec {
78 iov_base: bs.as_ptr() as *const _ as *mut _,
81 iov_len: bs.len(),
82 };
83 let (mut msghdr, cmsg_layout, fd_len) = construct_msghdr_for(&mut iov, fds.len());
84 let cmsg_buffer = msghdr.msg_control;
85
86 let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
88 ptr::write(
89 cmsg_header,
90 libc::cmsghdr {
91 cmsg_level: libc::SOL_SOCKET,
92 cmsg_type: libc::SCM_RIGHTS,
93 cmsg_len: libc::CMSG_LEN(fd_len as u32) as _,
94 },
95 );
96 #[allow(clippy::cast_ptr_alignment)]
97 let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd;
98 for (i, fd) in fds.iter().enumerate() {
99 ptr::write_unaligned(cmsg_data.add(i), *fd);
100 }
101 let count = libc::sendmsg(socket, &msghdr as *const _, 0);
102 if count < 0 {
103 let error = io::Error::last_os_error();
104 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
105 Err(error)
106 } else {
107 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
108 Ok(count as usize)
109 }
110 }
111}
112
113fn recv_with_fd(socket: RawFd, bs: &mut [u8], mut fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
116 unsafe {
117 let mut iov = libc::iovec {
118 iov_base: bs.as_mut_ptr() as *mut _,
119 iov_len: bs.len(),
120 };
121 let (mut msghdr, cmsg_layout, _) = construct_msghdr_for(&mut iov, fds.len());
122 let cmsg_buffer = msghdr.msg_control;
123 let count = libc::recvmsg(socket, &mut msghdr as *mut _, 0);
124 if count < 0 {
125 let error = io::Error::last_os_error();
126 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
127 return Err(error);
128 }
129
130 let mut descriptor_count = 0;
133 let mut cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
134 while !cmsg_header.is_null() {
135 if (*cmsg_header).cmsg_level == libc::SOL_SOCKET
136 && (*cmsg_header).cmsg_type == libc::SCM_RIGHTS
137 {
138 let data_ptr = libc::CMSG_DATA(cmsg_header);
139 let data_offset = ptr_offset_from(data_ptr, cmsg_header as *const _);
140 debug_assert!(data_offset >= 0);
141 let data_byte_count = (*cmsg_header).cmsg_len as usize - data_offset as usize;
142 debug_assert!((*cmsg_header).cmsg_len as isize > data_offset);
143 debug_assert!(data_byte_count % mem::size_of::<RawFd>() == 0);
144 let rawfd_count = (data_byte_count / mem::size_of::<RawFd>()) as isize;
145 #[allow(clippy::cast_ptr_alignment)]
146 let fd_ptr = data_ptr as *const RawFd;
147 for i in 0..rawfd_count {
148 if let Some((dst, rest)) = { fds }.split_first_mut() {
149 *dst = ptr::read_unaligned(fd_ptr.offset(i));
150 descriptor_count += 1;
151 fds = rest;
152 } else {
153 unreachable!();
163 }
164 }
165 }
166 cmsg_header = libc::CMSG_NXTHDR(&mut msghdr as *mut _, cmsg_header);
167 }
168
169 alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
170 Ok((count as usize, descriptor_count))
171 }
172}
173
174impl SendWithFd for net::UnixStream {
175 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
180 send_with_fd(self.as_raw_fd(), bytes, fds)
181 }
182}
183
184impl SendWithFd for net::UnixDatagram {
185 fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
191 send_with_fd(self.as_raw_fd(), bytes, fds)
192 }
193}
194
195impl RecvWithFd for net::UnixStream {
196 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
202 recv_with_fd(self.as_raw_fd(), bytes, fds)
203 }
204}
205
206impl RecvWithFd for net::UnixDatagram {
207 fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
218 recv_with_fd(self.as_raw_fd(), bytes, fds)
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::{RecvWithFd, SendWithFd};
225 use std::os::unix::io::{AsRawFd, FromRawFd};
226 use std::os::unix::net;
227
228 #[test]
229 fn stream_works() {
230 let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
231 let sent_bytes = b"hello world!";
232 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
233 assert_eq!(
234 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
235 .expect("send should be successful"),
236 sent_bytes.len()
237 );
238 let mut recv_bytes = [0; 128];
239 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
240 assert_eq!(
241 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
242 .expect("recv should be successful"),
243 (sent_bytes.len(), sent_fds.len())
244 );
245 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
246 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
247 let expected_value = Some(std::time::Duration::from_secs(42));
250 unsafe {
251 let s = net::UnixStream::from_raw_fd(sent);
252 s.set_read_timeout(expected_value)
253 .expect("set read timeout");
254 std::mem::forget(s);
255 assert_eq!(
256 net::UnixStream::from_raw_fd(recvd)
257 .read_timeout()
258 .expect("get read timeout"),
259 expected_value
260 );
261 }
262 }
263 }
264
265 #[test]
266 fn datagram_works() {
267 let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
268 let sent_bytes = b"hello world!";
269 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
270 assert_eq!(
271 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
272 .expect("send should be successful"),
273 sent_bytes.len()
274 );
275 let mut recv_bytes = [0; 128];
276 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
277 assert_eq!(
278 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
279 .expect("recv should be successful"),
280 (sent_bytes.len(), sent_fds.len())
281 );
282 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
283 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
284 let expected_value = Some(std::time::Duration::from_secs(42));
287 unsafe {
288 let s = net::UnixDatagram::from_raw_fd(sent);
289 s.set_read_timeout(expected_value)
290 .expect("set read timeout");
291 std::mem::forget(s);
292 assert_eq!(
293 net::UnixDatagram::from_raw_fd(recvd)
294 .read_timeout()
295 .expect("get read timeout"),
296 expected_value
297 );
298 }
299 }
300 }
301
302 #[test]
303 fn datagram_works_across_processes() {
304 let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
305 let sent_bytes = b"hello world!";
306 let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
307
308 unsafe {
309 match libc::fork() {
310 -1 => panic!("fork failed!"),
311 0 => {
312 l.send_with_fd(&sent_bytes[..], &sent_fds[..])
315 .expect("send should be successful");
316 ::std::process::exit(0);
317 }
318 _ => {
319 }
321 }
322 let mut recv_bytes = [0; 128];
323 let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
324 assert_eq!(
325 r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
326 .expect("recv should be successful"),
327 (sent_bytes.len(), sent_fds.len())
328 );
329 assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
330 for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
331 let expected_value = Some(std::time::Duration::from_secs(42));
334 let s = net::UnixDatagram::from_raw_fd(sent);
335 s.set_read_timeout(expected_value)
336 .expect("set read timeout");
337 std::mem::forget(s);
338 assert_eq!(
339 net::UnixDatagram::from_raw_fd(recvd)
340 .read_timeout()
341 .expect("get read timeout"),
342 expected_value
343 );
344 }
345 }
346 }
347
348 #[test]
349 fn sending_junk_fails() {
350 let (l, _) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
351 let sent_bytes = b"hello world!";
352 if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[i32::max_value()][..]) {
353 panic!("expected an error when sending a junk file descriptor");
354 }
355 if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[0xffi32][..]) {
356 panic!("expected an error when sending a junk file descriptor");
357 }
358 }
359}