virtio_driver/util/
sock_ctrl_msg.rs

1// Copyright (C) 2022 Red Hat, Inc. All rights reserved.
2//
3// Copyright 2017 The Chromium OS Authors. All rights reserved.
4// Use of this source code is governed by a BSD-style license that can be
5// found in the LICENSE.crosvm file.
6
7#![allow(dead_code)]
8
9//! Used to send and receive messages with file descriptors on sockets that accept control messages
10//! (e.g. Unix domain sockets).
11
12use std::fs::File;
13use std::io::{IoSlice, IoSliceMut, Result};
14use std::mem::{size_of, size_of_val};
15use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
16use std::os::unix::net::{UnixDatagram, UnixStream};
17use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
18use std::slice;
19
20use libc::{
21    c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
22};
23
24use crate::Error;
25
26// Each of the following macros performs the same function as their C counterparts. They are each
27// macros because they are used to size statically allocated arrays.
28
29macro_rules! CMSG_ALIGN {
30    ($len:expr) => {
31        (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
32    };
33}
34
35macro_rules! CMSG_SPACE {
36    ($len:expr) => {
37        size_of::<cmsghdr>() + CMSG_ALIGN!($len)
38    };
39}
40
41macro_rules! CMSG_LEN {
42    ($len:expr) => {
43        size_of::<cmsghdr>() + ($len)
44    };
45}
46
47// This function (macro in the C version) is not used in any compile time constant slots, so is just
48// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
49// module supports.
50#[allow(non_snake_case)]
51#[inline(always)]
52fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
53    // Essentially returns a pointer to just past the header.
54    cmsg_buffer.wrapping_offset(1) as *mut RawFd
55}
56
57// This function is like CMSG_NEXT, but safer because it reads only from references, although it
58// does some pointer arithmetic on cmsg_ptr.
59#[allow(clippy::cast_ptr_alignment)]
60fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
61    let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr;
62    if next_cmsg
63        .wrapping_offset(1)
64        .wrapping_sub(msghdr.msg_control as usize) as usize
65        > msghdr.msg_controllen
66    {
67        null_mut()
68    } else {
69        next_cmsg
70    }
71}
72
73const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
74
75enum CmsgBuffer {
76    Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
77    Heap(Box<[cmsghdr]>),
78}
79
80impl CmsgBuffer {
81    fn with_capacity(capacity: usize) -> CmsgBuffer {
82        let cap_in_cmsghdr_units =
83            (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
84        if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
85            CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
86        } else {
87            CmsgBuffer::Heap(
88                vec![
89                    cmsghdr {
90                        cmsg_len: 0,
91                        cmsg_level: 0,
92                        cmsg_type: 0,
93                    };
94                    cap_in_cmsghdr_units
95                ]
96                .into_boxed_slice(),
97            )
98        }
99    }
100
101    fn as_mut_ptr(&mut self) -> *mut cmsghdr {
102        match self {
103            CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
104            CmsgBuffer::Heap(a) => a.as_mut_ptr(),
105        }
106    }
107}
108
109fn raw_sendmsg<D: IntoIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
110    let cmsg_capacity = CMSG_SPACE!(size_of_val(out_fds));
111    let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
112
113    let iovec = IntoIobuf::as_iobufs(out_data);
114
115    let mut msg = msghdr {
116        msg_name: null_mut(),
117        msg_namelen: 0,
118        msg_iov: iovec.as_ptr() as *mut iovec,
119        msg_iovlen: iovec.len(),
120        msg_control: null_mut(),
121        msg_controllen: 0,
122        msg_flags: 0,
123    };
124
125    if !out_fds.is_empty() {
126        let cmsg = cmsghdr {
127            cmsg_len: CMSG_LEN!(size_of_val(out_fds)),
128            cmsg_level: SOL_SOCKET,
129            cmsg_type: SCM_RIGHTS,
130        };
131        unsafe {
132            // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
133            write_unaligned(cmsg_buffer.as_mut_ptr(), cmsg);
134            // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
135            // file descriptors.
136            copy_nonoverlapping(
137                out_fds.as_ptr(),
138                CMSG_DATA(cmsg_buffer.as_mut_ptr()),
139                out_fds.len(),
140            );
141        }
142
143        msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
144        msg.msg_controllen = cmsg_capacity;
145    }
146
147    // Safe because the msghdr was properly constructed from valid (or null) pointers of the
148    // indicated length and we check the return value.
149    let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
150
151    if write_count == -1 {
152        Err(Error::last_os_error())
153    } else {
154        Ok(write_count as usize)
155    }
156}
157
158fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
159    let cmsg_capacity = CMSG_SPACE!(size_of_val(in_fds));
160    let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
161
162    let mut iovec = iovec {
163        iov_base: in_data.as_mut_ptr() as *mut c_void,
164        iov_len: in_data.len(),
165    };
166
167    let mut msg = msghdr {
168        msg_name: null_mut(),
169        msg_namelen: 0,
170        msg_iov: &mut iovec as *mut iovec,
171        msg_iovlen: 1,
172        msg_control: null_mut(),
173        msg_controllen: 0,
174        msg_flags: 0,
175    };
176
177    if !in_fds.is_empty() {
178        msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
179        msg.msg_controllen = cmsg_capacity;
180    }
181
182    // Safe because the msghdr was properly constructed from valid (or null) pointers of the
183    // indicated length and we check the return value.
184    let total_read = unsafe { recvmsg(fd, &mut msg, 0) };
185
186    if total_read == -1 {
187        return Err(Error::last_os_error());
188    }
189
190    if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() {
191        return Ok((0, 0));
192    }
193
194    let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
195    let mut in_fds_count = 0;
196    while !cmsg_ptr.is_null() {
197        // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
198        // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
199        let cmsg = unsafe { (cmsg_ptr).read_unaligned() };
200
201        if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
202            let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>();
203            unsafe {
204                copy_nonoverlapping(
205                    CMSG_DATA(cmsg_ptr),
206                    in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
207                    fd_count,
208                );
209            }
210            in_fds_count += fd_count;
211        }
212
213        cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
214    }
215
216    Ok((total_read as usize, in_fds_count))
217}
218
219/// The maximum number of FDs that can be sent in a single send.
220pub const SCM_SOCKET_MAX_FD_COUNT: usize = 253;
221
222/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
223/// `recvmsg`.
224pub trait ScmSocket {
225    /// Gets the file descriptor of this socket.
226    fn socket_fd(&self) -> RawFd;
227
228    /// Sends the given data and file descriptor over the socket.
229    ///
230    /// On success, returns the number of bytes sent.
231    ///
232    /// # Arguments
233    ///
234    /// * `buf` - A buffer of data to send on the `socket`.
235    /// * `fd` - A file descriptors to be sent.
236    fn send_with_fd<D: IntoIobuf>(&self, buf: &[D], fd: RawFd) -> Result<usize> {
237        self.send_with_fds(buf, &[fd])
238    }
239
240    /// Sends the given data and file descriptors over the socket.
241    ///
242    /// On success, returns the number of bytes sent.
243    ///
244    /// # Arguments
245    ///
246    /// * `buf` - A buffer of data to send on the `socket`.
247    /// * `fds` - A list of file descriptors to be sent.
248    fn send_with_fds<D: IntoIobuf>(&self, buf: &[D], fd: &[RawFd]) -> Result<usize> {
249        raw_sendmsg(self.socket_fd(), buf, fd)
250    }
251
252    /// Receives data and potentially a file descriptor from the socket.
253    ///
254    /// On success, returns the number of bytes and an optional file descriptor.
255    ///
256    /// # Arguments
257    ///
258    /// * `buf` - A buffer to receive data from the socket.vm
259    fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
260        let mut fd = [0];
261        let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?;
262        let file = if fd_count == 0 {
263            None
264        } else {
265            // Safe because the first fd from recv_with_fds is owned by us and valid because this
266            // branch was taken.
267            Some(unsafe { File::from_raw_fd(fd[0]) })
268        };
269        Ok((read_count, file))
270    }
271
272    /// Receives data and file descriptors from the socket.
273    ///
274    /// On success, returns the number of bytes and file descriptors received as a tuple
275    /// `(bytes count, files count)`.
276    ///
277    /// # Arguments
278    ///
279    /// * `buf` - A buffer to receive data from the socket.
280    /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
281    ///           number of valid file descriptors is indicated by the second element of the
282    ///           returned tuple. The caller owns these file descriptors, but they will not be
283    ///           closed on drop like a `File`-like type would be. It is recommended that each valid
284    ///           file descriptor gets wrapped in a drop type that closes it after this returns.
285    fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> {
286        raw_recvmsg(self.socket_fd(), buf, fds)
287    }
288}
289
290impl ScmSocket for UnixDatagram {
291    fn socket_fd(&self) -> RawFd {
292        self.as_raw_fd()
293    }
294}
295
296impl ScmSocket for UnixStream {
297    fn socket_fd(&self) -> RawFd {
298        self.as_raw_fd()
299    }
300}
301
302/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for
303/// the lifetime of this object.
304///
305/// # Safety
306/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size
307/// being accurate.
308pub unsafe trait IntoIobuf: Sized {
309    /// Returns a `iovec` that describes a contiguous region of memory.
310    fn into_iobuf(self) -> iovec;
311
312    /// Returns a slice of `iovec`s that each describe a contiguous region of memory.
313    fn as_iobufs(bufs: &[Self]) -> &[iovec];
314}
315
316// Safe because there are no other mutable references to the memory described by `IoSlice` and it is
317// guaranteed to be ABI-compatible with `iovec`.
318unsafe impl<'a> IntoIobuf for IoSlice<'a> {
319    fn into_iobuf(self) -> iovec {
320        iovec {
321            iov_base: self.as_ptr() as *mut c_void,
322            iov_len: self.len(),
323        }
324    }
325
326    fn as_iobufs(bufs: &[Self]) -> &[iovec] {
327        // Safe because `IoSlice` is guaranteed to be ABI-compatible with `iovec`.
328        unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
329    }
330}
331
332// Safe because there are no other references to the memory described by `IoSliceMut` and it is
333// guaranteed to be ABI-compatible with `iovec`.
334unsafe impl<'a> IntoIobuf for IoSliceMut<'a> {
335    fn into_iobuf(self) -> iovec {
336        iovec {
337            iov_base: self.as_ptr() as *mut c_void,
338            iov_len: self.len(),
339        }
340    }
341
342    fn as_iobufs(bufs: &[Self]) -> &[iovec] {
343        // Safe because `IoSliceMut` is guaranteed to be ABI-compatible with `iovec`.
344        unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    use std::io::Write;
353    use std::mem::size_of;
354    use std::os::raw::c_long;
355    use std::os::unix::net::UnixDatagram;
356    use std::slice::from_raw_parts;
357
358    use libc::cmsghdr;
359
360    use crate::{EventFd, EventfdFlags};
361
362    #[test]
363    fn buffer_len() {
364        assert_eq!(CMSG_SPACE!(0), size_of::<cmsghdr>());
365        assert_eq!(
366            CMSG_SPACE!(size_of::<RawFd>()),
367            size_of::<cmsghdr>() + size_of::<c_long>()
368        );
369        if size_of::<RawFd>() == 4 {
370            assert_eq!(
371                CMSG_SPACE!(2 * size_of::<RawFd>()),
372                size_of::<cmsghdr>() + size_of::<c_long>()
373            );
374            assert_eq!(
375                CMSG_SPACE!(3 * size_of::<RawFd>()),
376                size_of::<cmsghdr>() + size_of::<c_long>() * 2
377            );
378            assert_eq!(
379                CMSG_SPACE!(4 * size_of::<RawFd>()),
380                size_of::<cmsghdr>() + size_of::<c_long>() * 2
381            );
382        } else if size_of::<RawFd>() == 8 {
383            assert_eq!(
384                CMSG_SPACE!(2 * size_of::<RawFd>()),
385                size_of::<cmsghdr>() + size_of::<c_long>() * 2
386            );
387            assert_eq!(
388                CMSG_SPACE!(3 * size_of::<RawFd>()),
389                size_of::<cmsghdr>() + size_of::<c_long>() * 3
390            );
391            assert_eq!(
392                CMSG_SPACE!(4 * size_of::<RawFd>()),
393                size_of::<cmsghdr>() + size_of::<c_long>() * 4
394            );
395        }
396    }
397
398    #[test]
399    fn send_recv_no_fd() {
400        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
401
402        let ioslice = IoSlice::new([1u8, 1, 2, 21, 34, 55].as_ref());
403        let write_count = s1
404            .send_with_fds(&[ioslice], &[])
405            .expect("failed to send data");
406
407        assert_eq!(write_count, 6);
408
409        let mut buf = [0; 6];
410        let mut files = [0; 1];
411        let (read_count, file_count) = s2
412            .recv_with_fds(&mut buf[..], &mut files)
413            .expect("failed to recv data");
414
415        assert_eq!(read_count, 6);
416        assert_eq!(file_count, 0);
417        assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
418    }
419
420    #[test]
421    fn send_recv_only_fd() {
422        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
423
424        let evt = EventFd::new(EventfdFlags::empty()).expect("failed to create eventfd");
425        let ioslice = IoSlice::new([].as_ref());
426        let write_count = s1
427            .send_with_fd(&[ioslice], evt.as_raw_fd())
428            .expect("failed to send fd");
429
430        assert_eq!(write_count, 0);
431
432        let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
433
434        let mut file = file_opt.unwrap();
435
436        assert_eq!(read_count, 0);
437        assert!(file.as_raw_fd() >= 0);
438        assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
439        assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
440        assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
441
442        file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
443            .expect("failed to write to sent fd");
444
445        assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
446    }
447
448    #[test]
449    fn send_recv_with_fd() {
450        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
451
452        let evt = EventFd::new(EventfdFlags::empty()).expect("failed to create eventfd");
453        let ioslice = IoSlice::new([237].as_ref());
454        let write_count = s1
455            .send_with_fds(&[ioslice], &[evt.as_raw_fd()])
456            .expect("failed to send fd");
457
458        assert_eq!(write_count, 1);
459
460        let mut files = [0; 2];
461        let mut buf = [0u8];
462        let (read_count, file_count) = s2
463            .recv_with_fds(&mut buf, &mut files)
464            .expect("failed to recv fd");
465
466        assert_eq!(read_count, 1);
467        assert_eq!(buf[0], 237);
468        assert_eq!(file_count, 1);
469        assert!(files[0] >= 0);
470        assert_ne!(files[0], s1.as_raw_fd());
471        assert_ne!(files[0], s2.as_raw_fd());
472        assert_ne!(files[0], evt.as_raw_fd());
473
474        let mut file = unsafe { File::from_raw_fd(files[0]) };
475
476        file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
477            .expect("failed to write to sent fd");
478
479        assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
480    }
481}