tun_rs/platform/unix/
interrupt.rs

1use crate::platform::unix::Fd;
2use std::io;
3use std::io::{IoSlice, IoSliceMut};
4use std::os::fd::AsRawFd;
5use std::sync::Mutex;
6
7impl Fd {
8    pub(crate) fn read_interruptible(
9        &self,
10        buf: &mut [u8],
11        event: &InterruptEvent,
12        timeout: Option<std::time::Duration>,
13    ) -> io::Result<usize> {
14        loop {
15            self.wait_readable_interruptible(event, timeout)?;
16            return match self.read(buf) {
17                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
18                    continue;
19                }
20                rs => rs,
21            };
22        }
23    }
24    pub(crate) fn readv_interruptible(
25        &self,
26        bufs: &mut [IoSliceMut<'_>],
27        event: &InterruptEvent,
28        timeout: Option<std::time::Duration>,
29    ) -> io::Result<usize> {
30        loop {
31            self.wait_readable_interruptible(event, timeout)?;
32            return match self.readv(bufs) {
33                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
34                    continue;
35                }
36
37                rs => rs,
38            };
39        }
40    }
41    pub(crate) fn write_interruptible(
42        &self,
43        buf: &[u8],
44        event: &InterruptEvent,
45    ) -> io::Result<usize> {
46        loop {
47            self.wait_writable_interruptible(event)?;
48            return match self.write(buf) {
49                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
50                    continue;
51                }
52                rs => rs,
53            };
54        }
55    }
56    pub fn writev_interruptible(
57        &self,
58        bufs: &[IoSlice<'_>],
59        event: &InterruptEvent,
60    ) -> io::Result<usize> {
61        loop {
62            self.wait_writable_interruptible(event)?;
63            return match self.writev(bufs) {
64                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
65                    continue;
66                }
67                rs => rs,
68            };
69        }
70    }
71    pub fn wait_readable_interruptible(
72        &self,
73        interrupted_event: &InterruptEvent,
74        timeout: Option<std::time::Duration>,
75    ) -> io::Result<()> {
76        let fd = self.as_raw_fd() as libc::c_int;
77        let event_fd = interrupted_event.as_event_fd();
78
79        let mut fds = [
80            libc::pollfd {
81                fd,
82                events: libc::POLLIN,
83                revents: 0,
84            },
85            libc::pollfd {
86                fd: event_fd,
87                events: libc::POLLIN,
88                revents: 0,
89            },
90        ];
91
92        let result = unsafe {
93            libc::poll(
94                fds.as_mut_ptr(),
95                fds.len() as libc::nfds_t,
96                timeout
97                    .map(|t| t.as_millis().min(i32::MAX as _) as _)
98                    .unwrap_or(-1),
99            )
100        };
101
102        if result == -1 {
103            return Err(io::Error::last_os_error());
104        }
105        if result == 0 {
106            return Err(io::Error::from(io::ErrorKind::TimedOut));
107        }
108        if fds[0].revents & libc::POLLIN != 0 {
109            return Ok(());
110        }
111
112        if fds[1].revents & libc::POLLIN != 0 {
113            return Err(io::Error::new(
114                io::ErrorKind::Interrupted,
115                "trigger interrupt",
116            ));
117        }
118
119        Err(io::Error::other("fd error"))
120    }
121    pub fn wait_writable_interruptible(
122        &self,
123        interrupted_event: &InterruptEvent,
124    ) -> io::Result<()> {
125        let fd = self.as_raw_fd() as libc::c_int;
126        let event_fd = interrupted_event.as_event_fd();
127
128        let mut fds = [
129            libc::pollfd {
130                fd,
131                events: libc::POLLOUT,
132                revents: 0,
133            },
134            libc::pollfd {
135                fd: event_fd,
136                events: libc::POLLIN,
137                revents: 0,
138            },
139        ];
140
141        let result = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, -1) };
142
143        if result == -1 {
144            return Err(io::Error::last_os_error());
145        }
146        if fds[0].revents & libc::POLLOUT != 0 {
147            return Ok(());
148        }
149
150        if fds[1].revents & libc::POLLIN != 0 {
151            return Err(io::Error::new(
152                io::ErrorKind::Interrupted,
153                "trigger interrupt",
154            ));
155        }
156
157        Err(io::Error::other("fd error"))
158    }
159}
160
161#[cfg(target_os = "macos")]
162impl Fd {
163    pub fn wait_writable(
164        &self,
165        interrupt_event: Option<&InterruptEvent>,
166        timeout: Option<std::time::Duration>,
167    ) -> io::Result<()> {
168        let readfds: libc::fd_set = unsafe { std::mem::zeroed() };
169        let mut writefds: libc::fd_set = unsafe { std::mem::zeroed() };
170        let fd = self.as_raw_fd();
171        unsafe {
172            libc::FD_SET(fd, &mut writefds);
173        }
174        self.wait(readfds, Some(writefds), interrupt_event, timeout)
175    }
176    pub fn wait_readable(
177        &self,
178        interrupt_event: Option<&InterruptEvent>,
179        timeout: Option<std::time::Duration>,
180    ) -> io::Result<()> {
181        let mut readfds: libc::fd_set = unsafe { std::mem::zeroed() };
182        let fd = self.as_raw_fd();
183        unsafe {
184            libc::FD_SET(fd, &mut readfds);
185        }
186        self.wait(readfds, None, interrupt_event, timeout)
187    }
188    fn wait(
189        &self,
190        mut readfds: libc::fd_set,
191        mut writefds: Option<libc::fd_set>,
192        interrupt_event: Option<&InterruptEvent>,
193        timeout: Option<std::time::Duration>,
194    ) -> io::Result<()> {
195        let fd = self.as_raw_fd();
196        let mut errorfds: libc::fd_set = unsafe { std::mem::zeroed() };
197        let mut nfds = fd;
198
199        if let Some(interrupt_event) = interrupt_event {
200            unsafe {
201                libc::FD_SET(interrupt_event.as_event_fd(), &mut readfds);
202            }
203            nfds = nfds.max(interrupt_event.as_event_fd());
204        }
205        let mut tv = libc::timeval {
206            tv_sec: 0,
207            tv_usec: 0,
208        };
209        let tv_ptr = if let Some(timeout) = timeout {
210            let secs = timeout.as_secs().min(libc::time_t::MAX as u64) as libc::time_t;
211            let usecs = (timeout.subsec_micros()) as libc::suseconds_t;
212            tv.tv_sec = secs;
213            tv.tv_usec = usecs;
214            &mut tv as *mut libc::timeval
215        } else {
216            std::ptr::null_mut()
217        };
218
219        let result = unsafe {
220            libc::select(
221                nfds + 1,
222                &mut readfds,
223                writefds
224                    .as_mut()
225                    .map(|p| p as *mut _)
226                    .unwrap_or_else(std::ptr::null_mut),
227                &mut errorfds,
228                tv_ptr,
229            )
230        };
231        if result < 0 {
232            return Err(io::Error::last_os_error());
233        }
234        if result == 0 {
235            return Err(io::Error::from(io::ErrorKind::TimedOut));
236        }
237        unsafe {
238            if let Some(cancel_event) = interrupt_event {
239                if libc::FD_ISSET(cancel_event.as_event_fd(), &readfds) {
240                    return Err(io::Error::new(
241                        io::ErrorKind::Interrupted,
242                        "trigger interrupt",
243                    ));
244                }
245            }
246        }
247        Ok(())
248    }
249}
250pub struct InterruptEvent {
251    state: Mutex<i32>,
252    read_fd: Fd,
253    write_fd: Fd,
254}
255impl InterruptEvent {
256    pub fn new() -> io::Result<Self> {
257        let mut fds: [libc::c_int; 2] = [0; 2];
258
259        unsafe {
260            if libc::pipe(fds.as_mut_ptr()) == -1 {
261                return Err(io::Error::last_os_error());
262            }
263            let read_fd = Fd::new_unchecked(fds[0]);
264            let write_fd = Fd::new_unchecked(fds[1]);
265            write_fd.set_nonblocking(true)?;
266            read_fd.set_nonblocking(true)?;
267            Ok(Self {
268                state: Mutex::new(0),
269                read_fd,
270                write_fd,
271            })
272        }
273    }
274    pub fn trigger(&self) -> io::Result<()> {
275        self.trigger_value(1)
276    }
277    pub fn trigger_value(&self, val: i32) -> io::Result<()> {
278        if val == 0 {
279            return Err(io::Error::new(
280                io::ErrorKind::InvalidInput,
281                "value cannot be 0",
282            ));
283        }
284        let mut guard = self.state.lock().unwrap();
285        if *guard != 0 {
286            return Ok(());
287        }
288        *guard = val;
289        let buf: [u8; 8] = 1u64.to_ne_bytes();
290        let res = unsafe {
291            libc::write(
292                self.write_fd.as_raw_fd(),
293                buf.as_ptr() as *const _,
294                buf.len(),
295            )
296        };
297        if res == -1 {
298            let e = io::Error::last_os_error();
299            if e.kind() == io::ErrorKind::WouldBlock {
300                return Ok(());
301            }
302            Err(e)
303        } else {
304            Ok(())
305        }
306    }
307    pub fn is_trigger(&self) -> bool {
308        *self.state.lock().unwrap() != 0
309    }
310    pub fn value(&self) -> i32 {
311        *self.state.lock().unwrap()
312    }
313    pub fn reset(&self) -> io::Result<()> {
314        let mut buf = [0; 8];
315        let mut guard = self.state.lock().unwrap();
316        *guard = 0;
317        loop {
318            unsafe {
319                let res = libc::read(
320                    self.read_fd.as_raw_fd(),
321                    buf.as_mut_ptr() as *mut _,
322                    buf.len(),
323                );
324                if res == -1 {
325                    let error = io::Error::last_os_error();
326                    return if error.kind() == io::ErrorKind::WouldBlock {
327                        Ok(())
328                    } else {
329                        Err(error)
330                    };
331                }
332            }
333        }
334    }
335    fn as_event_fd(&self) -> libc::c_int {
336        self.read_fd.as_raw_fd()
337    }
338}