1use std::io::{Read, Write};
5
6use bytes::{Bytes, BytesMut};
7use itertools::Itertools;
8use vortex_array::iter::ArrayIterator;
9use vortex_array::{ArrayRef, ArrayRegistry};
10use vortex_dtype::DType;
11use vortex_error::{VortexResult, vortex_bail, vortex_err};
12
13use crate::messages::{DecoderMessage, EncoderMessage, MessageEncoder, SyncMessageReader};
14
15pub struct SyncIPCReader<R: Read> {
17 reader: SyncMessageReader<R>,
18 dtype: DType,
19}
20
21impl<R: Read> SyncIPCReader<R> {
22 pub fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
23 let mut reader = SyncMessageReader::new(read, registry);
24 match reader.next().transpose()? {
25 Some(msg) => match msg {
26 DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, dtype }),
27 msg => {
28 vortex_bail!("Expected DType message, got {:?}", msg);
29 }
30 },
31 None => vortex_bail!("Expected DType message, got EOF"),
32 }
33 }
34}
35
36impl<R: Read> ArrayIterator for SyncIPCReader<R> {
37 fn dtype(&self) -> &DType {
38 &self.dtype
39 }
40}
41
42impl<R: Read> Iterator for SyncIPCReader<R> {
43 type Item = VortexResult<ArrayRef>;
44
45 fn next(&mut self) -> Option<Self::Item> {
46 match self.reader.next()? {
47 Ok(msg) => match msg {
48 DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
49 array_parts
50 .decode(&ctx, &self.dtype, row_count)
51 .and_then(|array| {
52 if array.dtype() != self.dtype() {
53 Err(vortex_err!(
54 "Array data type mismatch: expected {:?}, got {:?}",
55 self.dtype(),
56 array.dtype()
57 ))
58 } else {
59 Ok(array)
60 }
61 }),
62 ),
63 msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
64 },
65 Err(e) => Some(Err(e)),
66 }
67 }
68}
69
70pub trait ArrayIteratorIPC {
72 fn into_ipc(self) -> ArrayIteratorIPCBytes
73 where
74 Self: Sized;
75
76 fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
77 where
78 Self: Sized;
79}
80
81impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
82 fn into_ipc(self) -> ArrayIteratorIPCBytes
83 where
84 Self: Sized,
85 {
86 let mut encoder = MessageEncoder::default();
87 let buffers = encoder.encode(EncoderMessage::DType(self.dtype()));
88 ArrayIteratorIPCBytes {
89 inner: Box::new(self),
90 encoder,
91 buffers,
92 }
93 }
94
95 fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
96 where
97 Self: Sized,
98 {
99 let mut stream = self.into_ipc();
100 for buffer in &mut stream {
101 write.write_all(buffer?.as_ref())?;
102 }
103 Ok(write)
104 }
105}
106
107pub struct ArrayIteratorIPCBytes {
108 inner: Box<dyn ArrayIterator + 'static>,
109 encoder: MessageEncoder,
110 buffers: Vec<Bytes>,
111}
112
113impl ArrayIteratorIPCBytes {
114 pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
116 let buffers: Vec<Bytes> = self.try_collect()?;
117 let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
118 for buf in buffers {
119 buffer.extend_from_slice(buf.as_ref());
120 }
121 Ok(buffer.freeze())
122 }
123}
124
125impl Iterator for ArrayIteratorIPCBytes {
126 type Item = VortexResult<Bytes>;
127
128 fn next(&mut self) -> Option<Self::Item> {
129 if !self.buffers.is_empty() {
131 return Some(Ok(self.buffers.remove(0)));
132 }
133
134 match self.inner.next()? {
136 Ok(chunk) => {
137 self.buffers
138 .extend(self.encoder.encode(EncoderMessage::Array(&chunk)));
139 }
140 Err(e) => return Some(Err(e)),
141 }
142
143 if !self.buffers.is_empty() {
145 return Some(Ok(self.buffers.remove(0)));
146 }
147
148 None
150 }
151}
152
153#[cfg(test)]
154mod test {
155 use std::io::Cursor;
156
157 use vortex_array::iter::{ArrayIterator, ArrayIteratorExt};
158 use vortex_array::{IntoArray as _, ToCanonical};
159 use vortex_buffer::buffer;
160
161 use super::*;
162
163 #[test]
164 fn test_sync_stream() {
165 let array = buffer![1i32, 2, 3].into_array();
166 let ipc_buffer = array
167 .to_array_iterator()
168 .into_ipc()
169 .collect_to_buffer()
170 .unwrap();
171
172 let reader =
173 SyncIPCReader::try_new(Cursor::new(ipc_buffer), ArrayRegistry::canonical_only())
174 .unwrap();
175
176 assert_eq!(reader.dtype(), array.dtype());
177 let result = reader.read_all().unwrap().to_primitive();
178 assert_eq!(
179 array.to_primitive().as_slice::<i32>(),
180 result.as_slice::<i32>()
181 );
182 }
183}