1use std::io::{IoSlice, IoSliceMut};
2use std::marker::PhantomData;
3use std::os::fd::AsRawFd;
4use std::os::unix::prelude::RawFd;
5
6use nix::sys::socket::{self, UnixAddr};
7use nix::unistd::{self};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, thiserror::Error)]
11pub enum ChannelError {
12 #[error("failed unix syscalls")]
13 Nix(#[from] nix::Error),
14 #[error("failed serde serialization")]
15 Serde(#[from] serde_json::Error),
16 #[error("channel connection broken")]
17 BrokenChannel,
18}
19pub struct Receiver<T> {
20 receiver: RawFd,
21 phantom: PhantomData<T>,
22}
23
24pub struct Sender<T> {
25 sender: RawFd,
26 phantom: PhantomData<T>,
27}
28
29impl<T> Sender<T>
30where
31 T: Serialize,
32{
33 fn send_iovec(
34 &mut self,
35 iov: &[IoSlice],
36 fds: Option<&[RawFd]>,
37 ) -> Result<usize, ChannelError> {
38 let cmsgs = if let Some(fds) = fds {
39 vec![socket::ControlMessage::ScmRights(fds)]
40 } else {
41 vec![]
42 };
43 socket::sendmsg::<UnixAddr>(self.sender, iov, &cmsgs, socket::MsgFlags::empty(), None)
44 .map_err(|e| e.into())
45 }
46
47 fn send_slice_with_len(
48 &mut self,
49 data: &[u8],
50 fds: Option<&[RawFd]>,
51 ) -> Result<usize, ChannelError> {
52 let len = data.len() as u64;
53 let iov = [
55 IoSlice::new(unsafe {
56 std::slice::from_raw_parts(
57 (&len as *const u64) as *const u8,
58 std::mem::size_of::<u64>(),
59 )
60 }),
61 IoSlice::new(data),
62 ];
63 self.send_iovec(&iov[..], fds)
64 }
65
66 pub fn send(&mut self, object: T) -> Result<(), ChannelError> {
67 let payload = serde_json::to_vec(&object)?;
68 self.send_slice_with_len(&payload, None)?;
69
70 Ok(())
71 }
72
73 pub fn send_fds(&mut self, object: T, fds: &[RawFd]) -> Result<(), ChannelError> {
74 let payload = serde_json::to_vec(&object)?;
75 self.send_slice_with_len(&payload, Some(fds))?;
76
77 Ok(())
78 }
79
80 pub fn close(&self) -> Result<(), ChannelError> {
81 Ok(unistd::close(self.sender)?)
82 }
83}
84
85impl<T> Receiver<T>
86where
87 T: serde::de::DeserializeOwned,
88{
89 fn peek_size_iovec(&mut self) -> Result<u64, ChannelError> {
90 let mut len: u64 = 0;
91 let mut iov = [IoSliceMut::new(unsafe {
92 std::slice::from_raw_parts_mut(
93 (&mut len as *mut u64) as *mut u8,
94 std::mem::size_of::<u64>(),
95 )
96 })];
97 let _ =
98 socket::recvmsg::<UnixAddr>(self.receiver, &mut iov, None, socket::MsgFlags::MSG_PEEK)?;
99 match len {
100 0 => Err(ChannelError::BrokenChannel),
101 _ => Ok(len),
102 }
103 }
104
105 fn recv_into_iovec<F>(
106 &mut self,
107 iov: &mut [IoSliceMut],
108 ) -> Result<(usize, Option<F>), ChannelError>
109 where
110 F: Default + AsMut<[RawFd]>,
111 {
112 let mut cmsgspace = nix::cmsg_space!(F);
113 let msg = socket::recvmsg::<UnixAddr>(
114 self.receiver,
115 iov,
116 Some(&mut cmsgspace),
117 socket::MsgFlags::MSG_CMSG_CLOEXEC,
118 )?;
119
120 let fds: Option<F> = msg
126 .cmsgs()?
127 .find_map(|cmsg| {
128 if let socket::ControlMessageOwned::ScmRights(fds) = cmsg {
129 Some(fds)
130 } else {
131 None
132 }
133 })
134 .map(|fds| {
135 let mut fds_array: F = Default::default();
136 <F as AsMut<[RawFd]>>::as_mut(&mut fds_array).clone_from_slice(&fds);
137 fds_array
138 });
139
140 Ok((msg.bytes, fds))
141 }
142
143 fn recv_into_buf_with_len<F>(&mut self) -> Result<(Vec<u8>, Option<F>), ChannelError>
144 where
145 F: Default + AsMut<[RawFd]>,
146 {
147 let msg_len = self.peek_size_iovec()?;
148 let mut len: u64 = 0;
149 let mut buf = vec![0u8; msg_len as usize];
150 let (bytes, fds) = {
151 let mut iov = [
152 IoSliceMut::new(unsafe {
153 std::slice::from_raw_parts_mut(
154 (&mut len as *mut u64) as *mut u8,
155 std::mem::size_of::<u64>(),
156 )
157 }),
158 IoSliceMut::new(&mut buf),
159 ];
160 self.recv_into_iovec(&mut iov)?
161 };
162
163 match bytes {
164 0 => Err(ChannelError::BrokenChannel),
165 _ => Ok((buf, fds)),
166 }
167 }
168
169 pub fn recv(&mut self) -> Result<T, ChannelError> {
171 let (buf, _) = self.recv_into_buf_with_len::<[RawFd; 0]>()?;
172 Ok(serde_json::from_slice(&buf[..])?)
173 }
174
175 pub fn recv_with_fds<F>(&mut self) -> Result<(T, Option<F>), ChannelError>
179 where
180 F: Default + AsMut<[RawFd]>,
181 {
182 let (buf, fds) = self.recv_into_buf_with_len::<F>()?;
183 Ok((serde_json::from_slice(&buf[..])?, fds))
184 }
185
186 pub fn close(&self) -> Result<(), ChannelError> {
187 Ok(unistd::close(self.receiver)?)
188 }
189}
190
191pub fn channel<T>() -> Result<(Sender<T>, Receiver<T>), ChannelError>
192where
193 T: for<'de> Deserialize<'de> + Serialize,
194{
195 let (os_sender, os_receiver) = unix_channel()?;
196 let receiver = Receiver {
197 receiver: os_receiver,
198 phantom: PhantomData,
199 };
200 let sender = Sender {
201 sender: os_sender,
202 phantom: PhantomData,
203 };
204 Ok((sender, receiver))
205}
206
207fn unix_channel() -> Result<(RawFd, RawFd), ChannelError> {
209 let (f1, f2) = socket::socketpair(
210 socket::AddressFamily::Unix,
211 socket::SockType::SeqPacket,
212 None,
213 socket::SockFlag::SOCK_CLOEXEC,
214 )?;
215 let f1 = std::mem::ManuallyDrop::new(f1);
218 let f2 = std::mem::ManuallyDrop::new(f2);
219
220 Ok((f1.as_raw_fd(), f2.as_raw_fd()))
221}