unix_ipc/
raw_channel.rs

1use std::io;
2use std::mem;
3use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::slice;
7use std::sync::atomic::{AtomicBool, Ordering};
8
9use nix::sys::socket::{
10    c_uint, recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, CMSG_SPACE,
11};
12use nix::sys::uio::IoVec;
13use nix::unistd;
14
15#[cfg(target_os = "linux")]
16const MSG_FLAGS: MsgFlags = MsgFlags::MSG_CMSG_CLOEXEC;
17
18#[cfg(target_os = "macos")]
19const MSG_FLAGS: MsgFlags = MsgFlags::empty();
20
21/// A raw receiver.
22#[derive(Debug)]
23pub struct RawReceiver {
24    fd: RawFd,
25    dead: AtomicBool,
26}
27
28/// A raw sender.
29#[derive(Debug)]
30pub struct RawSender {
31    fd: RawFd,
32    dead: AtomicBool,
33}
34
35/// Creates a raw connected channel.
36pub fn raw_channel() -> io::Result<(RawSender, RawReceiver)> {
37    let (sender, receiver) = UnixStream::pair()?;
38    unsafe {
39        Ok((
40            RawSender::from_raw_fd(sender.into_raw_fd()),
41            RawReceiver::from_raw_fd(receiver.into_raw_fd()),
42        ))
43    }
44}
45
46#[repr(C)]
47#[derive(Default, Debug)]
48struct MsgHeader {
49    payload_len: u32,
50    fd_count: u32,
51}
52
53macro_rules! fd_impl {
54    ($ty:ty) => {
55        #[allow(dead_code)]
56        impl $ty {
57            pub(crate) fn extract_raw_fd(&self) -> RawFd {
58                if self.dead.swap(true, Ordering::SeqCst) {
59                    panic!("handle was moved previously");
60                } else {
61                    self.fd
62                }
63            }
64        }
65
66        impl FromRawFd for $ty {
67            unsafe fn from_raw_fd(fd: RawFd) -> Self {
68                Self {
69                    fd,
70                    dead: AtomicBool::new(false),
71                }
72            }
73        }
74
75        impl IntoRawFd for $ty {
76            fn into_raw_fd(self) -> RawFd {
77                let fd = self.fd;
78                mem::forget(self);
79                fd
80            }
81        }
82
83        impl AsRawFd for $ty {
84            fn as_raw_fd(&self) -> RawFd {
85                self.fd
86            }
87        }
88
89        impl Drop for $ty {
90            fn drop(&mut self) {
91                unistd::close(self.fd).ok();
92            }
93        }
94    };
95}
96
97fd_impl!(RawReceiver);
98fd_impl!(RawSender);
99
100impl RawReceiver {
101    /// Connects a receiver to a named unix socket.
102    pub fn connect<P: AsRef<Path>>(p: P) -> io::Result<RawReceiver> {
103        let sock = UnixStream::connect(p)?;
104        unsafe { Ok(RawReceiver::from_raw_fd(sock.into_raw_fd())) }
105    }
106
107    pub fn recv(&self) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>)> {
108        self.recv_impl(true)
109    }
110
111    pub fn try_recv(&self) -> io::Result<Option<(Vec<u8>, Option<Vec<RawFd>>)>> {
112        let res = self.recv_impl(false);
113
114        match res {
115            Ok(res) => Ok(Some(res)),
116            Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
117            Err(err) => Err(err),
118        }
119    }
120
121    /// Receives raw bytes from the socket.
122    fn recv_impl(&self, blocking: bool) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>)> {
123        let mut header = MsgHeader::default();
124        self.recv_part(
125            unsafe {
126                slice::from_raw_parts_mut(
127                    (&mut header as *mut _) as *mut u8,
128                    mem::size_of_val(&header),
129                )
130            },
131            0,
132            blocking,
133        )?;
134
135        let mut buf = vec![0u8; header.payload_len as usize];
136        // Once the header is received, the body must always follow
137        let (_, fds) = self.recv_part(&mut buf, header.fd_count as usize, true)?;
138        Ok((buf, fds))
139    }
140
141    fn recv_part(
142        &self,
143        buf: &mut [u8],
144        fd_count: usize,
145        blocking: bool,
146    ) -> io::Result<(usize, Option<Vec<RawFd>>)> {
147        let mut pos = 0;
148        let mut fds = None;
149
150        loop {
151            let iov = [IoVec::from_mut_slice(&mut buf[pos..])];
152            let mut new_fds = None;
153            let msgspace_size =
154                unsafe { CMSG_SPACE(mem::size_of::<RawFd>() as c_uint) * fd_count as u32 };
155            let mut cmsgspace = vec![0u8; msgspace_size as usize];
156
157            let flags = if blocking {
158                MSG_FLAGS
159            } else {
160                MSG_FLAGS | MsgFlags::MSG_DONTWAIT
161            };
162
163            let msg = recvmsg(self.fd, &iov, Some(&mut cmsgspace), flags)?;
164
165            for cmsg in msg.cmsgs() {
166                if let ControlMessageOwned::ScmRights(fds) = cmsg {
167                    if !fds.is_empty() {
168                        #[cfg(target_os = "macos")]
169                        unsafe {
170                            for &fd in &fds {
171                                libc::ioctl(fd, libc::FIOCLEX);
172                            }
173                        }
174                        new_fds = Some(fds);
175                    }
176                }
177            }
178
179            fds = match (fds, new_fds) {
180                (None, Some(new)) => Some(new),
181                (Some(mut old), Some(new)) => {
182                    old.extend(new);
183                    Some(old)
184                }
185                (old, None) => old,
186            };
187
188            if msg.bytes == 0 {
189                return Err(io::Error::new(
190                    io::ErrorKind::UnexpectedEof,
191                    "could not read",
192                ));
193            }
194
195            pos += msg.bytes;
196            if pos >= buf.len() {
197                return Ok((pos, fds));
198            }
199        }
200    }
201}
202
203impl RawSender {
204    /// Sends raw bytes and fds.
205    pub fn send(&self, data: &[u8], fds: &[RawFd]) -> io::Result<usize> {
206        let header = MsgHeader {
207            payload_len: data.len() as u32,
208            fd_count: fds.len() as u32,
209        };
210        let header_slice = unsafe {
211            slice::from_raw_parts(
212                (&header as *const _) as *const u8,
213                mem::size_of_val(&header),
214            )
215        };
216
217        self.send_impl(&header_slice, &[][..])?;
218        self.send_impl(&data, fds)
219    }
220
221    fn send_impl(&self, data: &[u8], mut fds: &[RawFd]) -> io::Result<usize> {
222        let mut pos = 0;
223        loop {
224            let iov = [IoVec::from_slice(&data[pos..])];
225            let sent = if !fds.is_empty() {
226                sendmsg(
227                    self.fd,
228                    &iov,
229                    &[ControlMessage::ScmRights(fds)],
230                    MsgFlags::empty(),
231                    None,
232                )?
233            } else {
234                sendmsg(self.fd, &iov, &[], MsgFlags::empty(), None)?
235            };
236            if sent == 0 {
237                return Err(io::Error::new(io::ErrorKind::WriteZero, "could not send"));
238            }
239            pos += sent;
240            fds = &[][..];
241            if pos >= data.len() {
242                return Ok(pos);
243            }
244        }
245    }
246}
247
248#[test]
249fn test_basic() {
250    let (tx, rx) = raw_channel().unwrap();
251
252    let server = std::thread::spawn(move || {
253        tx.send(b"Hello World!", &[][..]).unwrap();
254    });
255
256    std::thread::sleep(std::time::Duration::from_millis(10));
257
258    let client = std::thread::spawn(move || {
259        let (bytes, fds) = rx.recv().unwrap();
260        assert_eq!(bytes, b"Hello World!");
261        assert_eq!(fds, None);
262    });
263
264    server.join().unwrap();
265    client.join().unwrap();
266}
267
268#[test]
269fn test_large_buffer() {
270    use std::fmt::Write;
271
272    let mut buf = String::new();
273    for x in 0..10000 {
274        write!(&mut buf, "{}", x).ok();
275    }
276
277    let (tx, rx) = raw_channel().unwrap();
278
279    let server_buf = buf.clone();
280    let server = std::thread::spawn(move || {
281        tx.send(server_buf.as_bytes(), &[][..]).unwrap();
282    });
283
284    std::thread::sleep(std::time::Duration::from_millis(10));
285
286    let client = std::thread::spawn(move || {
287        let (bytes, fds) = rx.recv().unwrap();
288        assert_eq!(bytes, buf.as_bytes());
289        assert_eq!(fds, None);
290    });
291
292    server.join().unwrap();
293    client.join().unwrap();
294}