Skip to main content

vortex_ipc/messages/
reader_async.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::task::Context;
6use std::task::Poll;
7use std::task::ready;
8
9use bytes::BytesMut;
10use futures::AsyncRead;
11use futures::Stream;
12use pin_project_lite::pin_project;
13use vortex_error::VortexResult;
14use vortex_error::vortex_err;
15
16use crate::messages::DecoderMessage;
17use crate::messages::MessageDecoder;
18use crate::messages::PollRead;
19
20pin_project! {
21    /// An IPC message reader backed by an `AsyncRead` stream.
22    pub struct AsyncMessageReader<R> {
23        #[pin]
24        read: R,
25        buffer: BytesMut,
26        decoder: MessageDecoder,
27        state: ReadState,
28    }
29}
30
31impl<R> AsyncMessageReader<R> {
32    pub fn new(read: R) -> Self {
33        AsyncMessageReader {
34            read,
35            buffer: BytesMut::new(),
36            decoder: MessageDecoder::default(),
37            state: ReadState::default(),
38        }
39    }
40}
41
42/// The state of an in-progress read operation.
43#[derive(Default)]
44enum ReadState {
45    /// Ready to consult the decoder for the next operation.
46    #[default]
47    AwaitingDecoder,
48    /// Filling the buffer with data from the underlying reader.
49    ///
50    /// Async readers may return fewer bytes than requested (partial reads), especially over network
51    /// connections. This state persists across multiple `poll_next` calls until the buffer is
52    /// completely filled, at which point we transition back to [`Self::AwaitingDecoder`].
53    Filling {
54        /// The number of bytes read into the buffer so far.
55        total_bytes_read: usize,
56    },
57}
58
59/// Result of polling the reader to fill the buffer.
60enum FillResult {
61    /// The buffer has been completely filled.
62    Filled,
63    /// Need more data (partial read occurred).
64    Pending,
65    /// Clean EOF at a message boundary.
66    Eof,
67}
68
69/// Polls the reader to fill the buffer, handling partial reads.
70fn poll_fill_buffer<R: AsyncRead>(
71    read: Pin<&mut R>,
72    buffer: &mut [u8],
73    total_bytes_read: &mut usize,
74    cx: &mut Context<'_>,
75) -> Poll<VortexResult<FillResult>> {
76    let unfilled = &mut buffer[*total_bytes_read..];
77
78    let bytes_read = ready!(read.poll_read(cx, unfilled))?;
79
80    // `0` bytes read indicates an EOF.
81    Poll::Ready(if bytes_read == 0 {
82        if *total_bytes_read > 0 {
83            Err(vortex_err!(
84                "unexpected EOF during partial read: read {total_bytes_read} of {} expected bytes",
85                buffer.len()
86            ))
87        } else {
88            Ok(FillResult::Eof)
89        }
90    } else {
91        *total_bytes_read += bytes_read;
92        if *total_bytes_read == buffer.len() {
93            Ok(FillResult::Filled)
94        } else {
95            debug_assert!(*total_bytes_read < buffer.len());
96            Ok(FillResult::Pending)
97        }
98    })
99}
100
101impl<R: AsyncRead> Stream for AsyncMessageReader<R> {
102    type Item = VortexResult<DecoderMessage>;
103
104    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105        let mut this = self.project();
106        loop {
107            match this.state {
108                ReadState::AwaitingDecoder => match this.decoder.read_next(this.buffer)? {
109                    PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))),
110                    PollRead::NeedMore(new_len) => {
111                        this.buffer.resize(new_len, 0x00);
112                        *this.state = ReadState::Filling {
113                            total_bytes_read: 0,
114                        };
115                    }
116                },
117                ReadState::Filling { total_bytes_read } => {
118                    match ready!(poll_fill_buffer(
119                        this.read.as_mut(),
120                        this.buffer,
121                        total_bytes_read,
122                        cx
123                    )) {
124                        Err(e) => return Poll::Ready(Some(Err(e))),
125                        Ok(FillResult::Eof) => return Poll::Ready(None),
126                        Ok(FillResult::Filled) => *this.state = ReadState::AwaitingDecoder,
127                        Ok(FillResult::Pending) => {}
128                    }
129                }
130            }
131        }
132    }
133}