tun_rs/platform/unix/
interrupt.rs

1use crate::platform::unix::Fd;
2use std::io;
3use std::io::{IoSlice, IoSliceMut};
4use std::os::fd::AsRawFd;
5use std::sync::Mutex;
6
7impl Fd {
8    pub(crate) fn read_interruptible(
9        &self,
10        buf: &mut [u8],
11        event: &InterruptEvent,
12        timeout: Option<std::time::Duration>,
13    ) -> io::Result<usize> {
14        loop {
15            self.wait_readable_interruptible(event, timeout)?;
16            return match self.read(buf) {
17                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
18                    continue;
19                }
20                rs => rs,
21            };
22        }
23    }
24    pub(crate) fn readv_interruptible(
25        &self,
26        bufs: &mut [IoSliceMut<'_>],
27        event: &InterruptEvent,
28        timeout: Option<std::time::Duration>,
29    ) -> io::Result<usize> {
30        loop {
31            self.wait_readable_interruptible(event, timeout)?;
32            return match self.readv(bufs) {
33                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
34                    continue;
35                }
36
37                rs => rs,
38            };
39        }
40    }
41    pub(crate) fn write_interruptible(
42        &self,
43        buf: &[u8],
44        event: &InterruptEvent,
45    ) -> io::Result<usize> {
46        loop {
47            self.wait_writable_interruptible(event)?;
48            return match self.write(buf) {
49                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
50                    continue;
51                }
52                rs => rs,
53            };
54        }
55    }
56    pub fn writev_interruptible(
57        &self,
58        bufs: &[IoSlice<'_>],
59        event: &InterruptEvent,
60    ) -> io::Result<usize> {
61        loop {
62            self.wait_writable_interruptible(event)?;
63            return match self.writev(bufs) {
64                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
65                    continue;
66                }
67                rs => rs,
68            };
69        }
70    }
71    pub fn wait_readable_interruptible(
72        &self,
73        interrupted_event: &InterruptEvent,
74        timeout: Option<std::time::Duration>,
75    ) -> io::Result<()> {
76        let fd = self.as_raw_fd() as libc::c_int;
77        let event_fd = interrupted_event.as_event_fd();
78
79        let mut fds = [
80            libc::pollfd {
81                fd,
82                events: libc::POLLIN,
83                revents: 0,
84            },
85            libc::pollfd {
86                fd: event_fd,
87                events: libc::POLLIN,
88                revents: 0,
89            },
90        ];
91
92        let result = unsafe {
93            libc::poll(
94                fds.as_mut_ptr(),
95                fds.len() as libc::nfds_t,
96                timeout
97                    .map(|t| t.as_millis().min(i32::MAX as _) as _)
98                    .unwrap_or(-1),
99            )
100        };
101
102        if result == -1 {
103            return Err(io::Error::last_os_error());
104        }
105        if result == 0 {
106            return Err(io::Error::from(io::ErrorKind::TimedOut));
107        }
108        if fds[0].revents & libc::POLLIN != 0 {
109            return Ok(());
110        }
111
112        if fds[1].revents & libc::POLLIN != 0 {
113            return Err(io::Error::new(
114                io::ErrorKind::Interrupted,
115                "trigger interrupt",
116            ));
117        }
118
119        Err(io::Error::other("fd error"))
120    }
121    pub fn wait_writable_interruptible(
122        &self,
123        interrupted_event: &InterruptEvent,
124    ) -> io::Result<()> {
125        let fd = self.as_raw_fd() as libc::c_int;
126        let event_fd = interrupted_event.as_event_fd();
127
128        let mut fds = [
129            libc::pollfd {
130                fd,
131                events: libc::POLLOUT,
132                revents: 0,
133            },
134            libc::pollfd {
135                fd: event_fd,
136                events: libc::POLLIN,
137                revents: 0,
138            },
139        ];
140
141        let result = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, -1) };
142
143        if result == -1 {
144            return Err(io::Error::last_os_error());
145        }
146        if fds[0].revents & libc::POLLOUT != 0 {
147            return Ok(());
148        }
149
150        if fds[1].revents & libc::POLLIN != 0 {
151            return Err(io::Error::new(
152                io::ErrorKind::Interrupted,
153                "trigger interrupt",
154            ));
155        }
156
157        Err(io::Error::other("fd error"))
158    }
159}
160
161#[cfg(target_os = "macos")]
162impl Fd {
163    pub fn wait_writable(
164        &self,
165        interrupt_event: Option<&InterruptEvent>,
166        timeout: Option<std::time::Duration>,
167    ) -> io::Result<()> {
168        let readfds: libc::fd_set = unsafe { std::mem::zeroed() };
169        let mut writefds: libc::fd_set = unsafe { std::mem::zeroed() };
170        let fd = self.as_raw_fd();
171        unsafe {
172            libc::FD_SET(fd, &mut writefds);
173        }
174        self.wait(readfds, Some(writefds), interrupt_event, timeout)
175    }
176    pub fn wait_readable(
177        &self,
178        interrupt_event: Option<&InterruptEvent>,
179        timeout: Option<std::time::Duration>,
180    ) -> io::Result<()> {
181        let mut readfds: libc::fd_set = unsafe { std::mem::zeroed() };
182        let fd = self.as_raw_fd();
183        unsafe {
184            libc::FD_SET(fd, &mut readfds);
185        }
186        self.wait(readfds, None, interrupt_event, timeout)
187    }
188    fn wait(
189        &self,
190        mut readfds: libc::fd_set,
191        mut writefds: Option<libc::fd_set>,
192        interrupt_event: Option<&InterruptEvent>,
193        timeout: Option<std::time::Duration>,
194    ) -> io::Result<()> {
195        let fd = self.as_raw_fd();
196        let mut errorfds: libc::fd_set = unsafe { std::mem::zeroed() };
197        let mut nfds = fd;
198
199        if let Some(interrupt_event) = interrupt_event {
200            unsafe {
201                libc::FD_SET(interrupt_event.as_event_fd(), &mut readfds);
202            }
203            nfds = nfds.max(interrupt_event.as_event_fd());
204        }
205        let mut tv = libc::timeval {
206            tv_sec: 0,
207            tv_usec: 0,
208        };
209        let tv_ptr = if let Some(timeout) = timeout {
210            let secs = timeout.as_secs().min(libc::time_t::MAX as u64) as libc::time_t;
211            let usecs = (timeout.subsec_micros()) as libc::suseconds_t;
212            tv.tv_sec = secs;
213            tv.tv_usec = usecs;
214            &mut tv as *mut libc::timeval
215        } else {
216            std::ptr::null_mut()
217        };
218
219        let result = unsafe {
220            libc::select(
221                nfds + 1,
222                &mut readfds,
223                writefds.as_mut().map_or_else(std::ptr::null_mut, |p| p),
224                &mut errorfds,
225                tv_ptr,
226            )
227        };
228        if result < 0 {
229            return Err(io::Error::last_os_error());
230        }
231        if result == 0 {
232            return Err(io::Error::from(io::ErrorKind::TimedOut));
233        }
234        unsafe {
235            if let Some(cancel_event) = interrupt_event {
236                if libc::FD_ISSET(cancel_event.as_event_fd(), &readfds) {
237                    return Err(io::Error::new(
238                        io::ErrorKind::Interrupted,
239                        "trigger interrupt",
240                    ));
241                }
242            }
243        }
244        Ok(())
245    }
246}
247pub struct InterruptEvent {
248    state: Mutex<i32>,
249    read_fd: Fd,
250    write_fd: Fd,
251}
252impl InterruptEvent {
253    pub fn new() -> io::Result<Self> {
254        let mut fds: [libc::c_int; 2] = [0; 2];
255
256        unsafe {
257            if libc::pipe(fds.as_mut_ptr()) == -1 {
258                return Err(io::Error::last_os_error());
259            }
260            let read_fd = Fd::new_unchecked(fds[0]);
261            let write_fd = Fd::new_unchecked(fds[1]);
262            write_fd.set_nonblocking(true)?;
263            read_fd.set_nonblocking(true)?;
264            Ok(Self {
265                state: Mutex::new(0),
266                read_fd,
267                write_fd,
268            })
269        }
270    }
271    pub fn trigger(&self) -> io::Result<()> {
272        self.trigger_value(1)
273    }
274    pub fn trigger_value(&self, val: i32) -> io::Result<()> {
275        if val == 0 {
276            return Err(io::Error::new(
277                io::ErrorKind::InvalidInput,
278                "value cannot be 0",
279            ));
280        }
281        let mut guard = self.state.lock().unwrap();
282        if *guard != 0 {
283            return Ok(());
284        }
285        *guard = val;
286        let buf: [u8; 8] = 1u64.to_ne_bytes();
287        let res = unsafe {
288            libc::write(
289                self.write_fd.as_raw_fd(),
290                buf.as_ptr() as *const _,
291                buf.len(),
292            )
293        };
294        if res == -1 {
295            let e = io::Error::last_os_error();
296            if e.kind() == io::ErrorKind::WouldBlock {
297                return Ok(());
298            }
299            Err(e)
300        } else {
301            Ok(())
302        }
303    }
304    pub fn is_trigger(&self) -> bool {
305        *self.state.lock().unwrap() != 0
306    }
307    pub fn value(&self) -> i32 {
308        *self.state.lock().unwrap()
309    }
310    pub fn reset(&self) -> io::Result<()> {
311        let mut buf = [0; 8];
312        let mut guard = self.state.lock().unwrap();
313        *guard = 0;
314        loop {
315            unsafe {
316                let res = libc::read(
317                    self.read_fd.as_raw_fd(),
318                    buf.as_mut_ptr() as *mut _,
319                    buf.len(),
320                );
321                if res == -1 {
322                    let error = io::Error::last_os_error();
323                    return if error.kind() == io::ErrorKind::WouldBlock {
324                        Ok(())
325                    } else {
326                        Err(error)
327                    };
328                }
329            }
330        }
331    }
332    fn as_event_fd(&self) -> libc::c_int {
333        self.read_fd.as_raw_fd()
334    }
335}