shm_primitives/unix/
doorbell.rs1use std::io::{self, ErrorKind};
7use std::os::unix::io::{AsRawFd, FromRawFd, OwnedFd, RawFd};
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use tokio::io::Interest;
11use tokio::io::unix::AsyncFd;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SignalResult {
16 Sent,
18 BufferFull,
20 PeerDead,
22}
23
24#[derive(Debug)]
32pub struct DoorbellHandle(OwnedFd);
33
34impl DoorbellHandle {
35 pub fn as_raw_fd(&self) -> RawFd {
37 self.0.as_raw_fd()
38 }
39
40 pub unsafe fn from_raw_fd(fd: RawFd) -> Self {
45 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
47 Self(fd)
48 }
49
50 pub fn to_arg(&self) -> String {
52 self.0.as_raw_fd().to_string()
53 }
54
55 pub unsafe fn from_arg(s: &str) -> Result<Self, std::num::ParseIntError> {
61 let fd: RawFd = s.parse()?;
62 let handle = unsafe { Self::from_raw_fd(fd) };
63 Ok(handle)
64 }
65
66 pub const ARG_NAME: &'static str = "--doorbell-fd";
68}
69
70pub struct Doorbell {
75 async_fd: AsyncFd<OwnedFd>,
76 peer_dead_logged: AtomicBool,
78}
79
80fn drain_fd(fd: RawFd, would_block_is_error: bool) -> io::Result<bool> {
81 let mut buf = [0u8; 64];
82 let mut drained = false;
83
84 loop {
85 let ret = unsafe { libc::recv(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0) };
86
87 if ret > 0 {
88 drained = true;
89 continue;
90 }
91
92 if ret == 0 {
93 return Ok(drained);
94 }
95
96 let err = io::Error::last_os_error();
97 if err.kind() == ErrorKind::WouldBlock {
98 if drained {
99 return Ok(true);
100 }
101 return if would_block_is_error {
102 Err(err)
103 } else {
104 Ok(false)
105 };
106 }
107
108 return Err(err);
109 }
110}
111
112impl Doorbell {
113 pub fn create_pair() -> io::Result<(Self, DoorbellHandle)> {
118 let (host_fd, peer_fd) = create_socketpair()?;
119
120 set_nonblocking(host_fd.as_raw_fd())?;
121
122 let async_fd = AsyncFd::new(host_fd)?;
123
124 Ok((
125 Self {
126 async_fd,
127 peer_dead_logged: AtomicBool::new(false),
128 },
129 DoorbellHandle(peer_fd),
130 ))
131 }
132
133 pub fn from_handle(handle: DoorbellHandle) -> io::Result<Self> {
138 use std::os::unix::io::IntoRawFd;
139 Self::from_raw_fd(handle.0.into_raw_fd())
140 }
141
142 pub fn from_raw_fd(fd: RawFd) -> io::Result<Self> {
150 let owned = unsafe { OwnedFd::from_raw_fd(fd) };
151 set_nonblocking(fd)?;
152 let async_fd = AsyncFd::new(owned)?;
153 Ok(Self {
154 async_fd,
155 peer_dead_logged: AtomicBool::new(false),
156 })
157 }
158
159 pub async fn signal(&self) -> SignalResult {
168 let fd = self.async_fd.get_ref().as_raw_fd();
169 let buf = [1u8];
170
171 let ret = unsafe {
172 libc::send(
173 fd,
174 buf.as_ptr() as *const libc::c_void,
175 buf.len(),
176 libc::MSG_DONTWAIT,
177 )
178 };
179
180 if ret > 0 {
181 return SignalResult::Sent;
182 }
183
184 if ret == 0 {
185 return SignalResult::Sent;
187 }
188
189 let err = io::Error::last_os_error();
190 let raw_err = err.raw_os_error();
191
192 let is_buffer_full = err.kind() == ErrorKind::WouldBlock || raw_err == Some(libc::ENOBUFS);
195
196 if is_buffer_full {
197 return SignalResult::BufferFull;
198 }
199
200 match err.kind() {
201 ErrorKind::BrokenPipe | ErrorKind::ConnectionReset | ErrorKind::NotConnected => {
203 SignalResult::PeerDead
204 }
205 _ => {
206 if !self.peer_dead_logged.swap(true, Ordering::Relaxed) {
208 tracing::debug!(fd, error = %err, "doorbell signal failed (peer likely dead)");
209 }
210 SignalResult::PeerDead
211 }
212 }
213 }
214
215 pub fn is_peer_dead(&self) -> bool {
217 self.peer_dead_logged.load(Ordering::Relaxed)
218 }
219
220 pub async fn wait(&self) -> io::Result<()> {
222 if self.try_drain() {
223 return Ok(());
224 }
225
226 loop {
227 let mut guard = self.async_fd.ready(Interest::READABLE).await?;
228
229 let drained = guard.try_io(|inner| {
230 let fd = inner.get_ref().as_raw_fd();
231 drain_fd(fd, true).map(|_| ())
232 });
233
234 match drained {
235 Ok(Ok(())) => return Ok(()),
236 Ok(Err(e)) => return Err(e),
237 Err(_would_block) => continue,
238 }
239 }
240 }
241
242 fn try_drain(&self) -> bool {
243 let fd = self.async_fd.get_ref().as_raw_fd();
244 match drain_fd(fd, false) {
245 Ok(drained) => drained,
246 Err(err) => {
247 tracing::warn!(fd, error = %err, "doorbell drain failed");
248 false
249 }
250 }
251 }
252
253 pub fn drain(&self) {
255 self.try_drain();
256 }
257
258 pub async fn accept(&self) -> io::Result<()> {
263 Ok(())
265 }
266
267 pub fn pending_bytes(&self) -> usize {
269 let fd = self.async_fd.get_ref().as_raw_fd();
270 let mut pending: libc::c_int = 0;
271 let ret = unsafe { libc::ioctl(fd, libc::FIONREAD, &mut pending) };
272 if ret < 0 { 0 } else { pending as usize }
273 }
274}
275
276fn create_socketpair() -> io::Result<(OwnedFd, OwnedFd)> {
277 let mut fds = [0i32; 2];
278
279 #[cfg(target_os = "linux")]
280 let sock_type = libc::SOCK_DGRAM | libc::SOCK_NONBLOCK;
281 #[cfg(not(target_os = "linux"))]
282 let sock_type = libc::SOCK_DGRAM;
283
284 let ret = unsafe { libc::socketpair(libc::AF_UNIX, sock_type, 0, fds.as_mut_ptr()) };
285 if ret < 0 {
286 return Err(io::Error::last_os_error());
287 }
288
289 let fd0 = unsafe { OwnedFd::from_raw_fd(fds[0]) };
290 let fd1 = unsafe { OwnedFd::from_raw_fd(fds[1]) };
291
292 #[cfg(not(target_os = "linux"))]
293 {
294 set_nonblocking(fd0.as_raw_fd())?;
295 set_nonblocking(fd1.as_raw_fd())?;
296 }
297
298 Ok((fd0, fd1))
299}
300
301pub fn set_nonblocking(fd: RawFd) -> io::Result<()> {
303 let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
304 if flags < 0 {
305 return Err(io::Error::last_os_error());
306 }
307 let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
308 if ret < 0 {
309 return Err(io::Error::last_os_error());
310 }
311 Ok(())
312}
313
314pub fn close_peer_fd(fd: RawFd) {
320 unsafe {
321 libc::close(fd);
322 }
323}
324
325pub fn validate_fd(fd: RawFd) -> io::Result<()> {
329 let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
330 if flags < 0 {
331 Err(io::Error::last_os_error())
332 } else {
333 Ok(())
334 }
335}
336
337pub fn clear_cloexec(fd: RawFd) -> io::Result<()> {
343 let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
344 if flags < 0 {
345 return Err(io::Error::last_os_error());
346 }
347
348 let ret = unsafe { libc::fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC) };
349 if ret < 0 {
350 Err(io::Error::last_os_error())
351 } else {
352 Ok(())
353 }
354}