tun_rs/platform/unix/
interrupt.rs

1use crate::platform::unix::Fd;
2use std::io;
3use std::io::{IoSlice, IoSliceMut};
4use std::os::fd::AsRawFd;
5use std::sync::Mutex;
6
7impl Fd {
8    pub(crate) fn read_interruptible(
9        &self,
10        buf: &mut [u8],
11        event: &InterruptEvent,
12    ) -> io::Result<usize> {
13        loop {
14            self.wait_readable_interruptible(event)?;
15            return match self.read(buf) {
16                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
17                    continue;
18                }
19                rs => rs,
20            };
21        }
22    }
23    pub(crate) fn readv_interruptible(
24        &self,
25        bufs: &mut [IoSliceMut<'_>],
26        event: &InterruptEvent,
27    ) -> io::Result<usize> {
28        loop {
29            self.wait_readable_interruptible(event)?;
30            return match self.readv(bufs) {
31                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
32                    continue;
33                }
34
35                rs => rs,
36            };
37        }
38    }
39    pub(crate) fn write_interruptible(
40        &self,
41        buf: &[u8],
42        event: &InterruptEvent,
43    ) -> io::Result<usize> {
44        loop {
45            self.wait_writable_interruptible(event)?;
46            return match self.write(buf) {
47                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
48                    continue;
49                }
50                rs => rs,
51            };
52        }
53    }
54    pub fn writev_interruptible(
55        &self,
56        bufs: &[IoSlice<'_>],
57        event: &InterruptEvent,
58    ) -> io::Result<usize> {
59        loop {
60            self.wait_writable_interruptible(event)?;
61            return match self.writev(bufs) {
62                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
63                    continue;
64                }
65                rs => rs,
66            };
67        }
68    }
69    pub fn wait_readable_interruptible(
70        &self,
71        interrupted_event: &InterruptEvent,
72    ) -> io::Result<()> {
73        let fd = self.as_raw_fd() as libc::c_int;
74        let event_fd = interrupted_event.as_event_fd();
75
76        let mut fds = [
77            libc::pollfd {
78                fd,
79                events: libc::POLLIN,
80                revents: 0,
81            },
82            libc::pollfd {
83                fd: event_fd,
84                events: libc::POLLIN,
85                revents: 0,
86            },
87        ];
88
89        let result = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, -1) };
90
91        if result == -1 {
92            return Err(io::Error::last_os_error());
93        }
94        if fds[0].revents & libc::POLLIN != 0 {
95            return Ok(());
96        }
97
98        if fds[1].revents & libc::POLLIN != 0 {
99            return Err(io::Error::new(
100                io::ErrorKind::Interrupted,
101                "trigger interrupt",
102            ));
103        }
104
105        Err(io::Error::other("fd error"))
106    }
107    pub fn wait_writable_interruptible(
108        &self,
109        interrupted_event: &InterruptEvent,
110    ) -> io::Result<()> {
111        let fd = self.as_raw_fd() as libc::c_int;
112        let event_fd = interrupted_event.as_event_fd();
113
114        let mut fds = [
115            libc::pollfd {
116                fd,
117                events: libc::POLLOUT,
118                revents: 0,
119            },
120            libc::pollfd {
121                fd: event_fd,
122                events: libc::POLLIN,
123                revents: 0,
124            },
125        ];
126
127        let result = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, -1) };
128
129        if result == -1 {
130            return Err(io::Error::last_os_error());
131        }
132        if fds[0].revents & libc::POLLOUT != 0 {
133            return Ok(());
134        }
135
136        if fds[1].revents & libc::POLLIN != 0 {
137            return Err(io::Error::new(
138                io::ErrorKind::Interrupted,
139                "trigger interrupt",
140            ));
141        }
142
143        Err(io::Error::other("fd error"))
144    }
145}
146
147pub struct InterruptEvent {
148    state: Mutex<bool>,
149    read_fd: Fd,
150    write_fd: Fd,
151}
152impl InterruptEvent {
153    pub fn new() -> io::Result<Self> {
154        let mut fds: [libc::c_int; 2] = [0; 2];
155
156        unsafe {
157            if libc::pipe(fds.as_mut_ptr()) == -1 {
158                return Err(io::Error::last_os_error());
159            }
160            let read_fd = Fd::new_unchecked(fds[0]);
161            let write_fd = Fd::new_unchecked(fds[1]);
162            read_fd.set_nonblocking(true)?;
163            Ok(Self {
164                state: Default::default(),
165                read_fd,
166                write_fd,
167            })
168        }
169    }
170    pub fn trigger(&self) -> io::Result<()> {
171        let mut guard = self.state.lock().unwrap();
172        *guard = true;
173        let buf: [u8; 8] = 1u64.to_ne_bytes();
174        let res = unsafe {
175            libc::write(
176                self.write_fd.as_raw_fd(),
177                buf.as_ptr() as *const _,
178                buf.len(),
179            )
180        };
181        if res == -1 {
182            Err(io::Error::last_os_error())
183        } else {
184            Ok(())
185        }
186    }
187    pub fn is_trigger(&self) -> bool {
188        *self.state.lock().unwrap()
189    }
190    pub fn reset(&self) -> io::Result<()> {
191        let mut buf = [0; 8];
192        let mut guard = self.state.lock().unwrap();
193        *guard = false;
194        loop {
195            unsafe {
196                let res = libc::read(
197                    self.read_fd.as_raw_fd(),
198                    buf.as_mut_ptr() as *mut _,
199                    buf.len(),
200                );
201                if res == -1 {
202                    let error = io::Error::last_os_error();
203                    return if error.kind() == io::ErrorKind::WouldBlock {
204                        Ok(())
205                    } else {
206                        Err(error)
207                    };
208                }
209            }
210        }
211    }
212    fn as_event_fd(&self) -> libc::c_int {
213        self.read_fd.as_raw_fd()
214    }
215}