1use std::io::Read;
5use std::io::Write;
6
7use bytes::Bytes;
8use bytes::BytesMut;
9use itertools::Itertools;
10use vortex_array::ArrayRef;
11use vortex_array::iter::ArrayIterator;
12use vortex_array::session::ArrayRegistry;
13use vortex_dtype::DType;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17
18use crate::messages::DecoderMessage;
19use crate::messages::EncoderMessage;
20use crate::messages::MessageEncoder;
21use crate::messages::SyncMessageReader;
22
23pub struct SyncIPCReader<R: Read> {
25 reader: SyncMessageReader<R>,
26 dtype: DType,
27}
28
29impl<R: Read> SyncIPCReader<R> {
30 pub fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
31 let mut reader = SyncMessageReader::new(read, registry);
32 match reader.next().transpose()? {
33 Some(msg) => match msg {
34 DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, dtype }),
35 msg => {
36 vortex_bail!("Expected DType message, got {:?}", msg);
37 }
38 },
39 None => vortex_bail!("Expected DType message, got EOF"),
40 }
41 }
42}
43
44impl<R: Read> ArrayIterator for SyncIPCReader<R> {
45 fn dtype(&self) -> &DType {
46 &self.dtype
47 }
48}
49
50impl<R: Read> Iterator for SyncIPCReader<R> {
51 type Item = VortexResult<ArrayRef>;
52
53 fn next(&mut self) -> Option<Self::Item> {
54 match self.reader.next()? {
55 Ok(msg) => match msg {
56 DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
57 array_parts
58 .decode(&ctx, &self.dtype, row_count)
59 .and_then(|array| {
60 if array.dtype() != self.dtype() {
61 Err(vortex_err!(
62 "Array data type mismatch: expected {:?}, got {:?}",
63 self.dtype(),
64 array.dtype()
65 ))
66 } else {
67 Ok(array)
68 }
69 }),
70 ),
71 msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
72 },
73 Err(e) => Some(Err(e)),
74 }
75 }
76}
77
78pub trait ArrayIteratorIPC {
80 fn into_ipc(self) -> ArrayIteratorIPCBytes
81 where
82 Self: Sized;
83
84 fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
85 where
86 Self: Sized;
87}
88
89impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
90 fn into_ipc(self) -> ArrayIteratorIPCBytes
91 where
92 Self: Sized,
93 {
94 let mut encoder = MessageEncoder::default();
95 let buffers = encoder.encode(EncoderMessage::DType(self.dtype()));
96 ArrayIteratorIPCBytes {
97 inner: Box::new(self),
98 encoder,
99 buffers,
100 }
101 }
102
103 fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
104 where
105 Self: Sized,
106 {
107 let mut stream = self.into_ipc();
108 for buffer in &mut stream {
109 write.write_all(buffer?.as_ref())?;
110 }
111 Ok(write)
112 }
113}
114
115pub struct ArrayIteratorIPCBytes {
116 inner: Box<dyn ArrayIterator + 'static>,
117 encoder: MessageEncoder,
118 buffers: Vec<Bytes>,
119}
120
121impl ArrayIteratorIPCBytes {
122 pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
124 let buffers: Vec<Bytes> = self.try_collect()?;
125 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
126 for buf in buffers {
127 buffer.extend_from_slice(buf.as_ref());
128 }
129 Ok(buffer.freeze())
130 }
131}
132
133impl Iterator for ArrayIteratorIPCBytes {
134 type Item = VortexResult<Bytes>;
135
136 fn next(&mut self) -> Option<Self::Item> {
137 if !self.buffers.is_empty() {
139 return Some(Ok(self.buffers.remove(0)));
140 }
141
142 match self.inner.next()? {
144 Ok(chunk) => {
145 self.buffers
146 .extend(self.encoder.encode(EncoderMessage::Array(&chunk)));
147 }
148 Err(e) => return Some(Err(e)),
149 }
150
151 if !self.buffers.is_empty() {
153 return Some(Ok(self.buffers.remove(0)));
154 }
155
156 None
158 }
159}
160
161#[cfg(test)]
162mod test {
163 use std::io::Cursor;
164
165 use vortex_array::IntoArray as _;
166 use vortex_array::ToCanonical;
167 use vortex_array::iter::ArrayIterator;
168 use vortex_array::iter::ArrayIteratorExt;
169 use vortex_array::session::ArraySession;
170 use vortex_buffer::buffer;
171
172 use super::*;
173
174 #[test]
175 fn test_sync_stream() {
176 let session = ArraySession::default();
177
178 let array = buffer![1i32, 2, 3].into_array();
179 let ipc_buffer = array
180 .to_array_iterator()
181 .into_ipc()
182 .collect_to_buffer()
183 .unwrap();
184
185 let reader =
186 SyncIPCReader::try_new(Cursor::new(ipc_buffer), session.registry().clone()).unwrap();
187
188 assert_eq!(reader.dtype(), array.dtype());
189 let result = reader.read_all().unwrap().to_primitive();
190 assert_eq!(
191 array.to_primitive().as_slice::<i32>(),
192 result.as_slice::<i32>()
193 );
194 }
195}