vortex_ipc/
stream.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Poll, ready};
7
8use bytes::{Bytes, BytesMut};
9use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream, StreamExt, TryStreamExt};
10use pin_project_lite::pin_project;
11use vortex_array::stream::ArrayStream;
12use vortex_array::{ArrayRef, ArrayRegistry};
13use vortex_dtype::DType;
14use vortex_error::{VortexResult, vortex_bail, vortex_err};
15
16use crate::messages::{AsyncMessageReader, DecoderMessage, EncoderMessage, MessageEncoder};
17
18pin_project! {
19    /// An [`ArrayStream`] for reading messages off an async IPC stream.
20    pub struct AsyncIPCReader<R> {
21        #[pin]
22        reader: AsyncMessageReader<R>,
23        dtype: DType,
24    }
25}
26
27impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
28    pub async fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
29        let mut reader = AsyncMessageReader::new(read, registry);
30
31        let dtype = match reader.next().await.transpose()? {
32            Some(msg) => match msg {
33                DecoderMessage::DType(dtype) => dtype,
34                msg => {
35                    vortex_bail!("Expected DType message, got {:?}", msg);
36                }
37            },
38            None => vortex_bail!("Expected DType message, got EOF"),
39        };
40
41        Ok(AsyncIPCReader { reader, dtype })
42    }
43}
44
45impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
46    fn dtype(&self) -> &DType {
47        &self.dtype
48    }
49}
50
51impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
52    type Item = VortexResult<ArrayRef>;
53
54    fn poll_next(
55        self: Pin<&mut Self>,
56        cx: &mut std::task::Context<'_>,
57    ) -> Poll<Option<Self::Item>> {
58        let this = self.project();
59
60        match ready!(this.reader.poll_next(cx)) {
61            None => Poll::Ready(None),
62            Some(msg) => match msg {
63                Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
64                    array_parts
65                        .decode(&ctx, this.dtype, row_count)
66                        .and_then(|array| {
67                            if array.dtype() != this.dtype {
68                                Err(vortex_err!(
69                                    "Array data type mismatch: expected {:?}, got {:?}",
70                                    this.dtype,
71                                    array.dtype()
72                                ))
73                            } else {
74                                Ok(array)
75                            }
76                        }),
77                )),
78                Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
79                    "Expected Array message, got {:?}",
80                    msg
81                )))),
82                Err(e) => Poll::Ready(Some(Err(e))),
83            },
84        }
85    }
86}
87
88/// A trait for converting an [`ArrayStream`] into IPC streams.
89pub trait ArrayStreamIPC {
90    fn into_ipc(self) -> ArrayStreamIPCBytes
91    where
92        Self: Sized;
93
94    fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
95    where
96        Self: Sized;
97}
98
99impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
100    fn into_ipc(self) -> ArrayStreamIPCBytes
101    where
102        Self: Sized,
103    {
104        ArrayStreamIPCBytes {
105            stream: Box::pin(self),
106            encoder: MessageEncoder::default(),
107            buffers: vec![],
108            written_dtype: false,
109        }
110    }
111
112    async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
113    where
114        Self: Sized,
115    {
116        let mut stream = self.into_ipc();
117        while let Some(chunk) = stream.next().await {
118            write.write_all(&chunk?).await?;
119        }
120        Ok(write)
121    }
122}
123
124pub struct ArrayStreamIPCBytes {
125    stream: Pin<Box<dyn ArrayStream + 'static>>,
126    encoder: MessageEncoder,
127    buffers: Vec<Bytes>,
128    written_dtype: bool,
129}
130
131impl ArrayStreamIPCBytes {
132    /// Collects the IPC bytes into a single `Bytes`.
133    pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
134        let buffers: Vec<Bytes> = self.try_collect().await?;
135        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
136        for buf in buffers {
137            buffer.extend_from_slice(buf.as_ref());
138        }
139        Ok(buffer.freeze())
140    }
141}
142
143impl Stream for ArrayStreamIPCBytes {
144    type Item = VortexResult<Bytes>;
145
146    fn poll_next(
147        self: Pin<&mut Self>,
148        cx: &mut std::task::Context<'_>,
149    ) -> Poll<Option<Self::Item>> {
150        let this = self.get_mut();
151
152        // If we haven't written the dtype yet, we write it
153        if !this.written_dtype {
154            this.buffers.extend(
155                this.encoder
156                    .encode(EncoderMessage::DType(this.stream.dtype())),
157            );
158            this.written_dtype = true;
159        }
160
161        // Try to flush any buffers we have
162        if !this.buffers.is_empty() {
163            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
164        }
165
166        // Or else try to serialize the next array
167        match ready!(this.stream.poll_next_unpin(cx)) {
168            None => return Poll::Ready(None),
169            Some(chunk) => match chunk {
170                Ok(chunk) => {
171                    this.buffers
172                        .extend(this.encoder.encode(EncoderMessage::Array(&chunk)));
173                }
174                Err(e) => return Poll::Ready(Some(Err(e))),
175            },
176        }
177
178        // Try to flush any buffers we have again
179        if !this.buffers.is_empty() {
180            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
181        }
182
183        // Otherwise, we're done
184        Poll::Ready(None)
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use futures::io::Cursor;
191    use vortex_array::stream::{ArrayStream, ArrayStreamExt};
192    use vortex_array::{ArraySession, IntoArray as _, ToCanonical};
193    use vortex_buffer::buffer;
194
195    use super::*;
196
197    #[tokio::test]
198    async fn test_async_stream() {
199        let session = ArraySession::default();
200        let array = buffer![1, 2, 3].into_array();
201        let ipc_buffer = array
202            .to_array_stream()
203            .into_ipc()
204            .collect_to_buffer()
205            .await
206            .unwrap();
207
208        let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), session.registry().clone())
209            .await
210            .unwrap();
211
212        assert_eq!(reader.dtype(), array.dtype());
213        let result = reader.read_all().await.unwrap().to_primitive();
214        assert_eq!(
215            array.to_primitive().as_slice::<i32>(),
216            result.as_slice::<i32>()
217        );
218    }
219}