vmm_sys_util/unix/
sock_ctrl_msg.rs

1// Copyright 2017 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE-BSD-3-Clause file.
4// SPDX-License-Identifier: BSD-3-Clause
5
6/* Copied from the crosvm Project, commit 186eb8b */
7
8//! Wrapper for sending and receiving messages with file descriptors on sockets that accept
9//! control messages (e.g. Unix domain sockets).
10
11use std::fs::File;
12use std::mem;
13use std::mem::size_of;
14use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
15use std::os::unix::net::{UnixDatagram, UnixStream};
16use std::ptr::null_mut;
17
18use crate::errno::{Error, Result};
19use crate::fam::{FamStruct, FamStructWrapper};
20use libc::{
21    c_uint, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, CMSG_LEN, CMSG_SPACE, MSG_NOSIGNAL,
22    SCM_RIGHTS, SOL_SOCKET,
23};
24
25#[cfg(not(target_env = "musl"))]
26fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
27    msghdr {
28        msg_name: null_mut(),
29        msg_namelen: 0,
30        msg_iov: iovecs.as_mut_ptr(),
31        #[cfg(any(target_os = "linux", target_os = "android"))]
32        msg_iovlen: iovecs.len(),
33        #[cfg(not(any(target_os = "linux", target_os = "android")))]
34        msg_iovlen: iovecs
35            .len()
36            .try_into()
37            .expect("iovecs.len() exceeds i32 range"),
38        msg_control: null_mut(),
39        msg_controllen: 0,
40        msg_flags: 0,
41    }
42}
43
44#[cfg(target_env = "musl")]
45fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
46    assert!(iovecs.len() <= (std::i32::MAX as usize));
47    let mut msg: msghdr = unsafe { std::mem::zeroed() };
48    msg.msg_name = null_mut();
49    msg.msg_iov = iovecs.as_mut_ptr();
50    msg.msg_iovlen = iovecs.len() as i32;
51    msg.msg_control = null_mut();
52    msg
53}
54
55#[cfg(all(
56    not(target_env = "musl"),
57    any(target_os = "linux", target_os = "android")
58))]
59fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: c_uint) {
60    msg.msg_controllen = cmsg_capacity as libc::size_t;
61}
62
63#[cfg(any(
64    target_env = "musl",
65    not(any(target_os = "linux", target_os = "android"))
66))]
67fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: c_uint) {
68    msg.msg_controllen = cmsg_capacity;
69}
70
71#[repr(transparent)]
72struct CmsgHdr(cmsghdr);
73
74impl Default for CmsgHdr {
75    fn default() -> Self {
76        // SAFETY: all-zero is valid for cmsghdr on all architectures/targets.
77        Self(unsafe { mem::zeroed() })
78    }
79}
80
81// SAFETY: `CmsgHdr` is a POD with a FAM struct at the end that is also a POD.
82// We do not have the FAM represented by a zero-width field as would be customary if
83// generated by bindgen, but this makes no difference. The length implementation correctly
84// converts between the size in bytes of the cmsghdr structure + the payload, to the number
85// of fds in the payload.
86unsafe impl FamStruct for CmsgHdr {
87    type Entry = RawFd;
88
89    // on gnu, size_t is usize, while on musl it is u32
90    #[allow(clippy::unnecessary_cast)]
91    fn len(&self) -> usize {
92        (self.0.cmsg_len as usize - size_of::<cmsghdr>()) / size_of::<RawFd>()
93    }
94
95    unsafe fn set_len(&mut self, len: usize) {
96        // Different targets define cmsg_len of different types, but its always
97        // at least 32 bit wide. The safety invariant on this function ensures us
98        // that we're not being fed a value that exceeds Self::max_len(), so we
99        // are good to cast indiscriminately.
100        self.0.cmsg_len = CMSG_LEN((len * size_of::<RawFd>()) as _) as _;
101    }
102
103    fn max_len() -> usize {
104        (u32::MAX as usize - size_of::<cmsghdr>()) / size_of::<RawFd>()
105    }
106
107    fn as_slice(&self) -> &[RawFd] {
108        // SAFETY: By the invariants of the trait impl and set_len, we have a payload of self.len()
109        // fds after the header itself
110        unsafe { std::slice::from_raw_parts((&self.0 as *const cmsghdr).add(1).cast(), self.len()) }
111    }
112
113    fn as_mut_slice(&mut self) -> &mut [RawFd] {
114        // SAFETY: By the invariants of the trait impl and set_len, we have a payload of self.len()
115        // fds after the header itself
116        unsafe {
117            std::slice::from_raw_parts_mut((&mut self.0 as *mut cmsghdr).add(1).cast(), self.len())
118        }
119    }
120}
121
122type CmsgBuffer = FamStructWrapper<CmsgHdr>;
123
124fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
125    let mut cmsg_buffer =
126        CmsgBuffer::from_entries(out_fds).map_err(|_| Error::new(libc::ENOMEM))?;
127
128    let mut iovecs = Vec::with_capacity(out_data.len());
129    for data in out_data {
130        iovecs.push(iovec {
131            iov_base: data.as_ptr() as *mut c_void,
132            iov_len: data.size(),
133        });
134    }
135
136    let mut msg = new_msghdr(&mut iovecs);
137
138    if !out_fds.is_empty() {
139        // SAFETY: We do not touch the cmsg_len field.
140        unsafe {
141            let hdr = cmsg_buffer.as_mut_fam_struct();
142            hdr.0.cmsg_level = SOL_SOCKET;
143            hdr.0.cmsg_type = SCM_RIGHTS;
144        }
145
146        msg.msg_control = cmsg_buffer.as_mut_fam_struct_ptr() as *mut c_void;
147        // SAFETY: CMSG_SPACE has no invariants to uphold
148        unsafe {
149            set_msg_controllen(&mut msg, CMSG_SPACE(size_of_val(out_fds) as _));
150        }
151    }
152
153    // SAFETY: Safe because the msghdr was properly constructed from valid (or null) pointers of
154    // the indicated length and we check the return value.
155    let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
156
157    if write_count == -1 {
158        Err(Error::last())
159    } else {
160        Ok(write_count as usize)
161    }
162}
163
164#[allow(clippy::unnecessary_cast)]
165unsafe fn raw_recvmsg(
166    fd: RawFd,
167    iovecs: &mut [iovec],
168    in_fds: &mut [RawFd],
169) -> Result<(usize, usize)> {
170    let mut cmsg_buffer = CmsgBuffer::new(in_fds.len()).map_err(|_| Error::new(libc::ENOMEM))?;
171    let mut msg = new_msghdr(iovecs);
172    // Due to alignment constraints, this might be > in_fds.len()!
173    let cmsg_capacity = cmsg_buffer.as_fam_struct_ref().len();
174
175    if !in_fds.is_empty() {
176        // MSG control len is size_of(cmsghdr) + size_of(RawFd) * in_fds.len().
177        msg.msg_control = cmsg_buffer.as_mut_fam_struct_ptr() as *mut c_void;
178        // SAFETY: CMSG_SPACE has no invariants to uphold
179        unsafe {
180            set_msg_controllen(&mut msg, CMSG_SPACE(size_of_val(in_fds) as _));
181        }
182    }
183
184    // Safe because the msghdr was properly constructed from valid (or null) pointers of the
185    // indicated length and we check the return value.
186    // TODO: Should we handle MSG_TRUNC in a specific way?
187    let total_read = recvmsg(fd, &mut msg, 0);
188    if total_read == -1 {
189        return Err(Error::last());
190    }
191
192    let mut copied_fds_count = 0;
193    // If the control data was truncated, then this might be a sign of incorrect communication
194    // protocol. If MSG_CTRUNC was set we must close the fds from the control data.
195    let mut teardown_control_data = msg.msg_flags & libc::MSG_CTRUNC != 0;
196
197    if !msg.msg_control.is_null() {
198        let cmsg = cmsg_buffer.as_mut_fam_struct();
199
200        if cmsg.0.cmsg_level == SOL_SOCKET && cmsg.0.cmsg_type == SCM_RIGHTS {
201            // SAFETY: On some OSes (MacOS), when CMSG_TRUNC is set, the kernel sets
202            // the cmsg->len field to the untruncated length of the message sent,
203            // even if the buffer is actually smaller. Compensate for this by reducing
204            // the length to be at most the size of the memory we have allocated.
205            unsafe { cmsg.set_len(cmsg.len().min(cmsg_capacity)) }
206            let fds = cmsg.as_slice();
207            // It could be that while constructing the cmsg structures, alignment constraints made
208            // our allocation big enough that it fits more than the in_fds.len() file descriptors
209            // we intended to receive. Treat this the same way we would treat a truncated message,
210            // because there is no way for us to communicate these extra FDs back to the caller.
211            teardown_control_data |= fds.len() > in_fds.len();
212            if teardown_control_data {
213                for fd in fds {
214                    libc::close(*fd);
215                }
216            } else {
217                in_fds[..fds.len()].copy_from_slice(fds);
218
219                copied_fds_count = fds.len();
220            }
221        }
222    }
223
224    if teardown_control_data {
225        return Err(Error::new(libc::ENOBUFS));
226    }
227
228    Ok((total_read as usize, copied_fds_count))
229}
230
231/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
232/// `recvmsg`.
233///
234/// # Examples
235///
236/// ```
237/// use std::os::fd::{AsRawFd, FromRawFd};
238/// use std::os::unix::net::UnixDatagram;
239///
240/// use libc::{c_void, iovec};
241/// use vmm_sys_util::event::{new_event_consumer_and_notifier, EventFlag, EventNotifier};
242/// use vmm_sys_util::sock_ctrl_msg::ScmSocket;
243///
244/// let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
245/// let (consumer, fd_to_send) = new_event_consumer_and_notifier(EventFlag::empty())
246///     .expect("Failed to create notifier and consumer");
247///
248/// let write_count = s1
249///     .send_with_fds(&[[237].as_ref()], &[fd_to_send.as_raw_fd()])
250///     .expect("failed to send fd");
251///
252/// let mut files = [0; 2];
253/// let mut buf = [0u8];
254/// let mut iovecs = [iovec {
255///     iov_base: buf.as_mut_ptr() as *mut c_void,
256///     iov_len: buf.len(),
257/// }];
258/// let (read_count, file_count) = unsafe {
259///     s2.recv_with_fds(&mut iovecs[..], &mut files)
260///         .expect("failed to recv fd")
261/// };
262///
263/// let mut notifier = unsafe { EventNotifier::from_raw_fd(files[0]) };
264/// notifier.notify().unwrap();
265/// assert!(consumer.consume().is_ok());
266/// ```
267pub trait ScmSocket {
268    /// Gets the file descriptor of this socket.
269    fn socket_fd(&self) -> RawFd;
270
271    /// Sends the given data and file descriptor over the socket.
272    ///
273    /// On success, returns the number of bytes sent.
274    ///
275    /// # Arguments
276    ///
277    /// * `buf` - A buffer of data to send on the `socket`.
278    /// * `fd` - A file descriptors to be sent.
279    fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
280        self.send_with_fds(&[buf], &[fd])
281    }
282
283    /// Sends the given data and file descriptors over the socket.
284    ///
285    /// On success, returns the number of bytes sent.
286    ///
287    /// # Arguments
288    ///
289    /// * `bufs` - A list of data buffer to send on the `socket`.
290    /// * `fds` - A list of file descriptors to be sent.
291    fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> {
292        raw_sendmsg(self.socket_fd(), bufs, fds)
293    }
294
295    /// Receives data and potentially a file descriptor from the socket.
296    ///
297    /// On success, returns the number of bytes and an optional file descriptor.
298    ///
299    /// # Arguments
300    ///
301    /// * `buf` - A buffer to receive data from the socket.
302    fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
303        let mut fd = [0];
304        let mut iovecs = [iovec {
305            iov_base: buf.as_mut_ptr() as *mut c_void,
306            iov_len: buf.len(),
307        }];
308
309        // SAFETY: Safe because we have mutably borrowed buf and it's safe to write arbitrary data
310        // to a slice.
311        let (read_count, fd_count) = unsafe { self.recv_with_fds(&mut iovecs[..], &mut fd)? };
312        let file = if fd_count == 0 {
313            None
314        } else {
315            // SAFETY: Safe because the first fd from recv_with_fds is owned by us and valid
316            // because this branch was taken.
317            Some(unsafe { File::from_raw_fd(fd[0]) })
318        };
319        Ok((read_count, file))
320    }
321
322    /// Receives data and file descriptors from the socket.
323    ///
324    /// On success, returns the number of bytes and file descriptors received as a tuple
325    /// `(bytes count, files count)`.
326    ///
327    /// # Arguments
328    ///
329    /// * `iovecs` - A list of iovec to receive data from the socket.
330    /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
331    ///   number of valid file descriptors is indicated by the second element of the
332    ///   returned tuple. The caller owns these file descriptors, but they will not be
333    ///   closed on drop like a `File`-like type would be. It is recommended that each valid
334    ///   file descriptor gets wrapped in a drop type that closes it after this returns.
335    ///
336    /// # Safety
337    ///
338    /// It is the callers responsibility to ensure it is safe for arbitrary data to be
339    /// written to the iovec pointers.
340    unsafe fn recv_with_fds(
341        &self,
342        iovecs: &mut [iovec],
343        fds: &mut [RawFd],
344    ) -> Result<(usize, usize)> {
345        raw_recvmsg(self.socket_fd(), iovecs, fds)
346    }
347}
348
349impl ScmSocket for UnixDatagram {
350    fn socket_fd(&self) -> RawFd {
351        self.as_raw_fd()
352    }
353}
354
355impl ScmSocket for UnixStream {
356    fn socket_fd(&self) -> RawFd {
357        self.as_raw_fd()
358    }
359}
360
361/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for
362/// the lifetime of this object.
363///
364/// # Safety
365///
366/// This is marked unsafe because the implementation must ensure that the returned pointer and size
367/// is valid and that the lifetime of the returned pointer is at least that of the trait object.
368pub unsafe trait IntoIovec {
369    /// Gets the base pointer of this `iovec`.
370    fn as_ptr(&self) -> *const c_void;
371
372    /// Gets the size in bytes of this `iovec`.
373    fn size(&self) -> usize;
374}
375
376// SAFETY: Safe because this slice can not have another mutable reference and it's pointer and
377// size are guaranteed to be valid.
378unsafe impl IntoIovec for &[u8] {
379    // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480
380    #[allow(clippy::useless_asref)]
381    fn as_ptr(&self) -> *const c_void {
382        self.as_ref().as_ptr() as *const c_void
383    }
384
385    fn size(&self) -> usize {
386        self.len()
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    #![allow(clippy::undocumented_unsafe_blocks)]
393    use super::*;
394    use std::io::{pipe, Read};
395
396    use std::io::Write;
397    use std::os::unix::net::UnixDatagram;
398    use std::slice::from_raw_parts;
399
400    #[test]
401    fn send_recv_no_fd() {
402        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
403
404        let write_count = s1
405            .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
406            .expect("failed to send data");
407
408        assert_eq!(write_count, 6);
409
410        let mut buf = [0u8; 6];
411        let mut files = [0; 1];
412        let mut iovecs = [iovec {
413            iov_base: buf.as_mut_ptr() as *mut c_void,
414            iov_len: buf.len(),
415        }];
416        let (read_count, file_count) = unsafe {
417            s2.recv_with_fds(&mut iovecs[..], &mut files)
418                .expect("failed to recv data")
419        };
420
421        assert_eq!(read_count, 6);
422        assert_eq!(file_count, 0);
423        assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
424    }
425
426    #[test]
427    fn send_recv_only_fd() {
428        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
429
430        let (mut evt_consumer, evt_notifier) = pipe().expect("failed to create pipe");
431        let write_count = s1
432            .send_with_fd([].as_ref(), evt_notifier.as_raw_fd())
433            .expect("failed to send fd");
434
435        assert_eq!(write_count, 0);
436
437        let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
438
439        let mut file = file_opt.unwrap();
440
441        assert_eq!(read_count, 0);
442        assert!(file.as_raw_fd() >= 0);
443        assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
444        assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
445        assert_ne!(file.as_raw_fd(), evt_notifier.as_raw_fd());
446
447        file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
448            .expect("failed to write to sent fd");
449
450        let mut buf = [0u8; std::mem::size_of::<u64>()];
451        evt_consumer
452            .read_exact(buf.as_mut_slice())
453            .expect("Failed to read from PipeReader");
454        assert_eq!(u64::from_ne_bytes(buf), 1203);
455    }
456
457    #[test]
458    fn send_recv_with_fd() {
459        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
460
461        let (mut evt_consumer, evt_notifier) = pipe().expect("failed to create pipe");
462        let write_count = s1
463            .send_with_fds(&[[237].as_ref()], &[evt_notifier.as_raw_fd()])
464            .expect("failed to send fd");
465
466        assert_eq!(write_count, 1);
467
468        let mut files = [0; 2];
469        let mut buf = [0u8];
470        let mut iovecs = [iovec {
471            iov_base: buf.as_mut_ptr() as *mut c_void,
472            iov_len: buf.len(),
473        }];
474        let (read_count, file_count) = unsafe {
475            s2.recv_with_fds(&mut iovecs[..], &mut files)
476                .expect("failed to recv fd")
477        };
478
479        assert_eq!(read_count, 1);
480        assert_eq!(buf[0], 237);
481        assert_eq!(file_count, 1);
482        assert!(files[0] >= 0);
483        assert_ne!(files[0], s1.as_raw_fd());
484        assert_ne!(files[0], s2.as_raw_fd());
485        assert_ne!(files[0], evt_notifier.as_raw_fd());
486
487        let mut file = unsafe { File::from_raw_fd(files[0]) };
488
489        file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
490            .expect("failed to write to sent fd");
491
492        let mut buf = [0u8; std::mem::size_of::<u64>()];
493        evt_consumer
494            .read_exact(buf.as_mut_slice())
495            .expect("Failed to read from PipeReader");
496        assert_eq!(u64::from_ne_bytes(buf), 1203);
497    }
498
499    #[test]
500    // Exercise the code paths that activate the issue of receiving the all the ancillary data,
501    // but missing to provide enough buffer space to store it.
502    fn send_more_recv_less() {
503        // macos does not set MSG_CTRUNC if we pass a zero-size buffer for control data, even
504        // if the sender does provide control data, while linux does
505        #[cfg(any(target_os = "linux", target_os = "android"))]
506        let start = 0;
507        #[cfg(not(any(target_os = "linux", target_os = "android")))]
508        let start = 1;
509
510        for too_small in start..3 {
511            let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
512
513            let (_, evt_notifier1) = pipe().expect("failed to create pipe");
514            let (_, evt_notifier2) = pipe().expect("failed to create pipe");
515            let (_, evt_notifier3) = pipe().expect("failed to create pipe");
516            let (_, evt_notifier4) = pipe().expect("failed to create pipe");
517            let write_count = s1
518                .send_with_fds(
519                    &[[237].as_ref()],
520                    &[
521                        evt_notifier1.as_raw_fd(),
522                        evt_notifier2.as_raw_fd(),
523                        evt_notifier3.as_raw_fd(),
524                        evt_notifier4.as_raw_fd(),
525                    ],
526                )
527                .expect("failed to send fd");
528
529            assert_eq!(write_count, 1);
530
531            let mut files = vec![0; too_small];
532            let mut buf = [0u8];
533            let mut iovecs = [iovec {
534                iov_base: buf.as_mut_ptr() as *mut c_void,
535                iov_len: buf.len(),
536            }];
537            unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).unwrap_err() };
538        }
539    }
540}