nonblocking/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8//! Turn an async function to a sync non-blocking function.
9
10use std::future::Future;
11use std::io;
12use std::task::Context;
13use std::task::Poll;
14use std::task::RawWaker;
15use std::task::RawWakerVTable;
16use std::task::Waker;
17
18/// Attempt to resolve a `future` without blocking.
19/// Return `WouldBlock` error if the future will block.
20/// Return the resolved value otherwise.
21pub fn non_blocking<F, R>(future: F) -> io::Result<R>
22where
23    F: Future<Output = R>,
24{
25    let waker = waker();
26    let mut cx = Context::from_waker(&waker);
27    let mut future = Box::pin(future);
28    match future.as_mut().poll(&mut cx) {
29        Poll::Ready(result) => Ok(result),
30        Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
31    }
32}
33
34/// Similar to `non_blocking`, but unwraps a level of `Result`.
35pub fn non_blocking_result<F, T, E>(future: F) -> Result<T, E>
36where
37    F: Future<Output = Result<T, E>>,
38    E: From<io::Error>,
39{
40    non_blocking(future)?
41}
42
43fn waker() -> Waker {
44    let raw_waker = clone(std::ptr::null());
45    unsafe { Waker::from_raw(raw_waker) }
46}
47
48fn vtable() -> &'static RawWakerVTable {
49    &RawWakerVTable::new(clone, noop, noop, noop)
50}
51
52fn clone(data: *const ()) -> RawWaker {
53    RawWaker::new(data, vtable())
54}
55
56fn noop(_: *const ()) {}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn test_non_blocking_ok() {
64        async fn f() -> usize {
65            g().await + 4
66        }
67
68        async fn g() -> usize {
69            3
70        }
71
72        async fn h() -> io::Result<usize> {
73            Ok(5)
74        }
75
76        assert_eq!(non_blocking(async { f().await }).unwrap(), 7);
77        assert_eq!(non_blocking_result(h()).unwrap(), 5);
78    }
79
80    #[test]
81    fn test_non_blocking_err() {
82        let (sender, receiver) = futures::channel::oneshot::channel::<usize>();
83        assert_eq!(
84            non_blocking(receiver).unwrap_err().kind(),
85            io::ErrorKind::WouldBlock
86        );
87        drop(sender);
88    }
89}