1use std::cell::RefCell;
2use std::io;
3use std::mem;
4use std::os::unix::io::{FromRawFd, IntoRawFd, RawFd};
5use std::sync::Mutex;
6
7use serde_::{de, ser};
8use serde_::{de::DeserializeOwned, Deserialize, Serialize};
9
10thread_local! {
11 static IPC_FDS: RefCell<Vec<Vec<RawFd>>> = RefCell::new(Vec::new());
12}
13
14pub struct Handle<F>(Mutex<Option<F>>);
26
27pub struct HandleRef(pub RawFd);
32
33impl<F: FromRawFd + IntoRawFd> Handle<F> {
34 pub fn new(f: F) -> Self {
36 f.into()
37 }
38
39 fn extract_raw_fd(&self) -> RawFd {
40 self.0
41 .lock()
42 .unwrap()
43 .take()
44 .map(|x| x.into_raw_fd())
45 .expect("cannot serialize handle twice")
46 }
47
48 pub fn into_inner(self) -> F {
50 self.0.lock().unwrap().take().expect("handle was moved")
51 }
52}
53
54impl<F: FromRawFd + IntoRawFd> From<F> for Handle<F> {
55 fn from(f: F) -> Self {
56 Handle(Mutex::new(Some(f)))
57 }
58}
59
60impl Serialize for HandleRef {
61 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62 where
63 S: ser::Serializer,
64 {
65 if serde_in_ipc_mode() {
66 let fd = self.0;
67 let idx = register_fd(fd);
68 idx.serialize(serializer)
69 } else {
70 Err(ser::Error::custom("can only serialize in ipc mode"))
71 }
72 }
73}
74
75impl<F: FromRawFd + IntoRawFd> Serialize for Handle<F> {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: ser::Serializer,
79 {
80 HandleRef(self.extract_raw_fd()).serialize(serializer)
81 }
82}
83
84impl<'de, F: FromRawFd + IntoRawFd> Deserialize<'de> for Handle<F> {
85 fn deserialize<D>(deserializer: D) -> Result<Handle<F>, D::Error>
86 where
87 D: de::Deserializer<'de>,
88 {
89 if serde_in_ipc_mode() {
90 let idx = u32::deserialize(deserializer)?;
91 let fd = lookup_fd(idx).ok_or_else(|| de::Error::custom("fd not found in mapping"))?;
92 unsafe { Ok(Handle(Mutex::new(Some(FromRawFd::from_raw_fd(fd))))) }
93 } else {
94 Err(de::Error::custom("can only deserialize in ipc mode"))
95 }
96 }
97}
98
99struct ResetIpcSerde;
100
101impl Drop for ResetIpcSerde {
102 fn drop(&mut self) {
103 IPC_FDS.with(|x| x.borrow_mut().pop());
104 }
105}
106
107fn enter_ipc_mode<F: FnOnce() -> R, R>(f: F, fds: &mut Vec<RawFd>) -> R {
108 IPC_FDS.with(|x| x.borrow_mut().push(fds.clone()));
109 let reset = ResetIpcSerde;
110 let rv = f();
111 *fds = IPC_FDS.with(|x| x.borrow_mut().pop()).unwrap_or_default();
112 mem::forget(reset);
113 rv
114}
115
116fn register_fd(fd: RawFd) -> u32 {
117 IPC_FDS.with(|x| {
118 let mut x = x.borrow_mut();
119 let fds = x.last_mut().unwrap();
120 let rv = fds.len() as u32;
121 fds.push(fd);
122 rv
123 })
124}
125
126fn lookup_fd(idx: u32) -> Option<RawFd> {
127 IPC_FDS.with(|x| x.borrow().last().and_then(|l| l.get(idx as usize).copied()))
128}
129
130pub fn serde_in_ipc_mode() -> bool {
135 IPC_FDS.with(|x| !x.borrow().is_empty())
136}
137
138#[allow(clippy::boxed_local)]
139fn bincode_to_io_error(err: bincode::Error) -> io::Error {
140 match *err {
141 bincode::ErrorKind::Io(err) => err,
142 err => io::Error::new(io::ErrorKind::Other, err.to_string()),
143 }
144}
145
146pub fn serialize<S: Serialize>(s: S) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
152 let mut fds = Vec::new();
153 let mut out = Vec::new();
154 enter_ipc_mode(|| bincode::serialize_into(&mut out, &s), &mut fds)
155 .map_err(bincode_to_io_error)?;
156 Ok((out, fds))
157}
158
159pub fn deserialize<D: DeserializeOwned>(bytes: &[u8], fds: &[RawFd]) -> io::Result<D> {
164 let mut fds = fds.to_owned();
165 let result =
166 enter_ipc_mode(|| bincode::deserialize(bytes), &mut fds).map_err(bincode_to_io_error)?;
167 Ok(result)
168}
169
170macro_rules! implement_handle_serialization {
171 ($ty:ty) => {
172 impl $crate::_serde_ref::Serialize for $ty {
173 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
174 where
175 S: $crate::_serde_ref::ser::Serializer,
176 {
177 $crate::_serde_ref::Serialize::serialize(
178 &$crate::HandleRef(self.extract_raw_fd()),
179 serializer,
180 )
181 }
182 }
183 impl<'de> Deserialize<'de> for $ty {
184 fn deserialize<D>(deserializer: D) -> Result<$ty, D::Error>
185 where
186 D: $crate::_serde_ref::de::Deserializer<'de>,
187 {
188 let handle: $crate::Handle<$ty> =
189 $crate::_serde_ref::Deserialize::deserialize(deserializer)?;
190 Ok(handle.into_inner())
191 }
192 }
193 };
194}
195
196implement_handle_serialization!(crate::RawSender);
197implement_handle_serialization!(crate::RawReceiver);
198
199macro_rules! implement_typed_handle_serialization {
200 ($ty:ty) => {
201 impl<T: Serialize + DeserializeOwned> $crate::_serde_ref::Serialize for $ty {
202 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
203 where
204 S: $crate::_serde_ref::ser::Serializer,
205 {
206 $crate::_serde_ref::Serialize::serialize(
207 &$crate::HandleRef(self.extract_raw_fd()),
208 serializer,
209 )
210 }
211 }
212 impl<'de, T: Serialize + DeserializeOwned> Deserialize<'de> for $ty {
213 fn deserialize<D>(deserializer: D) -> Result<$ty, D::Error>
214 where
215 D: $crate::_serde_ref::de::Deserializer<'de>,
216 {
217 let handle: $crate::Handle<$ty> =
218 $crate::_serde_ref::Deserialize::deserialize(deserializer)?;
219 Ok(handle.into_inner())
220 }
221 }
222 };
223}
224
225implement_typed_handle_serialization!(crate::Sender<T>);
226implement_typed_handle_serialization!(crate::Receiver<T>);
227
228#[test]
229fn test_basic() {
230 use std::io::Read;
231 let f = std::fs::File::open("src/serde.rs").unwrap();
232 let handle = Handle::from(f);
233 let (bytes, fds) = serialize(handle).unwrap();
234 let f2: Handle<std::fs::File> = deserialize(&bytes, &fds).unwrap();
235 let mut out = Vec::new();
236 f2.into_inner().read_to_end(&mut out).unwrap();
237 assert!(out.len() > 100);
238}