tokio_sync_read_stream/
lib.rs

1use futures::Stream;
2use std::io;
3use std::io::Read;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::Mutex;
7use std::task::Context;
8use std::task::Poll;
9use tokio::runtime::Handle;
10
11const DEFAULT_BUFFER_SIZE: usize = 1024 * 16;
12
13struct State<R: Read + Send + Sync + 'static> {
14  readable: R,
15  res: Option<Result<Option<Vec<u8>>, io::Error>>,
16}
17
18pub struct SyncReadStream<R: Read + Send + Sync + 'static> {
19  // An `Arc<Mutex<>>` is probably unnecessary, but it's fully uncontended and rarely cloned (both assuming `poll_next` is never called until `waker.wake()`), and easy to work with.
20  state: Arc<Mutex<State<R>>>,
21  tokio: Handle,
22  buffer_size: usize,
23}
24
25impl<R: Read + Send + Sync + 'static> SyncReadStream<R> {
26  /// This must be called from within a Tokio runtime context, or else it will panic.
27  pub fn with_tokio_handle_and_buffer_size(tokio: Handle, readable: R, buffer_size: usize) -> Self {
28    Self {
29      tokio,
30      buffer_size,
31      state: Arc::new(Mutex::new(State {
32        readable,
33        res: None,
34      })),
35    }
36  }
37
38  /// This must be called from within a Tokio runtime context, or else it will panic.
39  pub fn with_tokio_handle(tokio: Handle, readable: R) -> Self {
40    Self::with_tokio_handle_and_buffer_size(tokio, readable, DEFAULT_BUFFER_SIZE)
41  }
42
43  /// This must be called from within a Tokio runtime context, or else it will panic.
44  pub fn with_buffer_size(readable: R, buffer_size: usize) -> Self {
45    Self::with_tokio_handle_and_buffer_size(Handle::current(), readable, buffer_size)
46  }
47}
48
49impl<R: Read + Send + Sync> Stream for SyncReadStream<R> {
50  type Item = Result<Vec<u8>, io::Error>;
51
52  fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53    let buffer_size = self.buffer_size;
54    let mut state = self.state.lock().unwrap();
55    if let Some(res) = state.res.take() {
56      return Poll::Ready(res.transpose());
57    };
58    let waker = cx.waker().clone();
59    drop(state);
60    let state = Arc::clone(&self.state);
61    self.tokio.spawn_blocking(move || {
62      let mut state = state.lock().unwrap();
63      let mut buf = vec![0u8; buffer_size];
64      state.res = Some(match state.readable.read(&mut buf) {
65        Ok(n) if n == 0 => Ok(None),
66        Ok(n) => {
67          buf.truncate(n);
68          Ok(Some(buf))
69        }
70        Err(err) => Err(err),
71      });
72      waker.wake();
73    });
74    Poll::Pending
75  }
76}
77
78impl<R: Read + Send + Sync + 'static> From<R> for SyncReadStream<R> {
79  /// This must be called from within a Tokio runtime context, or else it will panic.
80  fn from(value: R) -> Self {
81    Self::with_buffer_size(value, DEFAULT_BUFFER_SIZE)
82  }
83}