tokio_unix_ipc/
raw_channel.rs

1use std::io;
2use std::io::{IoSlice, IoSliceMut};
3use std::mem;
4use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
5use std::os::unix::net::UnixStream;
6use std::path::Path;
7use std::slice;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use nix::errno::Errno;
11use nix::sys::socket::{
12    c_uint, recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, CMSG_SPACE,
13};
14use nix::unistd;
15
16use tokio::io::unix::AsyncFd;
17
18#[cfg(any(
19    target_os = "android",
20    target_os = "dragonfly",
21    target_os = "freebsd",
22    target_os = "linux",
23    target_os = "netbsd",
24    target_os = "openbsd"
25))]
26const MSG_FLAGS: MsgFlags = MsgFlags::MSG_CMSG_CLOEXEC;
27
28#[cfg(target_os = "macos")]
29const MSG_FLAGS: MsgFlags = MsgFlags::empty();
30
31#[repr(C)]
32#[derive(Default, Debug)]
33struct MsgHeader {
34    payload_len: u32,
35    fd_count: u32,
36}
37
38impl MsgHeader {
39    pub fn as_buf(&self) -> &[u8] {
40        unsafe { slice::from_raw_parts((self as *const _) as *const u8, mem::size_of_val(self)) }
41    }
42
43    pub fn as_buf_mut(&mut self) -> &mut [u8] {
44        unsafe { slice::from_raw_parts_mut((self as *mut _) as *mut u8, mem::size_of_val(self)) }
45    }
46
47    pub fn make_buffer(&self) -> Vec<u8> {
48        vec![0u8; self.payload_len as usize]
49    }
50}
51
52/// Data received via `SCM_CREDENTIALS` from a remote process.
53#[derive(Debug, Clone)]
54pub struct Credentials {
55    pid: libc::pid_t,
56    uid: libc::uid_t,
57    gid: libc::gid_t,
58}
59
60impl Credentials {
61    /// The remote process identifier.
62    pub fn pid(&self) -> libc::pid_t {
63        self.pid
64    }
65
66    /// The remote process user ID.
67    pub fn uid(&self) -> libc::uid_t {
68        self.uid
69    }
70
71    /// The remote process group ID.
72    pub fn gid(&self) -> libc::gid_t {
73        self.gid
74    }
75}
76
77#[cfg(any(target_os = "android", target_os = "linux"))]
78impl From<nix::sys::socket::UnixCredentials> for Credentials {
79    fn from(c: nix::sys::socket::UnixCredentials) -> Self {
80        Self {
81            pid: c.pid(),
82            uid: c.uid(),
83            gid: c.gid(),
84        }
85    }
86}
87
88macro_rules! fd_impl {
89    ($ty:ty) => {
90        #[allow(dead_code)]
91        impl $ty {
92            pub(crate) unsafe fn from_raw_fd(fd: RawFd) -> io::Result<Self> {
93                Ok(Self {
94                    inner: AsyncFd::new(fd)?,
95                    dead: AtomicBool::new(false),
96                })
97            }
98
99            /// Convert from a standard stream.  This is a fallible
100            /// operation because registering the file descriptor with
101            /// the async runtime may fail.
102            ///
103            /// # Panics
104            ///
105            /// This function panics if it is not called from within a runtime with
106            /// IO enabled.
107            pub fn from_std(stream: UnixStream) -> io::Result<Self> {
108                unsafe { Self::from_raw_fd(stream.into_raw_fd()) }
109            }
110
111            pub(crate) fn extract_raw_fd(&self) -> RawFd {
112                if self.dead.swap(true, Ordering::SeqCst) {
113                    panic!("handle was moved previously");
114                } else {
115                    self.inner.as_raw_fd()
116                }
117            }
118        }
119
120        impl FromRawFd for $ty {
121            unsafe fn from_raw_fd(fd: RawFd) -> Self {
122                Self::from_raw_fd(fd)
123                    .expect("conversion from RawFd requires an active tokio runtime")
124            }
125        }
126
127        impl IntoRawFd for $ty {
128            fn into_raw_fd(self) -> RawFd {
129                self.extract_raw_fd()
130            }
131        }
132
133        impl AsRawFd for $ty {
134            fn as_raw_fd(&self) -> RawFd {
135                self.inner.as_raw_fd()
136            }
137        }
138
139        impl Drop for $ty {
140            fn drop(&mut self) {
141                if !self.dead.load(Ordering::SeqCst) {
142                    unistd::close(self.as_raw_fd()).ok();
143                }
144            }
145        }
146    };
147}
148
149fd_impl!(RawReceiver);
150fd_impl!(RawSender);
151
152macro_rules! nix_eintr {
153    ($expr:expr) => {
154        loop {
155            match $expr {
156                Err(Errno::EINTR) => continue,
157                other => break other,
158            }
159        }
160    };
161}
162
163fn recv_impl(
164    fd: RawFd,
165    buf: &mut [u8],
166    fds: Option<Vec<i32>>,
167    fd_count: usize,
168    _want_creds: bool,
169) -> io::Result<(usize, Option<Vec<RawFd>>, Option<Credentials>)> {
170    let mut iov = [IoSliceMut::new(buf)];
171    let mut new_fds = None;
172
173    #[allow(unused_mut)]
174    let mut creds = None;
175
176    // Compute the size of ancillary data, combining expected number of file descriptors
177    // with any space needed for credentials.
178    let msgspace_size = {
179        let fd_size = unsafe { CMSG_SPACE(mem::size_of::<RawFd>() as c_uint) * fd_count as u32 };
180        #[cfg(any(target_os = "android", target_os = "linux"))]
181        {
182            let cred_size: u32 = _want_creds
183                .then(|| unsafe {
184                    CMSG_SPACE(mem::size_of::<nix::sys::socket::UnixCredentials>() as c_uint)
185                })
186                .unwrap_or_default();
187            fd_size + cred_size
188        }
189        #[cfg(not(any(target_os = "android", target_os = "linux")))]
190        {
191            fd_size
192        }
193    };
194    let mut cmsgspace = vec![0u8; msgspace_size as usize];
195
196    let msg = nix_eintr!(recvmsg::<()>(fd, &mut iov, Some(&mut cmsgspace), MSG_FLAGS))?;
197
198    for cmsg in msg.cmsgs() {
199        match cmsg {
200            ControlMessageOwned::ScmRights(fds) => {
201                if !fds.is_empty() {
202                    #[cfg(target_os = "macos")]
203                    unsafe {
204                        for &fd in &fds {
205                            // as per documentation this does not ever fail
206                            // with EINTR
207                            libc::ioctl(fd, libc::FIOCLEX);
208                        }
209                    }
210                    new_fds = Some(fds);
211                }
212            }
213            #[cfg(any(target_os = "android", target_os = "linux"))]
214            ControlMessageOwned::ScmCredentials(c) => {
215                creds = Some(c.into());
216            }
217            _ => {}
218        }
219    }
220
221    if msg.bytes == 0 {
222        return Err(io::Error::new(
223            io::ErrorKind::UnexpectedEof,
224            "could not read",
225        ));
226    }
227
228    let fds = match (fds, new_fds) {
229        (None, Some(new)) => Some(new),
230        (Some(mut old), Some(new)) => {
231            old.extend(new);
232            Some(old)
233        }
234        (old, None) => old,
235    };
236
237    Ok((msg.bytes, fds, creds))
238}
239
240#[cfg(any(target_os = "android", target_os = "linux"))]
241fn send_impl(fd: RawFd, data: &[u8], fds: &[RawFd], creds: bool) -> io::Result<usize> {
242    let iov = [IoSlice::new(data)];
243    let creds = creds.then(nix::sys::socket::UnixCredentials::new);
244    let sent = match (fds, creds.as_ref()) {
245        ([], None) => nix_eintr!(sendmsg::<()>(fd, &iov, &[], MsgFlags::empty(), None))?,
246        ([], Some(creds)) => nix_eintr!(sendmsg::<()>(
247            fd,
248            &iov,
249            &[ControlMessage::ScmCredentials(creds),],
250            MsgFlags::empty(),
251            None,
252        ))?,
253        (fds, Some(creds)) => {
254            let cmsgs = &[
255                ControlMessage::ScmRights(fds),
256                ControlMessage::ScmCredentials(creds),
257            ];
258            nix_eintr!(sendmsg::<()>(fd, &iov, cmsgs, MsgFlags::empty(), None,))?
259        }
260        (fds, None) => {
261            let cmsgs = &[ControlMessage::ScmRights(fds)];
262            nix_eintr!(sendmsg::<()>(fd, &iov, cmsgs, MsgFlags::empty(), None,))?
263        }
264    };
265    if sent == 0 {
266        return Err(io::Error::new(io::ErrorKind::WriteZero, "could not send"));
267    }
268    Ok(sent)
269}
270
271#[cfg(not(any(target_os = "android", target_os = "linux")))]
272fn send_impl(fd: RawFd, data: &[u8], fds: &[RawFd], _creds: bool) -> io::Result<usize> {
273    let iov = [IoSlice::new(data)];
274    let sent = if !fds.is_empty() {
275        nix_eintr!(sendmsg::<()>(
276            fd,
277            &iov,
278            &[ControlMessage::ScmRights(fds)],
279            MsgFlags::empty(),
280            None,
281        ))?
282    } else {
283        nix_eintr!(sendmsg::<()>(fd, &iov, &[], MsgFlags::empty(), None))?
284    };
285    if sent == 0 {
286        return Err(io::Error::new(io::ErrorKind::WriteZero, "could not send"));
287    }
288    Ok(sent)
289}
290
291/// Creates a raw connected channel.
292pub fn raw_channel() -> io::Result<(RawSender, RawReceiver)> {
293    let (sender, receiver) = tokio::net::UnixStream::pair()?;
294    Ok((
295        RawSender::from_std(sender.into_std()?)?,
296        RawReceiver::from_std(receiver.into_std()?)?,
297    ))
298}
299
300/// Creates a raw connected channel from an already extant socket.
301pub fn raw_channel_from_std(sender: UnixStream) -> io::Result<(RawSender, RawReceiver)> {
302    let receiver = sender.try_clone()?;
303    Ok((
304        RawSender::from_std(sender)?,
305        RawReceiver::from_std(receiver)?,
306    ))
307}
308
309/// An async raw receiver.
310#[derive(Debug)]
311pub struct RawReceiver {
312    inner: AsyncFd<RawFd>,
313    dead: AtomicBool,
314}
315
316impl RawReceiver {
317    /// Connects a receiver to a named unix socket.
318    pub async fn connect<P: AsRef<Path>>(p: P) -> io::Result<RawReceiver> {
319        let stream = tokio::net::UnixStream::connect(p).await?;
320        RawReceiver::from_std(stream.into_std()?)
321    }
322
323    /// Receives raw bytes from the socket.
324    pub async fn recv(&self) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>)> {
325        let mut header = MsgHeader::default();
326        self.recv_impl(header.as_buf_mut(), 0, false).await?;
327        let mut buf = header.make_buffer();
328        let (_, fds, _) = self
329            .recv_impl(&mut buf, header.fd_count as usize, false)
330            .await?;
331        Ok((buf, fds))
332    }
333
334    /// Receives raw bytes and credentials from the socket.
335    #[cfg(any(target_os = "android", target_os = "linux"))]
336    pub async fn recv_with_credentials(
337        &self,
338    ) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>, Credentials)> {
339        nix::sys::socket::setsockopt(
340            self.inner.as_raw_fd(),
341            nix::sys::socket::sockopt::PassCred,
342            &true,
343        )?;
344        let mut header = MsgHeader::default();
345        let (_, _, creds) = self.recv_impl(header.as_buf_mut(), 0, true).await?;
346        let creds = creds.ok_or_else(|| {
347            io::Error::new(
348                io::ErrorKind::InvalidData,
349                "Remote did not provide credentials",
350            )
351        })?;
352        let mut buf = header.make_buffer();
353        let (_, fds, _) = self
354            .recv_impl(&mut buf, header.fd_count as usize, false)
355            .await?;
356        Ok((buf, fds, creds))
357    }
358
359    async fn recv_impl(
360        &self,
361        buf: &mut [u8],
362        fd_count: usize,
363        want_creds: bool,
364    ) -> io::Result<(usize, Option<Vec<RawFd>>, Option<Credentials>)> {
365        let mut pos = 0;
366        let mut fds = None;
367
368        loop {
369            let mut guard = self.inner.readable().await?;
370            let (bytes, new_fds, creds) = match guard.try_io(|inner| {
371                recv_impl(
372                    inner.as_raw_fd(),
373                    &mut buf[pos..],
374                    fds.take(),
375                    fd_count,
376                    want_creds,
377                )
378            }) {
379                Ok(result) => result,
380                Err(_would_block) => continue,
381            }?;
382
383            fds = new_fds;
384            pos += bytes;
385            if pos >= buf.len() {
386                return Ok((pos, fds, creds));
387            }
388        }
389    }
390}
391
392unsafe impl Send for RawReceiver {}
393unsafe impl Sync for RawReceiver {}
394
395/// An async raw sender.
396#[derive(Debug)]
397pub struct RawSender {
398    inner: AsyncFd<RawFd>,
399    #[allow(dead_code)]
400    dead: AtomicBool,
401}
402
403impl RawSender {
404    /// Sends raw bytes and fds.
405    pub async fn send(&self, data: &[u8], fds: &[RawFd]) -> io::Result<usize> {
406        let header = MsgHeader {
407            payload_len: data.len() as u32,
408            fd_count: fds.len() as u32,
409        };
410        self.send_impl(header.as_buf(), &[][..], false).await?;
411        self.send_impl(data, fds, false).await
412    }
413
414    /// Sends raw bytes and fds along with current process credentials.
415    #[cfg(any(target_os = "android", target_os = "linux"))]
416    pub async fn send_with_credentials(&self, data: &[u8], fds: &[RawFd]) -> io::Result<usize> {
417        let header = MsgHeader {
418            payload_len: data.len() as u32,
419            fd_count: fds.len() as u32,
420        };
421        self.send_impl(header.as_buf(), &[][..], true).await?;
422        self.send_impl(data, fds, false).await
423    }
424
425    async fn send_impl(&self, data: &[u8], mut fds: &[RawFd], creds: bool) -> io::Result<usize> {
426        let mut pos = 0;
427        loop {
428            let mut guard = self.inner.writable().await?;
429            let sent = match guard
430                .try_io(|inner| send_impl(inner.as_raw_fd(), &data[pos..], fds, creds))
431            {
432                Ok(result) => result,
433                Err(_would_block) => continue,
434            }?;
435            pos += sent;
436            fds = &[][..];
437            if pos >= data.len() {
438                return Ok(pos);
439            }
440        }
441    }
442}
443
444unsafe impl Send for RawSender {}
445unsafe impl Sync for RawSender {}