unix_ipc/
typed_channel.rs1use std::fmt;
2use std::io;
3use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
4use std::path::Path;
5
6use serde_::de::DeserializeOwned;
7use serde_::Serialize;
8
9use crate::raw_channel::{raw_channel, RawReceiver, RawSender};
10use crate::serde::{deserialize, serialize};
11
12pub struct Receiver<T> {
14 raw_receiver: RawReceiver,
15 _marker: std::marker::PhantomData<T>,
16}
17
18pub struct Sender<T> {
20 raw_sender: RawSender,
21 _marker: std::marker::PhantomData<T>,
22}
23
24impl<T> fmt::Debug for Receiver<T> {
25 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26 f.debug_struct("Receiver")
27 .field("fd", &self.as_raw_fd())
28 .finish()
29 }
30}
31
32impl<T> fmt::Debug for Sender<T> {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 f.debug_struct("Sender")
35 .field("fd", &self.as_raw_fd())
36 .finish()
37 }
38}
39
40macro_rules! fd_impl {
41 ($field:ident, $raw_ty:ty, $ty:ty) => {
42 #[allow(dead_code)]
43 impl<T> $ty {
44 pub(crate) fn extract_raw_fd(&self) -> RawFd {
45 self.$field.extract_raw_fd()
46 }
47 }
48
49 impl<T: Serialize + DeserializeOwned> From<$raw_ty> for $ty {
50 fn from(value: $raw_ty) -> Self {
51 Self {
52 $field: value,
53 _marker: std::marker::PhantomData,
54 }
55 }
56 }
57
58 impl<T: Serialize + DeserializeOwned> FromRawFd for $ty {
59 unsafe fn from_raw_fd(fd: RawFd) -> Self {
60 Self {
61 $field: FromRawFd::from_raw_fd(fd),
62 _marker: std::marker::PhantomData,
63 }
64 }
65 }
66
67 impl<T> IntoRawFd for $ty {
68 fn into_raw_fd(self) -> RawFd {
69 self.$field.into_raw_fd()
70 }
71 }
72
73 impl<T> AsRawFd for $ty {
74 fn as_raw_fd(&self) -> RawFd {
75 self.$field.as_raw_fd()
76 }
77 }
78 };
79}
80
81fd_impl!(raw_receiver, RawReceiver, Receiver<T>);
82fd_impl!(raw_sender, RawSender, Sender<T>);
83
84pub fn channel<T: Serialize + DeserializeOwned>() -> io::Result<(Sender<T>, Receiver<T>)> {
86 let (sender, receiver) = raw_channel()?;
87 Ok((sender.into(), receiver.into()))
88}
89
90impl<T: Serialize + DeserializeOwned> Receiver<T> {
91 pub fn connect<P: AsRef<Path>>(p: P) -> io::Result<Receiver<T>> {
93 RawReceiver::connect(p).map(Into::into)
94 }
95
96 pub fn into_raw_receiver(self) -> RawReceiver {
98 self.raw_receiver
99 }
100
101 pub fn try_recv(&self) -> io::Result<Option<T>> {
103 let res = self.raw_receiver.try_recv()?;
104 if let Some((buf, fds)) = res {
105 Ok(Some(
106 deserialize::<(T, bool)>(&buf, fds.as_deref().unwrap_or_default()).map(|x| x.0)?,
107 ))
108 } else {
109 Ok(None)
110 }
111 }
112
113 pub fn recv(&self) -> io::Result<T> {
115 let (buf, fds) = self.raw_receiver.recv()?;
116 deserialize::<(T, bool)>(&buf, fds.as_deref().unwrap_or_default()).map(|x| x.0)
117 }
118}
119
120impl<T: Serialize + DeserializeOwned> Sender<T> {
121 pub fn into_raw_sender(self) -> RawSender {
123 self.raw_sender
124 }
125
126 pub fn send(&self, s: T) -> io::Result<()> {
128 let (payload, fds) = serialize((&s, true))?;
131 self.raw_sender.send(&payload, &fds)?;
132 Ok(())
133 }
134}
135
136#[test]
137fn test_basic() {
138 use crate::serde::Handle;
139 use std::io::Read;
140
141 let f = Handle::from(std::fs::File::open("src/serde.rs").unwrap());
142
143 let (tx, rx) = channel().unwrap();
144
145 let server = std::thread::spawn(move || {
146 tx.send(f).unwrap();
147 });
148
149 std::thread::sleep(std::time::Duration::from_millis(10));
150
151 let client = std::thread::spawn(move || {
152 let f = rx.recv().unwrap();
153
154 let mut out = Vec::new();
155 f.into_inner().read_to_end(&mut out).unwrap();
156 assert!(out.len() > 100);
157 });
158
159 server.join().unwrap();
160 client.join().unwrap();
161}
162
163#[test]
164fn test_send_channel() {
165 use crate::serde::Handle;
166 use std::fs::File;
167 use std::io::Read;
168
169 let (tx, rx) = channel().unwrap();
170 let (sender, receiver) = channel::<Handle<File>>().unwrap();
171
172 let server = std::thread::spawn(move || {
173 tx.send(sender).unwrap();
174 let handle = receiver.recv().unwrap();
175 let mut file = handle.into_inner();
176 let mut out = Vec::new();
177 file.read_to_end(&mut out).unwrap();
178 assert!(out.len() > 100);
179 });
180
181 std::thread::sleep(std::time::Duration::from_millis(10));
182
183 let client = std::thread::spawn(move || {
184 let sender = rx.recv().unwrap();
185 sender
186 .send(Handle::from(File::open("src/serde.rs").unwrap()))
187 .unwrap();
188 });
189
190 server.join().unwrap();
191 client.join().unwrap();
192}
193
194#[test]
195fn test_multiple_fds() {
196 let (tx1, rx1) = channel().unwrap();
197 let (tx2, rx2) = channel::<()>().unwrap();
198 let (tx3, rx3) = channel::<()>().unwrap();
199
200 let a = std::thread::spawn(move || {
201 tx1.send((tx2, rx2, tx3, rx3)).unwrap();
202 });
203
204 let b = std::thread::spawn(move || {
205 let _channels = rx1.recv().unwrap();
206 });
207
208 a.join().unwrap();
209 b.join().unwrap();
210}
211
212#[test]
213fn test_conversion() {
214 let (tx, rx) = channel::<i32>().unwrap();
215 let raw_tx = tx.into_raw_sender();
216 let raw_rx = rx.into_raw_receiver();
217 let tx = Sender::<bool>::from(raw_tx);
218 let rx = Receiver::<bool>::from(raw_rx);
219
220 let a = std::thread::spawn(move || {
221 tx.send(true).unwrap();
222 });
223
224 let b = std::thread::spawn(move || {
225 assert_eq!(rx.recv().unwrap(), true);
226 });
227
228 a.join().unwrap();
229 b.join().unwrap();
230}
231
232#[test]
233fn test_zero_sized_type() {
234 let (tx, rx) = channel::<()>().unwrap();
235
236 let a = std::thread::spawn(move || {
237 tx.send(()).unwrap();
238 });
239
240 let b = std::thread::spawn(move || {
241 rx.recv().unwrap();
242 });
243
244 a.join().unwrap();
245 b.join().unwrap();
246}
247
248#[test]
249fn test_many_nested() {
250 for _ in 0..2000 {
251 let (tx, rx) = channel().unwrap();
252 let (tx2, rx2) = channel().unwrap();
253
254 tx.send(tx2).unwrap();
255
256 let recv = rx.recv().unwrap();
257
258 recv.send(1).unwrap();
259
260 rx2.recv().unwrap();
261 }
262}
263
264#[test]
265fn test_try_recv() {
266 let (tx, rx) = channel().unwrap();
267
268 assert!(rx.try_recv().unwrap().is_none());
269
270 tx.send(1_f32).unwrap();
271
272 assert!(rx.try_recv().unwrap().is_some())
273}