1use std::future::Future;
5use std::pin::Pin;
6use std::task::{Poll, ready};
7
8use bytes::{Bytes, BytesMut};
9use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream, StreamExt, TryStreamExt};
10use pin_project_lite::pin_project;
11use vortex_array::stream::ArrayStream;
12use vortex_array::{ArrayRef, ArrayRegistry};
13use vortex_dtype::DType;
14use vortex_error::{VortexResult, vortex_bail, vortex_err};
15
16use crate::messages::{AsyncMessageReader, DecoderMessage, EncoderMessage, MessageEncoder};
17
18pin_project! {
19 pub struct AsyncIPCReader<R> {
21 #[pin]
22 reader: AsyncMessageReader<R>,
23 dtype: DType,
24 }
25}
26
27impl<R: AsyncRead + Unpin> AsyncIPCReader<R> {
28 pub async fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
29 let mut reader = AsyncMessageReader::new(read, registry);
30
31 let dtype = match reader.next().await.transpose()? {
32 Some(msg) => match msg {
33 DecoderMessage::DType(dtype) => dtype,
34 msg => {
35 vortex_bail!("Expected DType message, got {:?}", msg);
36 }
37 },
38 None => vortex_bail!("Expected DType message, got EOF"),
39 };
40
41 Ok(AsyncIPCReader { reader, dtype })
42 }
43}
44
45impl<R: AsyncRead> ArrayStream for AsyncIPCReader<R> {
46 fn dtype(&self) -> &DType {
47 &self.dtype
48 }
49}
50
51impl<R: AsyncRead> Stream for AsyncIPCReader<R> {
52 type Item = VortexResult<ArrayRef>;
53
54 fn poll_next(
55 self: Pin<&mut Self>,
56 cx: &mut std::task::Context<'_>,
57 ) -> Poll<Option<Self::Item>> {
58 let this = self.project();
59
60 match ready!(this.reader.poll_next(cx)) {
61 None => Poll::Ready(None),
62 Some(msg) => match msg {
63 Ok(DecoderMessage::Array((array_parts, ctx, row_count))) => Poll::Ready(Some(
64 array_parts
65 .decode(&ctx, this.dtype, row_count)
66 .and_then(|array| {
67 if array.dtype() != this.dtype {
68 Err(vortex_err!(
69 "Array data type mismatch: expected {:?}, got {:?}",
70 this.dtype,
71 array.dtype()
72 ))
73 } else {
74 Ok(array)
75 }
76 }),
77 )),
78 Ok(msg) => Poll::Ready(Some(Err(vortex_err!(
79 "Expected Array message, got {:?}",
80 msg
81 )))),
82 Err(e) => Poll::Ready(Some(Err(e))),
83 },
84 }
85 }
86}
87
88pub trait ArrayStreamIPC {
90 fn into_ipc(self) -> ArrayStreamIPCBytes
91 where
92 Self: Sized;
93
94 fn write_ipc<W: AsyncWrite + Unpin>(self, write: W) -> impl Future<Output = VortexResult<W>>
95 where
96 Self: Sized;
97}
98
99impl<S: ArrayStream + 'static> ArrayStreamIPC for S {
100 fn into_ipc(self) -> ArrayStreamIPCBytes
101 where
102 Self: Sized,
103 {
104 ArrayStreamIPCBytes {
105 stream: Box::pin(self),
106 encoder: MessageEncoder::default(),
107 buffers: vec![],
108 written_dtype: false,
109 }
110 }
111
112 async fn write_ipc<W: AsyncWrite + Unpin>(self, mut write: W) -> VortexResult<W>
113 where
114 Self: Sized,
115 {
116 let mut stream = self.into_ipc();
117 while let Some(chunk) = stream.next().await {
118 write.write_all(&chunk?).await?;
119 }
120 Ok(write)
121 }
122}
123
124pub struct ArrayStreamIPCBytes {
125 stream: Pin<Box<dyn ArrayStream + 'static>>,
126 encoder: MessageEncoder,
127 buffers: Vec<Bytes>,
128 written_dtype: bool,
129}
130
131impl ArrayStreamIPCBytes {
132 pub async fn collect_to_buffer(self) -> VortexResult<Bytes> {
134 let buffers: Vec<Bytes> = self.try_collect().await?;
135 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
136 for buf in buffers {
137 buffer.extend_from_slice(buf.as_ref());
138 }
139 Ok(buffer.freeze())
140 }
141}
142
143impl Stream for ArrayStreamIPCBytes {
144 type Item = VortexResult<Bytes>;
145
146 fn poll_next(
147 self: Pin<&mut Self>,
148 cx: &mut std::task::Context<'_>,
149 ) -> Poll<Option<Self::Item>> {
150 let this = self.get_mut();
151
152 if !this.written_dtype {
154 this.buffers.extend(
155 this.encoder
156 .encode(EncoderMessage::DType(this.stream.dtype())),
157 );
158 this.written_dtype = true;
159 }
160
161 if !this.buffers.is_empty() {
163 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
164 }
165
166 match ready!(this.stream.poll_next_unpin(cx)) {
168 None => return Poll::Ready(None),
169 Some(chunk) => match chunk {
170 Ok(chunk) => {
171 this.buffers
172 .extend(this.encoder.encode(EncoderMessage::Array(&chunk)));
173 }
174 Err(e) => return Poll::Ready(Some(Err(e))),
175 },
176 }
177
178 if !this.buffers.is_empty() {
180 return Poll::Ready(Some(Ok(this.buffers.remove(0))));
181 }
182
183 Poll::Ready(None)
185 }
186}
187
188#[cfg(test)]
189mod test {
190 use futures::io::Cursor;
191 use vortex_array::stream::{ArrayStream, ArrayStreamExt};
192 use vortex_array::{ArraySession, IntoArray as _, ToCanonical};
193 use vortex_buffer::buffer;
194
195 use super::*;
196
197 #[tokio::test]
198 async fn test_async_stream() {
199 let session = ArraySession::default();
200 let array = buffer![1, 2, 3].into_array();
201 let ipc_buffer = array
202 .to_array_stream()
203 .into_ipc()
204 .collect_to_buffer()
205 .await
206 .unwrap();
207
208 let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), session.registry().clone())
209 .await
210 .unwrap();
211
212 assert_eq!(reader.dtype(), array.dtype());
213 let result = reader.read_all().await.unwrap().to_primitive();
214 assert_eq!(
215 array.to_primitive().as_slice::<i32>(),
216 result.as_slice::<i32>()
217 );
218 }
219}