vortex_ipc/
stream.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Poll, ready};
4
5use bytes::{Bytes, BytesMut};
6use futures_util::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream, StreamExt, TryStreamExt};
7use pin_project_lite::pin_project;
8use vortex_array::stream::ArrayStream;
9use vortex_array::{ArrayRef, ArrayRegistry};
10use vortex_dtype::DType;
11use vortex_error::{VortexResult, vortex_bail, vortex_err};
12
13use crate::messages::{AsyncMessageReader, DecoderMessage, EncoderMessage, MessageEncoder};
14
15pin_project! {
16    /// An [`ArrayStream`] for reading messages off an async IPC stream.
17    pub struct AsyncIPCReader<R> {
18        #[pin]
19        reader: AsyncMessageReader<R>,
20        dtype: DType,
21    }
22}
23
24impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
25    pub async fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
26        let mut reader = AsyncMessageReader::new(read, registry);
27
28        let dtype = match reader.next().await.transpose()? {
29            Some(msg) => match msg {
30                DecoderMessage::DType(dtype) => dtype,
31                msg => {
32                    vortex_bail!("Expected DType message, got {:?}", msg);
33                }
34            },
35            None => vortex_bail!("Expected DType message, got EOF"),
36        };
37
38        Ok(AsyncIPCReader { reader, dtype })
39    }
40}
41
42impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
43    fn dtype(&self) -> &DType {
44        &self.dtype
45    }
46}
47
48impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
49    type Item = VortexResult<ArrayRef>;
50
51    fn poll_next(
52        self: Pin<&mut Self>,
53        cx: &mut std::task::Context<'_>,
54    ) -> Poll<Option<Self::Item>> {
55        let this = self.project();
56
57        match ready!(this.reader.poll_next(cx)) {
58            None => Poll::Ready(None),
59            Some(msg) => match msg {
60                Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
61                    array_parts
62                        .decode(&ctx, this.dtype.clone(), row_count)
63                        .and_then(|array| {
64                            if array.dtype() != this.dtype {
65                                Err(vortex_err!(
66                                    "Array data type mismatch: expected {:?}, got {:?}",
67                                    this.dtype,
68                                    array.dtype()
69                                ))
70                            } else {
71                                Ok(array)
72                            }
73                        }),
74                )),
75                Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
76                    "Expected Array message, got {:?}",
77                    msg
78                )))),
79                Err(e) => Poll::Ready(Some(Err(e))),
80            },
81        }
82    }
83}
84
85/// A trait for convering an [`ArrayStream`] into IPC streams.
86pub trait ArrayStreamIPC {
87    fn into_ipc(self) -> ArrayStreamIPCBytes
88    where
89        Self: Sized;
90
91    fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
92    where
93        Self: Sized;
94}
95
96impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
97    fn into_ipc(self) -> ArrayStreamIPCBytes
98    where
99        Self: Sized,
100    {
101        ArrayStreamIPCBytes {
102            stream: Box::pin(self),
103            encoder: MessageEncoder::default(),
104            buffers: vec![],
105            written_dtype: false,
106        }
107    }
108
109    async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
110    where
111        Self: Sized,
112    {
113        let mut stream = self.into_ipc();
114        while let Some(chunk) = stream.next().await {
115            write.write_all(&chunk?).await?;
116        }
117        Ok(write)
118    }
119}
120
121pub struct ArrayStreamIPCBytes {
122    stream: Pin<Box<dyn ArrayStream + 'static>>,
123    encoder: MessageEncoder,
124    buffers: Vec<Bytes>,
125    written_dtype: bool,
126}
127
128impl ArrayStreamIPCBytes {
129    /// Collects the IPC bytes into a single `Bytes`.
130    pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
131        let buffers: Vec<Bytes> = self.try_collect().await?;
132        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
133        for buf in buffers {
134            buffer.extend_from_slice(buf.as_ref());
135        }
136        Ok(buffer.freeze())
137    }
138}
139
140impl Stream for ArrayStreamIPCBytes {
141    type Item = VortexResult<Bytes>;
142
143    fn poll_next(
144        self: Pin<&mut Self>,
145        cx: &mut std::task::Context<'_>,
146    ) -> Poll<Option<Self::Item>> {
147        let this = self.get_mut();
148
149        // If we haven't written the dtype yet, we write it
150        if !this.written_dtype {
151            this.buffers.extend(
152                this.encoder
153                    .encode(EncoderMessage::DType(this.stream.dtype())),
154            );
155            this.written_dtype = true;
156        }
157
158        // Try to flush any buffers we have
159        if !this.buffers.is_empty() {
160            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
161        }
162
163        // Or else try to serialize the next array
164        match ready!(this.stream.poll_next_unpin(cx)) {
165            None => return Poll::Ready(None),
166            Some(chunk) => match chunk {
167                Ok(chunk) => {
168                    this.buffers
169                        .extend(this.encoder.encode(EncoderMessage::Array(&chunk)));
170                }
171                Err(e) => return Poll::Ready(Some(Err(e))),
172            },
173        }
174
175        // Try to flush any buffers we have again
176        if !this.buffers.is_empty() {
177            return Poll::Ready(Some(Ok(this.buffers.remove(0))));
178        }
179
180        // Otherwise, we're done
181        Poll::Ready(None)
182    }
183}
184
185#[cfg(test)]
186mod test {
187    use futures_util::io::Cursor;
188    use vortex_array::arrays::PrimitiveArray;
189    use vortex_array::stream::{ArrayStream, ArrayStreamArrayExt, ArrayStreamExt};
190    use vortex_array::{Array, ToCanonical};
191
192    use super::*;
193
194    #[tokio::test]
195    async fn test_async_stream() {
196        let array = PrimitiveArray::from_iter([1, 2, 3]);
197        let ipc_buffer = array
198            .to_array_stream()
199            .into_ipc()
200            .collect_to_buffer()
201            .await
202            .unwrap();
203
204        let reader =
205            AsyncIPCReader::try_new(Cursor::new(ipc_buffer), ArrayRegistry::canonical_only())
206                .await
207                .unwrap();
208
209        assert_eq!(reader.dtype(), array.dtype());
210        let result = reader.read_all().await.unwrap().to_primitive().unwrap();
211        assert_eq!(array.as_slice::<i32>(), result.as_slice::<i32>());
212    }
213}