1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Poll, ready};
4
5use bytes::{Bytes, BytesMut};
6use futures_util::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream, StreamExt, TryStreamExt};
7use pin_project_lite::pin_project;
8use vortex_array::stream::ArrayStream;
9use vortex_array::{ArrayRef, ArrayRegistry};
10use vortex_dtype::DType;
11use vortex_error::{VortexResult, vortex_bail, vortex_err};
12
13use crate::messages::{AsyncMessageReader, DecoderMessage, EncoderMessage, MessageEncoder};
14
15pin_project! {
16 pub struct AsyncIPCReader<R> {
18 #[pin]
19 reader: AsyncMessageReader<R>,
20 dtype: DType,
21 }
22}
23
24impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
25 pub async fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
26 let mut reader = AsyncMessageReader::new(read, registry);
27
28 let dtype = match reader.next().await.transpose()? {
29 Some(msg) => match msg {
30 DecoderMessage::DType(dtype) => dtype,
31 msg => {
32 vortex_bail!("Expected DType message, got {:?}", msg);
33 }
34 },
35 None => vortex_bail!("Expected DType message, got EOF"),
36 };
37
38 Ok(AsyncIPCReader { reader, dtype })
39 }
40}
41
42impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
43 fn dtype(&self) -> &DType {
44 &self.dtype
45 }
46}
47
48impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
49 type Item = VortexResult<ArrayRef>;
50
51 fn poll_next(
52 self: Pin<&mut Self>,
53 cx: &mut std::task::Context<'_>,
54 ) -> Poll<Option<Self::Item>> {
55 let this = self.project();
56
57 match ready!(this.reader.poll_next(cx)) {
58 None => Poll::Ready(None),
59 Some(msg) => match msg {
60 Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
61 array_parts
62 .decode(&ctx, this.dtype.clone(), row_count)
63 .and_then(|array| {
64 if array.dtype() != this.dtype {
65 Err(vortex_err!(
66 "Array data type mismatch: expected {:?}, got {:?}",
67 this.dtype,
68 array.dtype()
69 ))
70 } else {
71 Ok(array)
72 }
73 }),
74 )),
75 Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
76 "Expected Array message, got {:?}",
77 msg
78 )))),
79 Err(e) => Poll::Ready(Some(Err(e))),
80 },
81 }
82 }
83}
84
85pub trait ArrayStreamIPC {
87 fn into_ipc(self) -> ArrayStreamIPCBytes
88 where
89 Self: Sized;
90
91 fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
92 where
93 Self: Sized;
94}
95
96impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
97 fn into_ipc(self) -> ArrayStreamIPCBytes
98 where
99 Self: Sized,
100 {
101 ArrayStreamIPCBytes {
102 stream: Box::pin(self),
103 encoder: MessageEncoder::default(),
104 buffers: vec![],
105 written_dtype: false,
106 }
107 }
108
109 async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
110 where
111 Self: Sized,
112 {
113 let mut stream = self.into_ipc();
114 while let Some(chunk) = stream.next().await {
115 write.write_all(&chunk?).await?;
116 }
117 Ok(write)
118 }
119}
120
121pub struct ArrayStreamIPCBytes {
122 stream: Pin<Box<dyn ArrayStream + 'static>>,
123 encoder: MessageEncoder,
124 buffers: Vec<Bytes>,
125 written_dtype: bool,
126}
127
128impl ArrayStreamIPCBytes {
129 pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
131 let buffers: Vec<Bytes> = self.try_collect().await?;
132 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
133 for buf in buffers {
134 buffer.extend_from_slice(buf.as_ref());
135 }
136 Ok(buffer.freeze())
137 }
138}
139
140impl Stream for ArrayStreamIPCBytes {
141 type Item = VortexResult<Bytes>;
142
143 fn poll_next(
144 self: Pin<&mut Self>,
145 cx: &mut std::task::Context<'_>,
146 ) -> Poll<Option<Self::Item>> {
147 let this = self.get_mut();
148
149 if !this.written_dtype {
151 this.buffers.extend(
152 this.encoder
153 .encode(EncoderMessage::DType(this.stream.dtype())),
154 );
155 this.written_dtype = true;
156 }
157
158 if !this.buffers.is_empty() {
160 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
161 }
162
163 match ready!(this.stream.poll_next_unpin(cx)) {
165 None => return Poll::Ready(None),
166 Some(chunk) => match chunk {
167 Ok(chunk) => {
168 this.buffers
169 .extend(this.encoder.encode(EncoderMessage::Array(&chunk)));
170 }
171 Err(e) => return Poll::Ready(Some(Err(e))),
172 },
173 }
174
175 if !this.buffers.is_empty() {
177 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
178 }
179
180 Poll::Ready(None)
182 }
183}
184
185#[cfg(test)]
186mod test {
187 use futures_util::io::Cursor;
188 use vortex_array::arrays::PrimitiveArray;
189 use vortex_array::stream::{ArrayStream, ArrayStreamArrayExt, ArrayStreamExt};
190 use vortex_array::{Array, ToCanonical};
191
192 use super::*;
193
194 #[tokio::test]
195 async fn test_async_stream() {
196 let array = PrimitiveArray::from_iter([1, 2, 3]);
197 let ipc_buffer = array
198 .to_array_stream()
199 .into_ipc()
200 .collect_to_buffer()
201 .await
202 .unwrap();
203
204 let reader =
205 AsyncIPCReader::try_new(Cursor::new(ipc_buffer), ArrayRegistry::canonical_only())
206 .await
207 .unwrap();
208
209 assert_eq!(reader.dtype(), array.dtype());
210 let result = reader.read_all().await.unwrap().to_primitive().unwrap();
211 assert_eq!(array.as_slice::<i32>(), result.as_slice::<i32>());
212 }
213}