sendfd_new/
lib.rs

1use std::os::unix::io::{AsRawFd, RawFd};
2use std::os::unix::net;
3use std::{alloc, io, mem, ptr};
4
5pub mod changelog;
6
7/// An extension trait that enables sending associated file descriptors along with the data.
8pub trait SendWithFd {
9    /// Send the bytes and the file descriptors.
10    fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize>;
11}
12
13/// An extension trait that enables receiving associated file descriptors along with the data.
14pub trait RecvWithFd {
15    /// Receive the bytes and the file descriptors.
16    ///
17    /// The bytes and the file descriptors are received into the corresponding buffers.
18    fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)>;
19}
20
21// Replace with `<*const u8>::offset_from` once it is stable.
22unsafe fn ptr_offset_from(this: *const u8, origin: *const u8) -> isize {
23    isize::wrapping_sub(this as _, origin as _)
24}
25
26/// Construct the `libc::msghdr` which is used as an argument to `libc::sendmsg` and
27/// `libc::recvmsg`.
28///
29/// The constructed `msghdr` contains the references to the given `iov` and has sufficient
30/// (dynamically allocated) space to store `fd_count` file descriptors delivered as ancillary data.
31///
32/// # Unsafety
33///
34/// This function provides a "mostly" safe interface, however it is kept unsafe as its only uses
35/// are intended to be in other unsafe code and its implementation itself is also unsafe.
36unsafe fn construct_msghdr_for(
37    iov: &mut libc::iovec,
38    fd_count: usize,
39) -> (libc::msghdr, alloc::Layout, usize) {
40    let fd_len = mem::size_of::<RawFd>() * fd_count;
41    let cmsg_buffer_len = libc::CMSG_SPACE(fd_len as u32) as usize;
42    let layout = alloc::Layout::from_size_align(cmsg_buffer_len, mem::align_of::<libc::cmsghdr>());
43    let (cmsg_buffer, cmsg_layout) = if let Ok(layout) = layout {
44        const NULL_MUT_U8: *mut u8 = ptr::null_mut();
45        match alloc::alloc(layout) {
46            NULL_MUT_U8 => alloc::handle_alloc_error(layout),
47            x => (x as *mut _, layout),
48        }
49    } else {
50        // NB: it is fine to construct such a `Layout` as it is not used for actual allocation,
51        // just for the error reporting. Either way this branch is not reachable at all provided a
52        // well behaved implementation of `CMSG_SPACE` in the host libc.
53        alloc::handle_alloc_error(alloc::Layout::from_size_align_unchecked(
54            cmsg_buffer_len,
55            mem::align_of::<libc::cmsghdr>(),
56        ))
57    };
58    (
59        libc::msghdr {
60            msg_name: ptr::null_mut(),
61            msg_namelen: 0,
62            msg_iov: iov as *mut _,
63            msg_iovlen: 1,
64            msg_control: cmsg_buffer,
65            msg_controllen: cmsg_buffer_len as _,
66            ..mem::zeroed()
67        },
68        cmsg_layout,
69        fd_len,
70    )
71}
72
73/// A common implementation of `sendmsg` that sends provided bytes with ancillary file descriptors
74/// over either a datagram or stream unix socket.
75fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
76    unsafe {
77        let mut iov = libc::iovec {
78            // NB: this casts *const to *mut, and in doing so we trust the OS to be a good citizen
79            // and not mutate our buffer. This is the API we have to live with.
80            iov_base: bs.as_ptr() as *const _ as *mut _,
81            iov_len: bs.len(),
82        };
83        let (mut msghdr, cmsg_layout, fd_len) = construct_msghdr_for(&mut iov, fds.len());
84        let cmsg_buffer = msghdr.msg_control;
85
86        // Fill cmsg with the file descriptors we are sending.
87        let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
88        ptr::write(
89            cmsg_header,
90            libc::cmsghdr {
91                cmsg_level: libc::SOL_SOCKET,
92                cmsg_type: libc::SCM_RIGHTS,
93                cmsg_len: libc::CMSG_LEN(fd_len as u32) as _,
94            },
95        );
96        #[allow(clippy::cast_ptr_alignment)]
97        let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd;
98        for (i, fd) in fds.iter().enumerate() {
99            ptr::write_unaligned(cmsg_data.add(i), *fd);
100        }
101        let count = libc::sendmsg(socket, &msghdr as *const _, 0);
102        if count < 0 {
103            let error = io::Error::last_os_error();
104            alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
105            Err(error)
106        } else {
107            alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
108            Ok(count as usize)
109        }
110    }
111}
112
113/// A common implementation of `recvmsg` that receives provided bytes and the ancillary file
114/// descriptors over either a datagram or stream unix socket.
115fn recv_with_fd(socket: RawFd, bs: &mut [u8], mut fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
116    unsafe {
117        let mut iov = libc::iovec {
118            iov_base: bs.as_mut_ptr() as *mut _,
119            iov_len: bs.len(),
120        };
121        let (mut msghdr, cmsg_layout, _) = construct_msghdr_for(&mut iov, fds.len());
122        let cmsg_buffer = msghdr.msg_control;
123        let count = libc::recvmsg(socket, &mut msghdr as *mut _, 0);
124        if count < 0 {
125            let error = io::Error::last_os_error();
126            alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
127            return Err(error);
128        }
129
130        // Walk the ancillary data buffer and copy the raw descriptors from it into the output
131        // buffer.
132        let mut descriptor_count = 0;
133        let mut cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
134        while !cmsg_header.is_null() {
135            if (*cmsg_header).cmsg_level == libc::SOL_SOCKET
136                && (*cmsg_header).cmsg_type == libc::SCM_RIGHTS
137            {
138                let data_ptr = libc::CMSG_DATA(cmsg_header);
139                let data_offset = ptr_offset_from(data_ptr, cmsg_header as *const _);
140                debug_assert!(data_offset >= 0);
141                let data_byte_count = (*cmsg_header).cmsg_len as usize - data_offset as usize;
142                debug_assert!((*cmsg_header).cmsg_len as isize > data_offset);
143                debug_assert!(data_byte_count % mem::size_of::<RawFd>() == 0);
144                let rawfd_count = (data_byte_count / mem::size_of::<RawFd>()) as isize;
145                #[allow(clippy::cast_ptr_alignment)]
146                let fd_ptr = data_ptr as *const RawFd;
147                for i in 0..rawfd_count {
148                    if let Some((dst, rest)) = { fds }.split_first_mut() {
149                        *dst = ptr::read_unaligned(fd_ptr.offset(i));
150                        descriptor_count += 1;
151                        fds = rest;
152                    } else {
153                        // This branch is unreachable. We allocate the ancillary data buffer just
154                        // large enough to fit exactly the number of `RawFd`s that are in the `fds`
155                        // buffer. It is not possible for the OS to return more of them.
156                        //
157                        // If this branch ended up being reachable for some reason, it would be
158                        // necessary for this branch to close the file descriptors to avoid leaking
159                        // resources.
160                        //
161                        // TODO: consider using unreachable_unchecked
162                        unreachable!();
163                    }
164                }
165            }
166            cmsg_header = libc::CMSG_NXTHDR(&mut msghdr as *mut _, cmsg_header);
167        }
168
169        alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
170        Ok((count as usize, descriptor_count))
171    }
172}
173
174impl SendWithFd for net::UnixStream {
175    /// Send the bytes and the file descriptors as a stream.
176    ///
177    /// Neither is guaranteed to be received by the other end in a single chunk and
178    /// may arrive entirely independently.
179    fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
180        send_with_fd(self.as_raw_fd(), bytes, fds)
181    }
182}
183
184impl SendWithFd for net::UnixDatagram {
185    /// Send the bytes and the file descriptors as a single packet.
186    ///
187    /// It is guaranteed that the bytes and the associated file descriptors will arrive at the same
188    /// time, however the receiver end may not receive the full message if its buffers are too
189    /// small.
190    fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
191        send_with_fd(self.as_raw_fd(), bytes, fds)
192    }
193}
194
195impl RecvWithFd for net::UnixStream {
196    /// Receive the bytes and the file descriptors from the stream.
197    ///
198    /// It is not guaranteed that the received information will form a single coherent packet of
199    /// data. In other words, it is not required that this receives the bytes and file descriptors
200    /// that were sent with a single `send_with_fd` call by somebody else.
201    fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
202        recv_with_fd(self.as_raw_fd(), bytes, fds)
203    }
204}
205
206impl RecvWithFd for net::UnixDatagram {
207    /// Receive the bytes and the file descriptors as a single packet.
208    ///
209    /// It is guaranteed that the received information will form a single coherent packet, and data
210    /// received will match a corresponding `send_with_fd` call. Note, however, that in case the
211    /// receiving buffer(s) are to small, the message may get silently truncated and the
212    /// undelivered data will be discarded.
213    ///
214    /// For receiving the file descriptors, the internal buffer is sized according to the size of
215    /// the `fds` buffer. If the sender sends `fds.len()` descriptors, but prefaces the descriptors
216    /// with some other ancilliary data, then some file descriptors may be truncated as well.
217    fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
218        recv_with_fd(self.as_raw_fd(), bytes, fds)
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::{RecvWithFd, SendWithFd};
225    use std::os::unix::io::{AsRawFd, FromRawFd};
226    use std::os::unix::net;
227
228    #[test]
229    fn stream_works() {
230        let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
231        let sent_bytes = b"hello world!";
232        let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
233        assert_eq!(
234            l.send_with_fd(&sent_bytes[..], &sent_fds[..])
235                .expect("send should be successful"),
236            sent_bytes.len()
237        );
238        let mut recv_bytes = [0; 128];
239        let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
240        assert_eq!(
241            r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
242                .expect("recv should be successful"),
243            (sent_bytes.len(), sent_fds.len())
244        );
245        assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
246        for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
247            // Modify the sent resource and check if the received resource has been modified the
248            // same way.
249            let expected_value = Some(std::time::Duration::from_secs(42));
250            unsafe {
251                let s = net::UnixStream::from_raw_fd(sent);
252                s.set_read_timeout(expected_value)
253                    .expect("set read timeout");
254                std::mem::forget(s);
255                assert_eq!(
256                    net::UnixStream::from_raw_fd(recvd)
257                        .read_timeout()
258                        .expect("get read timeout"),
259                    expected_value
260                );
261            }
262        }
263    }
264
265    #[test]
266    fn datagram_works() {
267        let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
268        let sent_bytes = b"hello world!";
269        let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
270        assert_eq!(
271            l.send_with_fd(&sent_bytes[..], &sent_fds[..])
272                .expect("send should be successful"),
273            sent_bytes.len()
274        );
275        let mut recv_bytes = [0; 128];
276        let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
277        assert_eq!(
278            r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
279                .expect("recv should be successful"),
280            (sent_bytes.len(), sent_fds.len())
281        );
282        assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
283        for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
284            // Modify the sent resource and check if the received resource has been modified the
285            // same way.
286            let expected_value = Some(std::time::Duration::from_secs(42));
287            unsafe {
288                let s = net::UnixDatagram::from_raw_fd(sent);
289                s.set_read_timeout(expected_value)
290                    .expect("set read timeout");
291                std::mem::forget(s);
292                assert_eq!(
293                    net::UnixDatagram::from_raw_fd(recvd)
294                        .read_timeout()
295                        .expect("get read timeout"),
296                    expected_value
297                );
298            }
299        }
300    }
301
302    #[test]
303    fn datagram_works_across_processes() {
304        let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
305        let sent_bytes = b"hello world!";
306        let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
307
308        unsafe {
309            match libc::fork() {
310                -1 => panic!("fork failed!"),
311                0 => {
312                    // This is the child in which we attempt to send a file descriptor back to
313                    // parent, emulating the cross-process FD sharing.
314                    l.send_with_fd(&sent_bytes[..], &sent_fds[..])
315                        .expect("send should be successful");
316                    ::std::process::exit(0);
317                }
318                _ => {
319                    // Parent process, receives the file descriptors sent by forked child.
320                }
321            }
322            let mut recv_bytes = [0; 128];
323            let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
324            assert_eq!(
325                r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
326                    .expect("recv should be successful"),
327                (sent_bytes.len(), sent_fds.len())
328            );
329            assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
330            for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
331                // Modify the sent resource and check if the received resource has been
332                // modified the same way.
333                let expected_value = Some(std::time::Duration::from_secs(42));
334                let s = net::UnixDatagram::from_raw_fd(sent);
335                s.set_read_timeout(expected_value)
336                    .expect("set read timeout");
337                std::mem::forget(s);
338                assert_eq!(
339                    net::UnixDatagram::from_raw_fd(recvd)
340                        .read_timeout()
341                        .expect("get read timeout"),
342                    expected_value
343                );
344            }
345        }
346    }
347
348    #[test]
349    fn sending_junk_fails() {
350        let (l, _) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
351        let sent_bytes = b"hello world!";
352        if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[i32::max_value()][..]) {
353            panic!("expected an error when sending a junk file descriptor");
354        }
355        if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[0xffi32][..]) {
356            panic!("expected an error when sending a junk file descriptor");
357        }
358    }
359}