1use std::future::Future;
5use std::pin::Pin;
6use std::task::Poll;
7use std::task::ready;
8
9use bytes::Bytes;
10use bytes::BytesMut;
11use futures::AsyncRead;
12use futures::AsyncWrite;
13use futures::AsyncWriteExt;
14use futures::Stream;
15use futures::StreamExt;
16use futures::TryStreamExt;
17use pin_project_lite::pin_project;
18use vortex_array::ArrayRef;
19use vortex_array::dtype::DType;
20use vortex_array::stream::ArrayStream;
21use vortex_error::VortexResult;
22use vortex_error::vortex_bail;
23use vortex_error::vortex_err;
24use vortex_session::VortexSession;
25
26use crate::messages::AsyncMessageReader;
27use crate::messages::DecoderMessage;
28use crate::messages::EncoderMessage;
29use crate::messages::MessageEncoder;
30
31pin_project! {
32 pub struct AsyncIPCReader<R> {
34 #[pin]
35 reader: AsyncMessageReader<R>,
36 dtype: DType,
37 session: VortexSession,
38 }
39}
40
41impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
42 pub async fn try_new(read: R, session: &VortexSession) -> VortexResult<Self> {
43 let mut reader = AsyncMessageReader::new(read);
44
45 let dtype = match reader.next().await.transpose()? {
46 Some(msg) => match msg {
47 DecoderMessage::DType(dtype) => dtype,
48 msg => {
49 vortex_bail!("Expected DType message, got {:?}", msg);
50 }
51 },
52 None => vortex_bail!("Expected DType message, got EOF"),
53 };
54
55 let dtype = DType::from_flatbuffer(dtype, session)?;
56
57 Ok(AsyncIPCReader {
58 reader,
59 dtype,
60 session: session.clone(),
61 })
62 }
63}
64
65impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
66 fn dtype(&self) -> &DType {
67 &self.dtype
68 }
69}
70
71impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
72 type Item = VortexResult<ArrayRef>;
73
74 fn poll_next(
75 self: Pin<&mut Self>,
76 cx: &mut std::task::Context<'_>,
77 ) -> Poll<Option<Self::Item>> {
78 let this = self.project();
79
80 match ready!(this.reader.poll_next(cx)) {
81 None => Poll::Ready(None),
82 Some(msg) => match msg {
83 Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
84 array_parts
85 .decode(this.dtype, row_count, &ctx, this.session)
86 .and_then(|array| {
87 if array.dtype() != this.dtype {
88 Err(vortex_err!(
89 "Array data type mismatch: expected {:?}, got {:?}",
90 this.dtype,
91 array.dtype()
92 ))
93 } else {
94 Ok(array)
95 }
96 }),
97 )),
98 Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
99 "Expected Array message, got {:?}",
100 msg
101 )))),
102 Err(e) => Poll::Ready(Some(Err(e))),
103 },
104 }
105 }
106}
107
108pub trait ArrayStreamIPC {
110 fn into_ipc(self) -> ArrayStreamIPCBytes
111 where
112 Self: Sized;
113
114 fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
115 where
116 Self: Sized;
117}
118
119impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
120 fn into_ipc(self) -> ArrayStreamIPCBytes
121 where
122 Self: Sized,
123 {
124 ArrayStreamIPCBytes {
125 stream: Box::pin(self),
126 encoder: MessageEncoder::default(),
127 buffers: vec![],
128 written_dtype: false,
129 }
130 }
131
132 async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
133 where
134 Self: Sized,
135 {
136 let mut stream = self.into_ipc();
137 while let Some(chunk) = stream.next().await {
138 write.write_all(&chunk?).await?;
139 }
140 Ok(write)
141 }
142}
143
144pub struct ArrayStreamIPCBytes {
145 stream: Pin<Box<dyn ArrayStream + 'static>>,
146 encoder: MessageEncoder,
147 buffers: Vec<Bytes>,
148 written_dtype: bool,
149}
150
151impl ArrayStreamIPCBytes {
152 pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
154 let buffers: Vec<Bytes> = self.try_collect().await?;
155 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
156 for buf in buffers {
157 buffer.extend_from_slice(buf.as_ref());
158 }
159 Ok(buffer.freeze())
160 }
161}
162
163impl Stream for ArrayStreamIPCBytes {
164 type Item = VortexResult<Bytes>;
165
166 fn poll_next(
167 self: Pin<&mut Self>,
168 cx: &mut std::task::Context<'_>,
169 ) -> Poll<Option<Self::Item>> {
170 let this = self.get_mut();
171
172 if !this.written_dtype {
174 let Ok(buffers) = this
175 .encoder
176 .encode(EncoderMessage::DType(this.stream.dtype()))
177 else {
178 return Poll::Ready(Some(Err(vortex_err!("Failed to encode DType message"))));
179 };
180 this.buffers.extend(buffers);
181 this.written_dtype = true;
182 }
183
184 if !this.buffers.is_empty() {
186 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
187 }
188
189 match ready!(this.stream.poll_next_unpin(cx)) {
191 None => return Poll::Ready(None),
192 Some(chunk) => match chunk.and_then(|c| this.encoder.encode(EncoderMessage::Array(&c)))
193 {
194 Ok(buffers) => {
195 this.buffers.extend(buffers);
196 }
197 Err(e) => return Poll::Ready(Some(Err(e))),
198 },
199 }
200
201 if !this.buffers.is_empty() {
203 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
204 }
205
206 Poll::Ready(None)
208 }
209}
210
211#[cfg(test)]
212mod test {
213 use std::io;
214 use std::pin::Pin;
215 use std::task::Context;
216 use std::task::Poll;
217
218 use futures::io::Cursor;
219 use vortex_array::IntoArray as _;
220 use vortex_array::assert_arrays_eq;
221 use vortex_array::stream::ArrayStream;
222 use vortex_array::stream::ArrayStreamExt;
223 use vortex_buffer::buffer;
224
225 use super::*;
226 use crate::test::SESSION;
227
228 #[tokio::test]
229 async fn test_async_stream() {
230 let array = buffer![1, 2, 3].into_array();
231 let ipc_buffer = array
232 .to_array_stream()
233 .into_ipc()
234 .collect_to_buffer()
235 .await
236 .unwrap();
237
238 let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), &SESSION)
239 .await
240 .unwrap();
241
242 assert_eq!(reader.dtype(), array.dtype());
243 let result = reader.read_all().await.unwrap();
244 assert_arrays_eq!(result, array);
245 }
246
247 struct ChunkedReader<R> {
249 inner: R,
250 chunk_size: usize,
251 }
252
253 impl<R: AsyncRead + Unpin> AsyncRead for ChunkedReader<R> {
254 fn poll_read(
255 mut self: Pin<&mut Self>,
256 cx: &mut Context<'_>,
257 buf: &mut [u8],
258 ) -> Poll<io::Result<usize>> {
259 let chunk_size = self.chunk_size.min(buf.len());
260 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..chunk_size])
261 }
262 }
263
264 #[tokio::test]
265 async fn test_async_stream_chunked() {
266 let array = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
267 let ipc_buffer = array
268 .to_array_stream()
269 .into_ipc()
270 .collect_to_buffer()
271 .await
272 .unwrap();
273
274 let chunked = ChunkedReader {
275 inner: Cursor::new(ipc_buffer),
276 chunk_size: 3,
277 };
278
279 let reader = AsyncIPCReader::try_new(chunked, &SESSION).await.unwrap();
280
281 let result = reader.read_all().await.unwrap();
282 let expected = buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array();
283 assert_arrays_eq!(result, expected);
284 }
285
286 #[tokio::test]
288 async fn test_async_stream_single_byte_chunks() {
289 let array = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
290 let ipc_buffer = array
291 .to_array_stream()
292 .into_ipc()
293 .collect_to_buffer()
294 .await
295 .unwrap();
296
297 let chunked = ChunkedReader {
298 inner: Cursor::new(ipc_buffer),
299 chunk_size: 1,
300 };
301
302 let reader = AsyncIPCReader::try_new(chunked, &SESSION).await.unwrap();
303
304 let result = reader.read_all().await.unwrap();
305 let expected = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
306 assert_arrays_eq!(result, expected);
307 }
308}