Skip to main content

vortex_ipc/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::pin::Pin;
6use std::task::Poll;
7use std::task::ready;
8
9use bytes::Bytes;
10use bytes::BytesMut;
11use futures::AsyncRead;
12use futures::AsyncWrite;
13use futures::AsyncWriteExt;
14use futures::Stream;
15use futures::StreamExt;
16use futures::TryStreamExt;
17use pin_project_lite::pin_project;
18use vortex_array::ArrayRef;
19use vortex_array::dtype::DType;
20use vortex_array::stream::ArrayStream;
21use vortex_error::VortexResult;
22use vortex_error::vortex_bail;
23use vortex_error::vortex_err;
24use vortex_session::VortexSession;
25
26use crate::messages::AsyncMessageReader;
27use crate::messages::DecoderMessage;
28use crate::messages::EncoderMessage;
29use crate::messages::MessageEncoder;
30
31pin_project! {
32    /// An [`ArrayStream`] for reading messages off an async IPC stream.
33    pub struct AsyncIPCReader<R> {
34        #[pin]
35        reader: AsyncMessageReader<R>,
36        dtype: DType,
37        session: VortexSession,
38    }
39}
40
41impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
42    pub async fn try_new(read: R, session: &VortexSession) -> VortexResult<Self> {
43        let mut reader = AsyncMessageReader::new(read);
44
45        let dtype = match reader.next().await.transpose()? {
46            Some(msg) => match msg {
47                DecoderMessage::DType(dtype) => dtype,
48                msg => {
49                    vortex_bail!("Expected DType message, got {:?}", msg);
50                }
51            },
52            None => vortex_bail!("Expected DType message, got EOF"),
53        };
54
55        let dtype = DType::from_flatbuffer(dtype, session)?;
56
57        Ok(AsyncIPCReader {
58            reader,
59            dtype,
60            session: session.clone(),
61        })
62    }
63}
64
65impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
66    fn dtype(&self) -> &DType {
67        &self.dtype
68    }
69}
70
71impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
72    type Item = VortexResult<ArrayRef>;
73
74    fn poll_next(
75        self: Pin<&mut Self>,
76        cx: &mut std::task::Context<'_>,
77    ) -> Poll<Option<Self::Item>> {
78        let this = self.project();
79
80        match ready!(this.reader.poll_next(cx)) {
81            None => Poll::Ready(None),
82            Some(msg) => match msg {
83                Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
84                    array_parts
85                        .decode(this.dtype, row_count, &ctx, this.session)
86                        .and_then(|array| {
87                            if array.dtype() != this.dtype {
88                                Err(vortex_err!(
89                                    "Array data type mismatch: expected {:?}, got {:?}",
90                                    this.dtype,
91                                    array.dtype()
92                                ))
93                            } else {
94                                Ok(array)
95                            }
96                        }),
97                )),
98                Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
99                    "Expected Array message, got {:?}",
100                    msg
101                )))),
102                Err(e) => Poll::Ready(Some(Err(e))),
103            },
104        }
105    }
106}
107
108/// A trait for converting an [`ArrayStream`] into IPC streams.
109pub trait ArrayStreamIPC {
110    fn into_ipc(self) -> ArrayStreamIPCBytes
111    where
112        Self: Sized;
113
114    fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
115    where
116        Self: Sized;
117}
118
119impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
120    fn into_ipc(self) -> ArrayStreamIPCBytes
121    where
122        Self: Sized,
123    {
124        ArrayStreamIPCBytes {
125            stream: Box::pin(self),
126            encoder: MessageEncoder::default(),
127            buffers: vec![],
128            written_dtype: false,
129        }
130    }
131
132    async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
133    where
134        Self: Sized,
135    {
136        let mut stream = self.into_ipc();
137        while let Some(chunk) = stream.next().await {
138            write.write_all(&chunk?).await?;
139        }
140        Ok(write)
141    }
142}
143
144pub struct ArrayStreamIPCBytes {
145    stream: Pin<Box<dyn ArrayStream + 'static>>,
146    encoder: MessageEncoder,
147    buffers: Vec<Bytes>,
148    written_dtype: bool,
149}
150
151impl ArrayStreamIPCBytes {
152    /// Collects the IPC bytes into a single `Bytes`.
153    pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
154        let buffers: Vec<Bytes> = self.try_collect().await?;
155        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
156        for buf in buffers {
157            buffer.extend_from_slice(buf.as_ref());
158        }
159        Ok(buffer.freeze())
160    }
161}
162
163impl Stream for ArrayStreamIPCBytes {
164    type Item = VortexResult<Bytes>;
165
166    fn poll_next(
167        self: Pin<&mut Self>,
168        cx: &mut std::task::Context<'_>,
169    ) -> Poll<Option<Self::Item>> {
170        let this = self.get_mut();
171
172        // If we haven't written the dtype yet, we write it
173        if !this.written_dtype {
174            let Ok(buffers) = this
175                .encoder
176                .encode(EncoderMessage::DType(this.stream.dtype()))
177            else {
178                return Poll::Ready(Some(Err(vortex_err!("Failed to encode DType message"))));
179            };
180            this.buffers.extend(buffers);
181            this.written_dtype = true;
182        }
183
184        // Try to flush any buffers we have
185        if !this.buffers.is_empty() {
186            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
187        }
188
189        // Or else try to serialize the next array
190        match ready!(this.stream.poll_next_unpin(cx)) {
191            None => return Poll::Ready(None),
192            Some(chunk) => match chunk.and_then(|c| this.encoder.encode(EncoderMessage::Array(&c)))
193            {
194                Ok(buffers) => {
195                    this.buffers.extend(buffers);
196                }
197                Err(e) => return Poll::Ready(Some(Err(e))),
198            },
199        }
200
201        // Try to flush any buffers we have again
202        if !this.buffers.is_empty() {
203            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
204        }
205
206        // Otherwise, we're done
207        Poll::Ready(None)
208    }
209}
210
211#[cfg(test)]
212mod test {
213    use std::io;
214    use std::pin::Pin;
215    use std::task::Context;
216    use std::task::Poll;
217
218    use futures::io::Cursor;
219    use vortex_array::IntoArray as _;
220    use vortex_array::assert_arrays_eq;
221    use vortex_array::stream::ArrayStream;
222    use vortex_array::stream::ArrayStreamExt;
223    use vortex_buffer::buffer;
224
225    use super::*;
226    use crate::test::SESSION;
227
228    #[tokio::test]
229    async fn test_async_stream() {
230        let array = buffer![1, 2, 3].into_array();
231        let ipc_buffer = array
232            .to_array_stream()
233            .into_ipc()
234            .collect_to_buffer()
235            .await
236            .unwrap();
237
238        let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), &SESSION)
239            .await
240            .unwrap();
241
242        assert_eq!(reader.dtype(), array.dtype());
243        let result = reader.read_all().await.unwrap();
244        assert_arrays_eq!(result, array);
245    }
246
247    /// Wrapper that limits reads to small chunks to simulate network behavior
248    struct ChunkedReader<R> {
249        inner: R,
250        chunk_size: usize,
251    }
252
253    impl<R: AsyncRead + Unpin> AsyncRead for ChunkedReader<R> {
254        fn poll_read(
255            mut self: Pin<&mut Self>,
256            cx: &mut Context<'_>,
257            buf: &mut [u8],
258        ) -> Poll<io::Result<usize>> {
259            let chunk_size = self.chunk_size.min(buf.len());
260            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..chunk_size])
261        }
262    }
263
264    #[tokio::test]
265    async fn test_async_stream_chunked() {
266        let array = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
267        let ipc_buffer = array
268            .to_array_stream()
269            .into_ipc()
270            .collect_to_buffer()
271            .await
272            .unwrap();
273
274        let chunked = ChunkedReader {
275            inner: Cursor::new(ipc_buffer),
276            chunk_size: 3,
277        };
278
279        let reader = AsyncIPCReader::try_new(chunked, &SESSION).await.unwrap();
280
281        let result = reader.read_all().await.unwrap();
282        let expected = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
283        assert_arrays_eq!(result, expected);
284    }
285
286    /// Test with 1-byte chunks to stress-test partial read handling.
287    #[tokio::test]
288    async fn test_async_stream_single_byte_chunks() {
289        let array = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
290        let ipc_buffer = array
291            .to_array_stream()
292            .into_ipc()
293            .collect_to_buffer()
294            .await
295            .unwrap();
296
297        let chunked = ChunkedReader {
298            inner: Cursor::new(ipc_buffer),
299            chunk_size: 1,
300        };
301
302        let reader = AsyncIPCReader::try_new(chunked, &SESSION).await.unwrap();
303
304        let result = reader.read_all().await.unwrap();
305        let expected = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
306        assert_arrays_eq!(result, expected);
307    }
308}