vortex_ipc/messages/
reader_async.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::pin::Pin;
5use std::task::Context;
6use std::task::Poll;
7use std::task::ready;
8
9use bytes::BytesMut;
10use futures::AsyncRead;
11use futures::Stream;
12use pin_project_lite::pin_project;
13use vortex_array::session::ArrayRegistry;
14use vortex_error::VortexResult;
15use vortex_error::vortex_err;
16
17use crate::messages::DecoderMessage;
18use crate::messages::MessageDecoder;
19use crate::messages::PollRead;
20
21pin_project! {
22    /// An IPC message reader backed by an `AsyncRead` stream.
23    pub struct AsyncMessageReader<R> {
24        #[pin]
25        read: R,
26        buffer: BytesMut,
27        decoder: MessageDecoder,
28        state: ReadState,
29    }
30}
31
32impl<R> AsyncMessageReader<R> {
33    pub fn new(read: R, registry: ArrayRegistry) -> Self {
34        AsyncMessageReader {
35            read,
36            buffer: BytesMut::new(),
37            decoder: MessageDecoder::new(registry),
38            state: ReadState::default(),
39        }
40    }
41}
42
43/// The state of an in-progress read operation.
44#[derive(Default)]
45enum ReadState {
46    /// Ready to consult the decoder for the next operation.
47    #[default]
48    AwaitingDecoder,
49    /// Filling the buffer with data from the underlying reader.
50    ///
51    /// Async readers may return fewer bytes than requested (partial reads), especially over network
52    /// connections. This state persists across multiple `poll_next` calls until the buffer is
53    /// completely filled, at which point we transition back to [`Self::AwaitingDecoder`].
54    Filling {
55        /// The number of bytes read into the buffer so far.
56        total_bytes_read: usize,
57    },
58}
59
60/// Result of polling the reader to fill the buffer.
61enum FillResult {
62    /// The buffer has been completely filled.
63    Filled,
64    /// Need more data (partial read occurred).
65    Pending,
66    /// Clean EOF at a message boundary.
67    Eof,
68}
69
70/// Polls the reader to fill the buffer, handling partial reads.
71fn poll_fill_buffer<R: AsyncRead>(
72    read: Pin<&mut R>,
73    buffer: &mut [u8],
74    total_bytes_read: &mut usize,
75    cx: &mut Context<'_>,
76) -> Poll<VortexResult<FillResult>> {
77    let unfilled = &mut buffer[*total_bytes_read..];
78
79    let bytes_read = ready!(read.poll_read(cx, unfilled))?;
80
81    // `0` bytes read indicates an EOF.
82    Poll::Ready(if bytes_read == 0 {
83        if *total_bytes_read > 0 {
84            Err(vortex_err!(
85                "unexpected EOF during partial read: read {total_bytes_read} of {} expected bytes",
86                buffer.len()
87            ))
88        } else {
89            Ok(FillResult::Eof)
90        }
91    } else {
92        *total_bytes_read += bytes_read;
93        if *total_bytes_read == buffer.len() {
94            Ok(FillResult::Filled)
95        } else {
96            debug_assert!(*total_bytes_read < buffer.len());
97            Ok(FillResult::Pending)
98        }
99    })
100}
101
102impl<R: AsyncRead> Stream for AsyncMessageReader<R> {
103    type Item = VortexResult<DecoderMessage>;
104
105    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        let mut this = self.project();
107        loop {
108            match this.state {
109                ReadState::AwaitingDecoder => match this.decoder.read_next(this.buffer)? {
110                    PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))),
111                    PollRead::NeedMore(new_len) => {
112                        this.buffer.resize(new_len, 0x00);
113                        *this.state = ReadState::Filling {
114                            total_bytes_read: 0,
115                        };
116                    }
117                },
118                ReadState::Filling { total_bytes_read } => {
119                    match ready!(poll_fill_buffer(
120                        this.read.as_mut(),
121                        this.buffer,
122                        total_bytes_read,
123                        cx
124                    )) {
125                        Err(e) => return Poll::Ready(Some(Err(e))),
126                        Ok(FillResult::Eof) => return Poll::Ready(None),
127                        Ok(FillResult::Filled) => *this.state = ReadState::AwaitingDecoder,
128                        Ok(FillResult::Pending) => {}
129                    }
130                }
131            }
132        }
133    }
134}