unix_ipc/
typed_channel.rs

1use std::fmt;
2use std::io;
3use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
4use std::path::Path;
5
6use serde_::de::DeserializeOwned;
7use serde_::Serialize;
8
9use crate::raw_channel::{raw_channel, RawReceiver, RawSender};
10use crate::serde::{deserialize, serialize};
11
12/// A typed receiver.
13pub struct Receiver<T> {
14    raw_receiver: RawReceiver,
15    _marker: std::marker::PhantomData<T>,
16}
17
18/// A typed sender.
19pub struct Sender<T> {
20    raw_sender: RawSender,
21    _marker: std::marker::PhantomData<T>,
22}
23
24impl<T> fmt::Debug for Receiver<T> {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        f.debug_struct("Receiver")
27            .field("fd", &self.as_raw_fd())
28            .finish()
29    }
30}
31
32impl<T> fmt::Debug for Sender<T> {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        f.debug_struct("Sender")
35            .field("fd", &self.as_raw_fd())
36            .finish()
37    }
38}
39
40macro_rules! fd_impl {
41    ($field:ident, $raw_ty:ty, $ty:ty) => {
42        #[allow(dead_code)]
43        impl<T> $ty {
44            pub(crate) fn extract_raw_fd(&self) -> RawFd {
45                self.$field.extract_raw_fd()
46            }
47        }
48
49        impl<T: Serialize + DeserializeOwned> From<$raw_ty> for $ty {
50            fn from(value: $raw_ty) -> Self {
51                Self {
52                    $field: value,
53                    _marker: std::marker::PhantomData,
54                }
55            }
56        }
57
58        impl<T: Serialize + DeserializeOwned> FromRawFd for $ty {
59            unsafe fn from_raw_fd(fd: RawFd) -> Self {
60                Self {
61                    $field: FromRawFd::from_raw_fd(fd),
62                    _marker: std::marker::PhantomData,
63                }
64            }
65        }
66
67        impl<T> IntoRawFd for $ty {
68            fn into_raw_fd(self) -> RawFd {
69                self.$field.into_raw_fd()
70            }
71        }
72
73        impl<T> AsRawFd for $ty {
74            fn as_raw_fd(&self) -> RawFd {
75                self.$field.as_raw_fd()
76            }
77        }
78    };
79}
80
81fd_impl!(raw_receiver, RawReceiver, Receiver<T>);
82fd_impl!(raw_sender, RawSender, Sender<T>);
83
84/// Creates a typed connected channel.
85pub fn channel<T: Serialize + DeserializeOwned>() -> io::Result<(Sender<T>, Receiver<T>)> {
86    let (sender, receiver) = raw_channel()?;
87    Ok((sender.into(), receiver.into()))
88}
89
90impl<T: Serialize + DeserializeOwned> Receiver<T> {
91    /// Connects a receiver to a named unix socket.
92    pub fn connect<P: AsRef<Path>>(p: P) -> io::Result<Receiver<T>> {
93        RawReceiver::connect(p).map(Into::into)
94    }
95
96    /// Converts the typed receiver into a raw one.
97    pub fn into_raw_receiver(self) -> RawReceiver {
98        self.raw_receiver
99    }
100
101    /// Receives a structured message from the socket if there is a message available.
102    pub fn try_recv(&self) -> io::Result<Option<T>> {
103        let res = self.raw_receiver.try_recv()?;
104        if let Some((buf, fds)) = res {
105            Ok(Some(
106                deserialize::<(T, bool)>(&buf, fds.as_deref().unwrap_or_default()).map(|x| x.0)?,
107            ))
108        } else {
109            Ok(None)
110        }
111    }
112
113    /// Receives a structured message from the socket.
114    pub fn recv(&self) -> io::Result<T> {
115        let (buf, fds) = self.raw_receiver.recv()?;
116        deserialize::<(T, bool)>(&buf, fds.as_deref().unwrap_or_default()).map(|x| x.0)
117    }
118}
119
120impl<T: Serialize + DeserializeOwned> Sender<T> {
121    /// Converts the typed sender into a raw one.
122    pub fn into_raw_sender(self) -> RawSender {
123        self.raw_sender
124    }
125
126    /// Receives a structured message from the socket.
127    pub fn send(&self, s: T) -> io::Result<()> {
128        // we always serialize a dummy bool at the end so that the message
129        // will not be empty because of zero sized types.
130        let (payload, fds) = serialize((&s, true))?;
131        self.raw_sender.send(&payload, &fds)?;
132        Ok(())
133    }
134}
135
136#[test]
137fn test_basic() {
138    use crate::serde::Handle;
139    use std::io::Read;
140
141    let f = Handle::from(std::fs::File::open("src/serde.rs").unwrap());
142
143    let (tx, rx) = channel().unwrap();
144
145    let server = std::thread::spawn(move || {
146        tx.send(f).unwrap();
147    });
148
149    std::thread::sleep(std::time::Duration::from_millis(10));
150
151    let client = std::thread::spawn(move || {
152        let f = rx.recv().unwrap();
153
154        let mut out = Vec::new();
155        f.into_inner().read_to_end(&mut out).unwrap();
156        assert!(out.len() > 100);
157    });
158
159    server.join().unwrap();
160    client.join().unwrap();
161}
162
163#[test]
164fn test_send_channel() {
165    use crate::serde::Handle;
166    use std::fs::File;
167    use std::io::Read;
168
169    let (tx, rx) = channel().unwrap();
170    let (sender, receiver) = channel::<Handle<File>>().unwrap();
171
172    let server = std::thread::spawn(move || {
173        tx.send(sender).unwrap();
174        let handle = receiver.recv().unwrap();
175        let mut file = handle.into_inner();
176        let mut out = Vec::new();
177        file.read_to_end(&mut out).unwrap();
178        assert!(out.len() > 100);
179    });
180
181    std::thread::sleep(std::time::Duration::from_millis(10));
182
183    let client = std::thread::spawn(move || {
184        let sender = rx.recv().unwrap();
185        sender
186            .send(Handle::from(File::open("src/serde.rs").unwrap()))
187            .unwrap();
188    });
189
190    server.join().unwrap();
191    client.join().unwrap();
192}
193
194#[test]
195fn test_multiple_fds() {
196    let (tx1, rx1) = channel().unwrap();
197    let (tx2, rx2) = channel::<()>().unwrap();
198    let (tx3, rx3) = channel::<()>().unwrap();
199
200    let a = std::thread::spawn(move || {
201        tx1.send((tx2, rx2, tx3, rx3)).unwrap();
202    });
203
204    let b = std::thread::spawn(move || {
205        let _channels = rx1.recv().unwrap();
206    });
207
208    a.join().unwrap();
209    b.join().unwrap();
210}
211
212#[test]
213fn test_conversion() {
214    let (tx, rx) = channel::<i32>().unwrap();
215    let raw_tx = tx.into_raw_sender();
216    let raw_rx = rx.into_raw_receiver();
217    let tx = Sender::<bool>::from(raw_tx);
218    let rx = Receiver::<bool>::from(raw_rx);
219
220    let a = std::thread::spawn(move || {
221        tx.send(true).unwrap();
222    });
223
224    let b = std::thread::spawn(move || {
225        assert_eq!(rx.recv().unwrap(), true);
226    });
227
228    a.join().unwrap();
229    b.join().unwrap();
230}
231
232#[test]
233fn test_zero_sized_type() {
234    let (tx, rx) = channel::<()>().unwrap();
235
236    let a = std::thread::spawn(move || {
237        tx.send(()).unwrap();
238    });
239
240    let b = std::thread::spawn(move || {
241        rx.recv().unwrap();
242    });
243
244    a.join().unwrap();
245    b.join().unwrap();
246}
247
248#[test]
249fn test_many_nested() {
250    for _ in 0..2000 {
251        let (tx, rx) = channel().unwrap();
252        let (tx2, rx2) = channel().unwrap();
253
254        tx.send(tx2).unwrap();
255
256        let recv = rx.recv().unwrap();
257
258        recv.send(1).unwrap();
259
260        rx2.recv().unwrap();
261    }
262}
263
264#[test]
265fn test_try_recv() {
266    let (tx, rx) = channel().unwrap();
267
268    assert!(rx.try_recv().unwrap().is_none());
269
270    tx.send(1_f32).unwrap();
271
272    assert!(rx.try_recv().unwrap().is_some())
273}