1use std::future::Future;
5use std::pin::Pin;
6use std::task::Poll;
7use std::task::ready;
8
9use bytes::Bytes;
10use bytes::BytesMut;
11use futures::AsyncRead;
12use futures::AsyncWrite;
13use futures::AsyncWriteExt;
14use futures::Stream;
15use futures::StreamExt;
16use futures::TryStreamExt;
17use pin_project_lite::pin_project;
18use vortex_array::ArrayRef;
19use vortex_array::session::ArrayRegistry;
20use vortex_array::stream::ArrayStream;
21use vortex_dtype::DType;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_error::vortex_err;
25
26use crate::messages::AsyncMessageReader;
27use crate::messages::DecoderMessage;
28use crate::messages::EncoderMessage;
29use crate::messages::MessageEncoder;
30
31pin_project! {
32 pub struct AsyncIPCReader<R> {
34 #[pin]
35 reader: AsyncMessageReader<R>,
36 dtype: DType,
37 }
38}
39
40impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
41 pub async fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
42 let mut reader = AsyncMessageReader::new(read, registry);
43
44 let dtype = match reader.next().await.transpose()? {
45 Some(msg) => match msg {
46 DecoderMessage::DType(dtype) => dtype,
47 msg => {
48 vortex_bail!("Expected DType message, got {:?}", msg);
49 }
50 },
51 None => vortex_bail!("Expected DType message, got EOF"),
52 };
53
54 Ok(AsyncIPCReader { reader, dtype })
55 }
56}
57
58impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
59 fn dtype(&self) -> &DType {
60 &self.dtype
61 }
62}
63
64impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
65 type Item = VortexResult<ArrayRef>;
66
67 fn poll_next(
68 self: Pin<&mut Self>,
69 cx: &mut std::task::Context<'_>,
70 ) -> Poll<Option<Self::Item>> {
71 let this = self.project();
72
73 match ready!(this.reader.poll_next(cx)) {
74 None => Poll::Ready(None),
75 Some(msg) => match msg {
76 Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
77 array_parts
78 .decode(&ctx, this.dtype, row_count)
79 .and_then(|array| {
80 if array.dtype() != this.dtype {
81 Err(vortex_err!(
82 "Array data type mismatch: expected {:?}, got {:?}",
83 this.dtype,
84 array.dtype()
85 ))
86 } else {
87 Ok(array)
88 }
89 }),
90 )),
91 Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
92 "Expected Array message, got {:?}",
93 msg
94 )))),
95 Err(e) => Poll::Ready(Some(Err(e))),
96 },
97 }
98 }
99}
100
101pub trait ArrayStreamIPC {
103 fn into_ipc(self) -> ArrayStreamIPCBytes
104 where
105 Self: Sized;
106
107 fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
108 where
109 Self: Sized;
110}
111
112impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
113 fn into_ipc(self) -> ArrayStreamIPCBytes
114 where
115 Self: Sized,
116 {
117 ArrayStreamIPCBytes {
118 stream: Box::pin(self),
119 encoder: MessageEncoder::default(),
120 buffers: vec![],
121 written_dtype: false,
122 }
123 }
124
125 async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
126 where
127 Self: Sized,
128 {
129 let mut stream = self.into_ipc();
130 while let Some(chunk) = stream.next().await {
131 write.write_all(&chunk?).await?;
132 }
133 Ok(write)
134 }
135}
136
137pub struct ArrayStreamIPCBytes {
138 stream: Pin<Box<dyn ArrayStream + 'static>>,
139 encoder: MessageEncoder,
140 buffers: Vec<Bytes>,
141 written_dtype: bool,
142}
143
144impl ArrayStreamIPCBytes {
145 pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
147 let buffers: Vec<Bytes> = self.try_collect().await?;
148 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
149 for buf in buffers {
150 buffer.extend_from_slice(buf.as_ref());
151 }
152 Ok(buffer.freeze())
153 }
154}
155
156impl Stream for ArrayStreamIPCBytes {
157 type Item = VortexResult<Bytes>;
158
159 fn poll_next(
160 self: Pin<&mut Self>,
161 cx: &mut std::task::Context<'_>,
162 ) -> Poll<Option<Self::Item>> {
163 let this = self.get_mut();
164
165 if !this.written_dtype {
167 this.buffers.extend(
168 this.encoder
169 .encode(EncoderMessage::DType(this.stream.dtype())),
170 );
171 this.written_dtype = true;
172 }
173
174 if !this.buffers.is_empty() {
176 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
177 }
178
179 match ready!(this.stream.poll_next_unpin(cx)) {
181 None => return Poll::Ready(None),
182 Some(chunk) => match chunk {
183 Ok(chunk) => {
184 this.buffers
185 .extend(this.encoder.encode(EncoderMessage::Array(&chunk)));
186 }
187 Err(e) => return Poll::Ready(Some(Err(e))),
188 },
189 }
190
191 if !this.buffers.is_empty() {
193 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
194 }
195
196 Poll::Ready(None)
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use std::io;
204 use std::pin::Pin;
205 use std::task::Context;
206 use std::task::Poll;
207
208 use futures::io::Cursor;
209 use vortex_array::IntoArray as _;
210 use vortex_array::ToCanonical;
211 use vortex_array::session::ArraySession;
212 use vortex_array::stream::ArrayStream;
213 use vortex_array::stream::ArrayStreamExt;
214 use vortex_buffer::buffer;
215
216 use super::*;
217
218 #[tokio::test]
219 async fn test_async_stream() {
220 let session = ArraySession::default();
221 let array = buffer![1, 2, 3].into_array();
222 let ipc_buffer = array
223 .to_array_stream()
224 .into_ipc()
225 .collect_to_buffer()
226 .await
227 .unwrap();
228
229 let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), session.registry().clone())
230 .await
231 .unwrap();
232
233 assert_eq!(reader.dtype(), array.dtype());
234 let result = reader.read_all().await.unwrap().to_primitive();
235 assert_eq!(
236 array.to_primitive().as_slice::<i32>(),
237 result.as_slice::<i32>()
238 );
239 }
240
241 struct ChunkedReader<R> {
243 inner: R,
244 chunk_size: usize,
245 }
246
247 impl<R: AsyncRead + Unpin> AsyncRead for ChunkedReader<R> {
248 fn poll_read(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &mut [u8],
252 ) -> Poll<io::Result<usize>> {
253 let chunk_size = self.chunk_size.min(buf.len());
254 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..chunk_size])
255 }
256 }
257
258 #[tokio::test]
259 async fn test_async_stream_chunked() {
260 let session = ArraySession::default();
261 let array = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
262 let ipc_buffer = array
263 .to_array_stream()
264 .into_ipc()
265 .collect_to_buffer()
266 .await
267 .unwrap();
268
269 let chunked = ChunkedReader {
270 inner: Cursor::new(ipc_buffer),
271 chunk_size: 3,
272 };
273
274 let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
275 .await
276 .unwrap();
277
278 let result = reader.read_all().await.unwrap().to_primitive();
279 assert_eq!(
280 &[1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10],
281 result.as_slice::<i32>()
282 );
283 }
284
285 #[tokio::test]
287 async fn test_async_stream_single_byte_chunks() {
288 let session = ArraySession::default();
289 let array = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
290 let ipc_buffer = array
291 .to_array_stream()
292 .into_ipc()
293 .collect_to_buffer()
294 .await
295 .unwrap();
296
297 let chunked = ChunkedReader {
298 inner: Cursor::new(ipc_buffer),
299 chunk_size: 1,
300 };
301
302 let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
303 .await
304 .unwrap();
305
306 let result = reader.read_all().await.unwrap().to_primitive();
307 assert_eq!(
308 &[42i64, -1, 0, i64::MAX, i64::MIN],
309 result.as_slice::<i64>()
310 );
311 }
312}