rustbus/wire/wrapper_types/
unixfd.rs

1use crate::wire::errors::MarshalError;
2use crate::wire::errors::UnmarshalError;
3use crate::wire::marshal::traits::SignatureBuffer;
4use crate::wire::marshal::MarshalContext;
5use crate::wire::unmarshal::UnmarshalContext;
6use crate::{Marshal, Signature, Unmarshal};
7
8use std::os::unix::io::RawFd;
9use std::sync::atomic::AtomicI32;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum DupError {
14    Nix(nix::Error),
15    AlreadyTaken,
16}
17
18#[derive(Debug)]
19struct UnixFdInner {
20    inner: AtomicI32,
21}
22impl Drop for UnixFdInner {
23    fn drop(&mut self) {
24        if let Some(fd) = self.take() {
25            nix::unistd::close(fd).ok();
26        }
27    }
28}
29
30impl UnixFdInner {
31    /// -1 seems like a good 'invalid' state for the atomici32
32    /// -1 is a common return value for operations that return FDs to signal an error occurance.
33    const FD_INVALID: RawFd = -1;
34
35    /// This is kinda like Cell::take it takes the FD and resets the atomic int to FD_INVALID which represents the invalid / taken state here.
36    fn take(&self) -> Option<RawFd> {
37        // load fd and see if it is already been taken
38        let loaded_fd: RawFd = self.inner.load(std::sync::atomic::Ordering::SeqCst);
39        if loaded_fd == Self::FD_INVALID {
40            None
41        } else {
42            //try to swap with FD_INVALID
43            let swapped_fd = self.inner.compare_exchange(
44                loaded_fd,
45                Self::FD_INVALID,
46                std::sync::atomic::Ordering::SeqCst,
47                std::sync::atomic::Ordering::SeqCst,
48            );
49            //  If swapped_fd == fd then we did a sucessful swap and we actually took the value
50            if let Ok(taken_fd) = swapped_fd {
51                Some(taken_fd)
52            } else {
53                None
54            }
55        }
56    }
57
58    /// This is kinda like Cell::get it returns the FD, FD_INVALID represents the invalid / taken state here.
59    fn get(&self) -> Option<RawFd> {
60        let loaded = self.inner.load(std::sync::atomic::Ordering::SeqCst);
61        if loaded == Self::FD_INVALID {
62            None
63        } else {
64            Some(loaded as RawFd)
65        }
66    }
67
68    /// Dup the underlying FD
69    fn dup(&self) -> Result<Self, DupError> {
70        let fd = match self.get() {
71            Some(fd) => fd,
72            None => return Err(DupError::AlreadyTaken),
73        };
74        match nix::unistd::dup(fd) {
75            Ok(new_fd) => Ok(Self {
76                inner: AtomicI32::new(new_fd),
77            }),
78            Err(e) => Err(DupError::Nix(e)),
79        }
80    }
81}
82
83/// UnixFd is a wrapper around RawFd, to ensure that opened FDs are closed again, while still having the possibility of having multiple references to it.
84///
85/// "Ownership" as in responsibility of closing the FD works as follows:
86/// 1. You can call take_raw_fd(). At this point UnixFd releases ownership. You are now responsible of closing the FD.
87/// 1. You can call get_raw_fd(). This will not release ownership, UnixFd will still close it if no more references to it exist.
88///
89/// ## UnixFds and messages
90/// 1. When a UnixFd is **marshalled** rustbus will dup() the FD so that the message and the original UnixFd do not depend on each others lifetime. You are free to use
91/// or close the original one.
92/// 1. When a UnixFd is **unmarshalled** rustbus will **NOT** dup() the FD. This means if you call take_raw_fd(), it is gone from the message too! If you do not want this,
93/// you have to call dup() and then get_raw_fd() or take_raw_fd()
94#[derive(Clone, Debug)]
95pub struct UnixFd(Arc<UnixFdInner>);
96impl UnixFd {
97    pub fn new(fd: RawFd) -> Self {
98        UnixFd(Arc::new(UnixFdInner {
99            inner: AtomicI32::new(fd),
100        }))
101    }
102    /// Gets a non-owning `RawFd`. If `None` is returned.
103    /// then this UnixFd has already been taken by somebody else
104    /// and is no longer valid.
105    pub fn get_raw_fd(&self) -> Option<RawFd> {
106        self.0.get()
107    }
108
109    /// Gets a owning `RawFd` from the UnixFd.
110    /// Subsequent attempt to get the `RawFd` from
111    /// other `UnixFd` referencing the same file descriptor will
112    /// fail.
113    pub fn take_raw_fd(self) -> Option<RawFd> {
114        self.0.take()
115    }
116
117    /// Duplicate the underlying FD so you can use it as you will. This is different from just calling
118    /// clone(). Clone only makes a new ref to the same underlying FD.
119    pub fn dup(&self) -> Result<Self, DupError> {
120        self.0.dup().map(|new_inner| Self(Arc::new(new_inner)))
121    }
122}
123/// Allow for the comparison of `UnixFd` even after the `RawFd`
124/// has been taken, to see if they originally referred to the same thing.
125impl PartialEq<UnixFd> for UnixFd {
126    fn eq(&self, other: &UnixFd) -> bool {
127        Arc::ptr_eq(&self.0, &other.0) || self.get_raw_fd() == other.get_raw_fd()
128    }
129}
130
131// These two impls are just there so that params::Base can derive Eq and Hash so they can be used as Keys
132// in dicts. This does not really make sense for unixfds (why would you use them as keys...) but the
133// contracts for Eq and Hash should be fulfilled by these impls.
134impl Eq for UnixFd {}
135impl std::hash::Hash for UnixFd {
136    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
137        state.write_i32(self.get_raw_fd().unwrap_or(0));
138    }
139}
140
141impl Signature for UnixFd {
142    fn signature() -> crate::signature::Type {
143        crate::signature::Type::Base(crate::signature::Base::UnixFd)
144    }
145    fn alignment() -> usize {
146        Self::signature().get_alignment()
147    }
148    #[inline]
149    fn sig_str(s_buf: &mut SignatureBuffer) {
150        s_buf.push_static("h");
151    }
152    fn has_sig(sig: &str) -> bool {
153        sig.starts_with('h')
154    }
155}
156impl Marshal for UnixFd {
157    fn marshal(&self, ctx: &mut MarshalContext) -> Result<(), MarshalError> {
158        crate::wire::util::marshal_unixfd(self, ctx)
159    }
160}
161impl Signature for &dyn std::os::unix::io::AsRawFd {
162    fn signature() -> crate::signature::Type {
163        UnixFd::signature()
164    }
165    fn alignment() -> usize {
166        UnixFd::alignment()
167    }
168    #[inline]
169    fn sig_str(s_buf: &mut SignatureBuffer) {
170        UnixFd::sig_str(s_buf)
171    }
172    fn has_sig(sig: &str) -> bool {
173        UnixFd::has_sig(sig)
174    }
175}
176impl Marshal for &dyn std::os::unix::io::AsRawFd {
177    fn marshal(&self, ctx: &mut MarshalContext) -> Result<(), MarshalError> {
178        let fd = self.as_raw_fd();
179        let new_fd = nix::unistd::dup(fd).map_err(MarshalError::DupUnixFd)?;
180        ctx.fds.push(UnixFd::new(new_fd));
181
182        let idx = ctx.fds.len() - 1;
183        ctx.align_to(Self::alignment());
184        crate::wire::util::write_u32(idx as u32, ctx.byteorder, ctx.buf);
185        Ok(())
186    }
187}
188
189impl<'buf, 'fds> Unmarshal<'buf, 'fds> for UnixFd {
190    fn unmarshal(
191        ctx: &mut UnmarshalContext<'fds, 'buf>,
192    ) -> crate::wire::unmarshal::UnmarshalResult<Self> {
193        let (bytes, idx) = u32::unmarshal(ctx)?;
194
195        if ctx.fds.len() <= idx as usize {
196            Err(UnmarshalError::BadFdIndex(idx as usize))
197        } else {
198            let val = &ctx.fds[idx as usize];
199            Ok((bytes, val.clone()))
200        }
201    }
202}
203
204#[test]
205fn test_fd_send() {
206    let x = UnixFd::new(nix::unistd::dup(1).unwrap());
207    std::thread::spawn(move || {
208        let _x = x.get_raw_fd();
209    });
210
211    let x = UnixFd::new(nix::unistd::dup(1).unwrap());
212    let fd = crate::params::Base::UnixFd(x);
213    std::thread::spawn(move || {
214        let _x = fd;
215    });
216}
217
218#[test]
219fn test_unix_fd() {
220    let fd = UnixFd::new(nix::unistd::dup(1).unwrap());
221    let _ = fd.get_raw_fd().unwrap();
222    let _ = fd.get_raw_fd().unwrap();
223    let _ = fd.clone().take_raw_fd().unwrap();
224    assert!(fd.get_raw_fd().is_none());
225    assert!(fd.take_raw_fd().is_none());
226}
227
228#[test]
229fn test_races_in_unixfd() {
230    let fd = UnixFd::new(nix::unistd::dup(1).unwrap());
231    let raw_fd = fd.get_raw_fd().unwrap();
232
233    const NUM_THREADS: usize = 20;
234    const NUM_RUNS: usize = 100;
235
236    let barrier = std::sync::Arc::new(std::sync::Barrier::new(NUM_THREADS + 1));
237
238    let result = std::sync::Arc::new(std::sync::Mutex::new(vec![false; NUM_THREADS]));
239
240    for _ in 0..NUM_RUNS {
241        for idx in 0..NUM_THREADS {
242            let local_fd = fd.clone();
243            let local_result = result.clone();
244            let local_barrier = barrier.clone();
245            std::thread::spawn(move || {
246                // wait for all other threads
247                local_barrier.wait();
248                if let Some(taken_fd) = local_fd.take_raw_fd() {
249                    assert_eq!(raw_fd, taken_fd);
250                    local_result.lock().unwrap()[idx] = true;
251                }
252                // wait for all other threads to finish so the main thread knows when to collect the results
253                local_barrier.wait();
254            });
255        }
256
257        // wait for all threads to be ready to take the fd all at once
258        barrier.wait();
259        // wait for all threads to finish
260        barrier.wait();
261        let result_iter = result.lock().unwrap();
262        assert_eq!(result_iter.iter().filter(|b| **b).count(), 1)
263    }
264}
265
266#[test]
267fn test_unixfd_dup() {
268    let fd = UnixFd::new(nix::unistd::dup(1).unwrap());
269    let fd2 = fd.dup().unwrap();
270    assert_ne!(fd.get_raw_fd().unwrap(), fd2.get_raw_fd().unwrap());
271
272    let _raw = fd.clone().take_raw_fd();
273    assert_eq!(fd.dup(), Err(DupError::AlreadyTaken));
274}