1use std::io;
2use std::mem;
3use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::slice;
7use std::sync::atomic::{AtomicBool, Ordering};
8
9use nix::sys::socket::{
10 c_uint, recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, CMSG_SPACE,
11};
12use nix::sys::uio::IoVec;
13use nix::unistd;
14
15#[cfg(target_os = "linux")]
16const MSG_FLAGS: MsgFlags = MsgFlags::MSG_CMSG_CLOEXEC;
17
18#[cfg(target_os = "macos")]
19const MSG_FLAGS: MsgFlags = MsgFlags::empty();
20
21#[derive(Debug)]
23pub struct RawReceiver {
24 fd: RawFd,
25 dead: AtomicBool,
26}
27
28#[derive(Debug)]
30pub struct RawSender {
31 fd: RawFd,
32 dead: AtomicBool,
33}
34
35pub fn raw_channel() -> io::Result<(RawSender, RawReceiver)> {
37 let (sender, receiver) = UnixStream::pair()?;
38 unsafe {
39 Ok((
40 RawSender::from_raw_fd(sender.into_raw_fd()),
41 RawReceiver::from_raw_fd(receiver.into_raw_fd()),
42 ))
43 }
44}
45
46#[repr(C)]
47#[derive(Default, Debug)]
48struct MsgHeader {
49 payload_len: u32,
50 fd_count: u32,
51}
52
53macro_rules! fd_impl {
54 ($ty:ty) => {
55 #[allow(dead_code)]
56 impl $ty {
57 pub(crate) fn extract_raw_fd(&self) -> RawFd {
58 if self.dead.swap(true, Ordering::SeqCst) {
59 panic!("handle was moved previously");
60 } else {
61 self.fd
62 }
63 }
64 }
65
66 impl FromRawFd for $ty {
67 unsafe fn from_raw_fd(fd: RawFd) -> Self {
68 Self {
69 fd,
70 dead: AtomicBool::new(false),
71 }
72 }
73 }
74
75 impl IntoRawFd for $ty {
76 fn into_raw_fd(self) -> RawFd {
77 let fd = self.fd;
78 mem::forget(self);
79 fd
80 }
81 }
82
83 impl AsRawFd for $ty {
84 fn as_raw_fd(&self) -> RawFd {
85 self.fd
86 }
87 }
88
89 impl Drop for $ty {
90 fn drop(&mut self) {
91 unistd::close(self.fd).ok();
92 }
93 }
94 };
95}
96
97fd_impl!(RawReceiver);
98fd_impl!(RawSender);
99
100impl RawReceiver {
101 pub fn connect<P: AsRef<Path>>(p: P) -> io::Result<RawReceiver> {
103 let sock = UnixStream::connect(p)?;
104 unsafe { Ok(RawReceiver::from_raw_fd(sock.into_raw_fd())) }
105 }
106
107 pub fn recv(&self) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>)> {
108 self.recv_impl(true)
109 }
110
111 pub fn try_recv(&self) -> io::Result<Option<(Vec<u8>, Option<Vec<RawFd>>)>> {
112 let res = self.recv_impl(false);
113
114 match res {
115 Ok(res) => Ok(Some(res)),
116 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
117 Err(err) => Err(err),
118 }
119 }
120
121 fn recv_impl(&self, blocking: bool) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>)> {
123 let mut header = MsgHeader::default();
124 self.recv_part(
125 unsafe {
126 slice::from_raw_parts_mut(
127 (&mut header as *mut _) as *mut u8,
128 mem::size_of_val(&header),
129 )
130 },
131 0,
132 blocking,
133 )?;
134
135 let mut buf = vec![0u8; header.payload_len as usize];
136 let (_, fds) = self.recv_part(&mut buf, header.fd_count as usize, true)?;
138 Ok((buf, fds))
139 }
140
141 fn recv_part(
142 &self,
143 buf: &mut [u8],
144 fd_count: usize,
145 blocking: bool,
146 ) -> io::Result<(usize, Option<Vec<RawFd>>)> {
147 let mut pos = 0;
148 let mut fds = None;
149
150 loop {
151 let iov = [IoVec::from_mut_slice(&mut buf[pos..])];
152 let mut new_fds = None;
153 let msgspace_size =
154 unsafe { CMSG_SPACE(mem::size_of::<RawFd>() as c_uint) * fd_count as u32 };
155 let mut cmsgspace = vec![0u8; msgspace_size as usize];
156
157 let flags = if blocking {
158 MSG_FLAGS
159 } else {
160 MSG_FLAGS | MsgFlags::MSG_DONTWAIT
161 };
162
163 let msg = recvmsg(self.fd, &iov, Some(&mut cmsgspace), flags)?;
164
165 for cmsg in msg.cmsgs() {
166 if let ControlMessageOwned::ScmRights(fds) = cmsg {
167 if !fds.is_empty() {
168 #[cfg(target_os = "macos")]
169 unsafe {
170 for &fd in &fds {
171 libc::ioctl(fd, libc::FIOCLEX);
172 }
173 }
174 new_fds = Some(fds);
175 }
176 }
177 }
178
179 fds = match (fds, new_fds) {
180 (None, Some(new)) => Some(new),
181 (Some(mut old), Some(new)) => {
182 old.extend(new);
183 Some(old)
184 }
185 (old, None) => old,
186 };
187
188 if msg.bytes == 0 {
189 return Err(io::Error::new(
190 io::ErrorKind::UnexpectedEof,
191 "could not read",
192 ));
193 }
194
195 pos += msg.bytes;
196 if pos >= buf.len() {
197 return Ok((pos, fds));
198 }
199 }
200 }
201}
202
203impl RawSender {
204 pub fn send(&self, data: &[u8], fds: &[RawFd]) -> io::Result<usize> {
206 let header = MsgHeader {
207 payload_len: data.len() as u32,
208 fd_count: fds.len() as u32,
209 };
210 let header_slice = unsafe {
211 slice::from_raw_parts(
212 (&header as *const _) as *const u8,
213 mem::size_of_val(&header),
214 )
215 };
216
217 self.send_impl(&header_slice, &[][..])?;
218 self.send_impl(&data, fds)
219 }
220
221 fn send_impl(&self, data: &[u8], mut fds: &[RawFd]) -> io::Result<usize> {
222 let mut pos = 0;
223 loop {
224 let iov = [IoVec::from_slice(&data[pos..])];
225 let sent = if !fds.is_empty() {
226 sendmsg(
227 self.fd,
228 &iov,
229 &[ControlMessage::ScmRights(fds)],
230 MsgFlags::empty(),
231 None,
232 )?
233 } else {
234 sendmsg(self.fd, &iov, &[], MsgFlags::empty(), None)?
235 };
236 if sent == 0 {
237 return Err(io::Error::new(io::ErrorKind::WriteZero, "could not send"));
238 }
239 pos += sent;
240 fds = &[][..];
241 if pos >= data.len() {
242 return Ok(pos);
243 }
244 }
245 }
246}
247
248#[test]
249fn test_basic() {
250 let (tx, rx) = raw_channel().unwrap();
251
252 let server = std::thread::spawn(move || {
253 tx.send(b"Hello World!", &[][..]).unwrap();
254 });
255
256 std::thread::sleep(std::time::Duration::from_millis(10));
257
258 let client = std::thread::spawn(move || {
259 let (bytes, fds) = rx.recv().unwrap();
260 assert_eq!(bytes, b"Hello World!");
261 assert_eq!(fds, None);
262 });
263
264 server.join().unwrap();
265 client.join().unwrap();
266}
267
268#[test]
269fn test_large_buffer() {
270 use std::fmt::Write;
271
272 let mut buf = String::new();
273 for x in 0..10000 {
274 write!(&mut buf, "{}", x).ok();
275 }
276
277 let (tx, rx) = raw_channel().unwrap();
278
279 let server_buf = buf.clone();
280 let server = std::thread::spawn(move || {
281 tx.send(server_buf.as_bytes(), &[][..]).unwrap();
282 });
283
284 std::thread::sleep(std::time::Duration::from_millis(10));
285
286 let client = std::thread::spawn(move || {
287 let (bytes, fds) = rx.recv().unwrap();
288 assert_eq!(bytes, buf.as_bytes());
289 assert_eq!(fds, None);
290 });
291
292 server.join().unwrap();
293 client.join().unwrap();
294}