1use std::io::{Read, Write};
2
3use bytes::{Bytes, BytesMut};
4use itertools::Itertools;
5use vortex_array::iter::ArrayIterator;
6use vortex_array::{ArrayRef, ArrayRegistry};
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_bail, vortex_err};
9
10use crate::messages::{DecoderMessage, EncoderMessage, MessageEncoder, SyncMessageReader};
11
12pub struct SyncIPCReader<R: Read> {
14 reader: SyncMessageReader<R>,
15 dtype: DType,
16}
17
18impl<R: Read> SyncIPCReader<R> {
19 pub fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
20 let mut reader = SyncMessageReader::new(read, registry);
21 match reader.next().transpose()? {
22 Some(msg) => match msg {
23 DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, dtype }),
24 msg => {
25 vortex_bail!("Expected DType message, got {:?}", msg);
26 }
27 },
28 None => vortex_bail!("Expected DType message, got EOF"),
29 }
30 }
31}
32
33impl<R: Read> ArrayIterator for SyncIPCReader<R> {
34 fn dtype(&self) -> &DType {
35 &self.dtype
36 }
37}
38
39impl<R: Read> Iterator for SyncIPCReader<R> {
40 type Item = VortexResult<ArrayRef>;
41
42 fn next(&mut self) -> Option<Self::Item> {
43 match self.reader.next()? {
44 Ok(msg) => match msg {
45 DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
46 array_parts
47 .decode(&ctx, &self.dtype, row_count)
48 .and_then(|array| {
49 if array.dtype() != self.dtype() {
50 Err(vortex_err!(
51 "Array data type mismatch: expected {:?}, got {:?}",
52 self.dtype(),
53 array.dtype()
54 ))
55 } else {
56 Ok(array)
57 }
58 }),
59 ),
60 msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
61 },
62 Err(e) => Some(Err(e)),
63 }
64 }
65}
66
67pub trait ArrayIteratorIPC {
69 fn into_ipc(self) -> ArrayIteratorIPCBytes
70 where
71 Self: Sized;
72
73 fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
74 where
75 Self: Sized;
76}
77
78impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
79 fn into_ipc(self) -> ArrayIteratorIPCBytes
80 where
81 Self: Sized,
82 {
83 let mut encoder = MessageEncoder::default();
84 let buffers = encoder.encode(EncoderMessage::DType(self.dtype()));
85 ArrayIteratorIPCBytes {
86 inner: Box::new(self),
87 encoder,
88 buffers,
89 }
90 }
91
92 fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
93 where
94 Self: Sized,
95 {
96 let mut stream = self.into_ipc();
97 for buffer in &mut stream {
98 write.write_all(buffer?.as_ref())?;
99 }
100 Ok(write)
101 }
102}
103
104pub struct ArrayIteratorIPCBytes {
105 inner: Box<dyn ArrayIterator + 'static>,
106 encoder: MessageEncoder,
107 buffers: Vec<Bytes>,
108}
109
110impl ArrayIteratorIPCBytes {
111 pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
113 let buffers: Vec<Bytes> = self.try_collect()?;
114 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
115 for buf in buffers {
116 buffer.extend_from_slice(buf.as_ref());
117 }
118 Ok(buffer.freeze())
119 }
120}
121
122impl Iterator for ArrayIteratorIPCBytes {
123 type Item = VortexResult<Bytes>;
124
125 fn next(&mut self) -> Option<Self::Item> {
126 if !self.buffers.is_empty() {
128 return Some(Ok(self.buffers.remove(0)));
129 }
130
131 match self.inner.next()? {
133 Ok(chunk) => {
134 self.buffers
135 .extend(self.encoder.encode(EncoderMessage::Array(&chunk)));
136 }
137 Err(e) => return Some(Err(e)),
138 }
139
140 if !self.buffers.is_empty() {
142 return Some(Ok(self.buffers.remove(0)));
143 }
144
145 None
147 }
148}
149
150#[cfg(test)]
151mod test {
152 use std::io::Cursor;
153
154 use vortex_array::ToCanonical;
155 use vortex_array::arrays::PrimitiveArray;
156 use vortex_array::iter::{ArrayIterator, ArrayIteratorExt};
157
158 use super::*;
159
160 #[test]
161 fn test_sync_stream() {
162 let array = PrimitiveArray::from_iter([1i32, 2, 3]);
163 let ipc_buffer = array
164 .to_array_iterator()
165 .into_ipc()
166 .collect_to_buffer()
167 .unwrap();
168
169 let reader =
170 SyncIPCReader::try_new(Cursor::new(ipc_buffer), ArrayRegistry::canonical_only())
171 .unwrap();
172
173 assert_eq!(reader.dtype(), array.dtype());
174 let result = reader.read_all().unwrap().to_primitive().unwrap();
175 assert_eq!(array.as_slice::<i32>(), result.as_slice::<i32>());
176 }
177}