vortex_ipc/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::pin::Pin;
6use std::task::Poll;
7use std::task::ready;
8
9use bytes::Bytes;
10use bytes::BytesMut;
11use futures::AsyncRead;
12use futures::AsyncWrite;
13use futures::AsyncWriteExt;
14use futures::Stream;
15use futures::StreamExt;
16use futures::TryStreamExt;
17use pin_project_lite::pin_project;
18use vortex_array::ArrayRef;
19use vortex_array::session::ArrayRegistry;
20use vortex_array::stream::ArrayStream;
21use vortex_dtype::DType;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_error::vortex_err;
25
26use crate::messages::AsyncMessageReader;
27use crate::messages::DecoderMessage;
28use crate::messages::EncoderMessage;
29use crate::messages::MessageEncoder;
30
31pin_project! {
32    /// An [`ArrayStream`] for reading messages off an async IPC stream.
33    pub struct AsyncIPCReader<R> {
34        #[pin]
35        reader: AsyncMessageReader<R>,
36        dtype: DType,
37    }
38}
39
40impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
41    pub async fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
42        let mut reader = AsyncMessageReader::new(read, registry);
43
44        let dtype = match reader.next().await.transpose()? {
45            Some(msg) => match msg {
46                DecoderMessage::DType(dtype) => dtype,
47                msg => {
48                    vortex_bail!("Expected DType message, got {:?}", msg);
49                }
50            },
51            None => vortex_bail!("Expected DType message, got EOF"),
52        };
53
54        Ok(AsyncIPCReader { reader, dtype })
55    }
56}
57
58impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
59    fn dtype(&self) -> &DType {
60        &self.dtype
61    }
62}
63
64impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
65    type Item = VortexResult<ArrayRef>;
66
67    fn poll_next(
68        self: Pin<&mut Self>,
69        cx: &mut std::task::Context<'_>,
70    ) -> Poll<Option<Self::Item>> {
71        let this = self.project();
72
73        match ready!(this.reader.poll_next(cx)) {
74            None => Poll::Ready(None),
75            Some(msg) => match msg {
76                Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
77                    array_parts
78                        .decode(&ctx, this.dtype, row_count)
79                        .and_then(|array| {
80                            if array.dtype() != this.dtype {
81                                Err(vortex_err!(
82                                    "Array data type mismatch: expected {:?}, got {:?}",
83                                    this.dtype,
84                                    array.dtype()
85                                ))
86                            } else {
87                                Ok(array)
88                            }
89                        }),
90                )),
91                Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
92                    "Expected Array message, got {:?}",
93                    msg
94                )))),
95                Err(e) => Poll::Ready(Some(Err(e))),
96            },
97        }
98    }
99}
100
101/// A trait for converting an [`ArrayStream`] into IPC streams.
102pub trait ArrayStreamIPC {
103    fn into_ipc(self) -> ArrayStreamIPCBytes
104    where
105        Self: Sized;
106
107    fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
108    where
109        Self: Sized;
110}
111
112impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
113    fn into_ipc(self) -> ArrayStreamIPCBytes
114    where
115        Self: Sized,
116    {
117        ArrayStreamIPCBytes {
118            stream: Box::pin(self),
119            encoder: MessageEncoder::default(),
120            buffers: vec![],
121            written_dtype: false,
122        }
123    }
124
125    async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
126    where
127        Self: Sized,
128    {
129        let mut stream = self.into_ipc();
130        while let Some(chunk) = stream.next().await {
131            write.write_all(&chunk?).await?;
132        }
133        Ok(write)
134    }
135}
136
137pub struct ArrayStreamIPCBytes {
138    stream: Pin<Box<dyn ArrayStream + 'static>>,
139    encoder: MessageEncoder,
140    buffers: Vec<Bytes>,
141    written_dtype: bool,
142}
143
144impl ArrayStreamIPCBytes {
145    /// Collects the IPC bytes into a single `Bytes`.
146    pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
147        let buffers: Vec<Bytes> = self.try_collect().await?;
148        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
149        for buf in buffers {
150            buffer.extend_from_slice(buf.as_ref());
151        }
152        Ok(buffer.freeze())
153    }
154}
155
156impl Stream for ArrayStreamIPCBytes {
157    type Item = VortexResult<Bytes>;
158
159    fn poll_next(
160        self: Pin<&mut Self>,
161        cx: &mut std::task::Context<'_>,
162    ) -> Poll<Option<Self::Item>> {
163        let this = self.get_mut();
164
165        // If we haven't written the dtype yet, we write it
166        if !this.written_dtype {
167            this.buffers.extend(
168                this.encoder
169                    .encode(EncoderMessage::DType(this.stream.dtype())),
170            );
171            this.written_dtype = true;
172        }
173
174        // Try to flush any buffers we have
175        if !this.buffers.is_empty() {
176            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
177        }
178
179        // Or else try to serialize the next array
180        match ready!(this.stream.poll_next_unpin(cx)) {
181            None => return Poll::Ready(None),
182            Some(chunk) => match chunk {
183                Ok(chunk) => {
184                    this.buffers
185                        .extend(this.encoder.encode(EncoderMessage::Array(&chunk)));
186                }
187                Err(e) => return Poll::Ready(Some(Err(e))),
188            },
189        }
190
191        // Try to flush any buffers we have again
192        if !this.buffers.is_empty() {
193            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
194        }
195
196        // Otherwise, we're done
197        Poll::Ready(None)
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use std::io;
204    use std::pin::Pin;
205    use std::task::Context;
206    use std::task::Poll;
207
208    use futures::io::Cursor;
209    use vortex_array::IntoArray as _;
210    use vortex_array::ToCanonical;
211    use vortex_array::session::ArraySession;
212    use vortex_array::stream::ArrayStream;
213    use vortex_array::stream::ArrayStreamExt;
214    use vortex_buffer::buffer;
215
216    use super::*;
217
218    #[tokio::test]
219    async fn test_async_stream() {
220        let session = ArraySession::default();
221        let array = buffer![1, 2, 3].into_array();
222        let ipc_buffer = array
223            .to_array_stream()
224            .into_ipc()
225            .collect_to_buffer()
226            .await
227            .unwrap();
228
229        let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), session.registry().clone())
230            .await
231            .unwrap();
232
233        assert_eq!(reader.dtype(), array.dtype());
234        let result = reader.read_all().await.unwrap().to_primitive();
235        assert_eq!(
236            array.to_primitive().as_slice::<i32>(),
237            result.as_slice::<i32>()
238        );
239    }
240
241    /// Wrapper that limits reads to small chunks to simulate network behavior
242    struct ChunkedReader<R> {
243        inner: R,
244        chunk_size: usize,
245    }
246
247    impl<R: AsyncRead + Unpin> AsyncRead for ChunkedReader<R> {
248        fn poll_read(
249            mut self: Pin<&mut Self>,
250            cx: &mut Context<'_>,
251            buf: &mut [u8],
252        ) -> Poll<io::Result<usize>> {
253            let chunk_size = self.chunk_size.min(buf.len());
254            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..chunk_size])
255        }
256    }
257
258    #[tokio::test]
259    async fn test_async_stream_chunked() {
260        let session = ArraySession::default();
261        let array = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
262        let ipc_buffer = array
263            .to_array_stream()
264            .into_ipc()
265            .collect_to_buffer()
266            .await
267            .unwrap();
268
269        let chunked = ChunkedReader {
270            inner: Cursor::new(ipc_buffer),
271            chunk_size: 3,
272        };
273
274        let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
275            .await
276            .unwrap();
277
278        let result = reader.read_all().await.unwrap().to_primitive();
279        assert_eq!(
280            &[1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10],
281            result.as_slice::<i32>()
282        );
283    }
284
285    /// Test with 1-byte chunks to stress-test partial read handling.
286    #[tokio::test]
287    async fn test_async_stream_single_byte_chunks() {
288        let session = ArraySession::default();
289        let array = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
290        let ipc_buffer = array
291            .to_array_stream()
292            .into_ipc()
293            .collect_to_buffer()
294            .await
295            .unwrap();
296
297        let chunked = ChunkedReader {
298            inner: Cursor::new(ipc_buffer),
299            chunk_size: 1,
300        };
301
302        let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
303            .await
304            .unwrap();
305
306        let result = reader.read_all().await.unwrap().to_primitive();
307        assert_eq!(
308            &[42i64, -1, 0, i64::MAX, i64::MIN],
309            result.as_slice::<i64>()
310        );
311    }
312}