rustbus/wire/wrapper_types/
unixfd.rs1use crate::wire::errors::MarshalError;
2use crate::wire::errors::UnmarshalError;
3use crate::wire::marshal::traits::SignatureBuffer;
4use crate::wire::marshal::MarshalContext;
5use crate::wire::unmarshal::UnmarshalContext;
6use crate::{Marshal, Signature, Unmarshal};
7
8use std::os::unix::io::RawFd;
9use std::sync::atomic::AtomicI32;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum DupError {
14 Nix(nix::Error),
15 AlreadyTaken,
16}
17
18#[derive(Debug)]
19struct UnixFdInner {
20 inner: AtomicI32,
21}
22impl Drop for UnixFdInner {
23 fn drop(&mut self) {
24 if let Some(fd) = self.take() {
25 nix::unistd::close(fd).ok();
26 }
27 }
28}
29
30impl UnixFdInner {
31 const FD_INVALID: RawFd = -1;
34
35 fn take(&self) -> Option<RawFd> {
37 let loaded_fd: RawFd = self.inner.load(std::sync::atomic::Ordering::SeqCst);
39 if loaded_fd == Self::FD_INVALID {
40 None
41 } else {
42 let swapped_fd = self.inner.compare_exchange(
44 loaded_fd,
45 Self::FD_INVALID,
46 std::sync::atomic::Ordering::SeqCst,
47 std::sync::atomic::Ordering::SeqCst,
48 );
49 if let Ok(taken_fd) = swapped_fd {
51 Some(taken_fd)
52 } else {
53 None
54 }
55 }
56 }
57
58 fn get(&self) -> Option<RawFd> {
60 let loaded = self.inner.load(std::sync::atomic::Ordering::SeqCst);
61 if loaded == Self::FD_INVALID {
62 None
63 } else {
64 Some(loaded as RawFd)
65 }
66 }
67
68 fn dup(&self) -> Result<Self, DupError> {
70 let fd = match self.get() {
71 Some(fd) => fd,
72 None => return Err(DupError::AlreadyTaken),
73 };
74 match nix::unistd::dup(fd) {
75 Ok(new_fd) => Ok(Self {
76 inner: AtomicI32::new(new_fd),
77 }),
78 Err(e) => Err(DupError::Nix(e)),
79 }
80 }
81}
82
83#[derive(Clone, Debug)]
95pub struct UnixFd(Arc<UnixFdInner>);
96impl UnixFd {
97 pub fn new(fd: RawFd) -> Self {
98 UnixFd(Arc::new(UnixFdInner {
99 inner: AtomicI32::new(fd),
100 }))
101 }
102 pub fn get_raw_fd(&self) -> Option<RawFd> {
106 self.0.get()
107 }
108
109 pub fn take_raw_fd(self) -> Option<RawFd> {
114 self.0.take()
115 }
116
117 pub fn dup(&self) -> Result<Self, DupError> {
120 self.0.dup().map(|new_inner| Self(Arc::new(new_inner)))
121 }
122}
123impl PartialEq<UnixFd> for UnixFd {
126 fn eq(&self, other: &UnixFd) -> bool {
127 Arc::ptr_eq(&self.0, &other.0) || self.get_raw_fd() == other.get_raw_fd()
128 }
129}
130
131impl Eq for UnixFd {}
135impl std::hash::Hash for UnixFd {
136 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
137 state.write_i32(self.get_raw_fd().unwrap_or(0));
138 }
139}
140
141impl Signature for UnixFd {
142 fn signature() -> crate::signature::Type {
143 crate::signature::Type::Base(crate::signature::Base::UnixFd)
144 }
145 fn alignment() -> usize {
146 Self::signature().get_alignment()
147 }
148 #[inline]
149 fn sig_str(s_buf: &mut SignatureBuffer) {
150 s_buf.push_static("h");
151 }
152 fn has_sig(sig: &str) -> bool {
153 sig.starts_with('h')
154 }
155}
156impl Marshal for UnixFd {
157 fn marshal(&self, ctx: &mut MarshalContext) -> Result<(), MarshalError> {
158 crate::wire::util::marshal_unixfd(self, ctx)
159 }
160}
161impl Signature for &dyn std::os::unix::io::AsRawFd {
162 fn signature() -> crate::signature::Type {
163 UnixFd::signature()
164 }
165 fn alignment() -> usize {
166 UnixFd::alignment()
167 }
168 #[inline]
169 fn sig_str(s_buf: &mut SignatureBuffer) {
170 UnixFd::sig_str(s_buf)
171 }
172 fn has_sig(sig: &str) -> bool {
173 UnixFd::has_sig(sig)
174 }
175}
176impl Marshal for &dyn std::os::unix::io::AsRawFd {
177 fn marshal(&self, ctx: &mut MarshalContext) -> Result<(), MarshalError> {
178 let fd = self.as_raw_fd();
179 let new_fd = nix::unistd::dup(fd).map_err(MarshalError::DupUnixFd)?;
180 ctx.fds.push(UnixFd::new(new_fd));
181
182 let idx = ctx.fds.len() - 1;
183 ctx.align_to(Self::alignment());
184 crate::wire::util::write_u32(idx as u32, ctx.byteorder, ctx.buf);
185 Ok(())
186 }
187}
188
189impl<'buf, 'fds> Unmarshal<'buf, 'fds> for UnixFd {
190 fn unmarshal(
191 ctx: &mut UnmarshalContext<'fds, 'buf>,
192 ) -> crate::wire::unmarshal::UnmarshalResult<Self> {
193 let (bytes, idx) = u32::unmarshal(ctx)?;
194
195 if ctx.fds.len() <= idx as usize {
196 Err(UnmarshalError::BadFdIndex(idx as usize))
197 } else {
198 let val = &ctx.fds[idx as usize];
199 Ok((bytes, val.clone()))
200 }
201 }
202}
203
204#[test]
205fn test_fd_send() {
206 let x = UnixFd::new(nix::unistd::dup(1).unwrap());
207 std::thread::spawn(move || {
208 let _x = x.get_raw_fd();
209 });
210
211 let x = UnixFd::new(nix::unistd::dup(1).unwrap());
212 let fd = crate::params::Base::UnixFd(x);
213 std::thread::spawn(move || {
214 let _x = fd;
215 });
216}
217
218#[test]
219fn test_unix_fd() {
220 let fd = UnixFd::new(nix::unistd::dup(1).unwrap());
221 let _ = fd.get_raw_fd().unwrap();
222 let _ = fd.get_raw_fd().unwrap();
223 let _ = fd.clone().take_raw_fd().unwrap();
224 assert!(fd.get_raw_fd().is_none());
225 assert!(fd.take_raw_fd().is_none());
226}
227
228#[test]
229fn test_races_in_unixfd() {
230 let fd = UnixFd::new(nix::unistd::dup(1).unwrap());
231 let raw_fd = fd.get_raw_fd().unwrap();
232
233 const NUM_THREADS: usize = 20;
234 const NUM_RUNS: usize = 100;
235
236 let barrier = std::sync::Arc::new(std::sync::Barrier::new(NUM_THREADS + 1));
237
238 let result = std::sync::Arc::new(std::sync::Mutex::new(vec![false; NUM_THREADS]));
239
240 for _ in 0..NUM_RUNS {
241 for idx in 0..NUM_THREADS {
242 let local_fd = fd.clone();
243 let local_result = result.clone();
244 let local_barrier = barrier.clone();
245 std::thread::spawn(move || {
246 local_barrier.wait();
248 if let Some(taken_fd) = local_fd.take_raw_fd() {
249 assert_eq!(raw_fd, taken_fd);
250 local_result.lock().unwrap()[idx] = true;
251 }
252 local_barrier.wait();
254 });
255 }
256
257 barrier.wait();
259 barrier.wait();
261 let result_iter = result.lock().unwrap();
262 assert_eq!(result_iter.iter().filter(|b| **b).count(), 1)
263 }
264}
265
266#[test]
267fn test_unixfd_dup() {
268 let fd = UnixFd::new(nix::unistd::dup(1).unwrap());
269 let fd2 = fd.dup().unwrap();
270 assert_ne!(fd.get_raw_fd().unwrap(), fd2.get_raw_fd().unwrap());
271
272 let _raw = fd.clone().take_raw_fd();
273 assert_eq!(fd.dup(), Err(DupError::AlreadyTaken));
274}