vortex_ipc/
iterator.rs

1use std::io::{Read, Write};
2
3use bytes::{Bytes, BytesMut};
4use itertools::Itertools;
5use vortex_array::iter::ArrayIterator;
6use vortex_array::{ArrayRef, ArrayRegistry};
7use vortex_dtype::DType;
8use vortex_error::{VortexResult, vortex_bail, vortex_err};
9
10use crate::messages::{DecoderMessage, EncoderMessage, MessageEncoder, SyncMessageReader};
11
12/// An [`ArrayIterator`] for reading messages off an IPC stream.
13pub struct SyncIPCReader<R: Read> {
14    reader: SyncMessageReader<R>,
15    dtype: DType,
16}
17
18impl<R: Read> SyncIPCReader<R> {
19    pub fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
20        let mut reader = SyncMessageReader::new(read, registry);
21        match reader.next().transpose()? {
22            Some(msg) => match msg {
23                DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, dtype }),
24                msg => {
25                    vortex_bail!("Expected DType message, got {:?}", msg);
26                }
27            },
28            None => vortex_bail!("Expected DType message, got EOF"),
29        }
30    }
31}
32
33impl<R: Read> ArrayIterator for SyncIPCReader<R> {
34    fn dtype(&self) -> &DType {
35        &self.dtype
36    }
37}
38
39impl<R: Read> Iterator for SyncIPCReader<R> {
40    type Item = VortexResult<ArrayRef>;
41
42    fn next(&mut self) -> Option<Self::Item> {
43        match self.reader.next()? {
44            Ok(msg) => match msg {
45                DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
46                    array_parts
47                        .decode(&ctx, &self.dtype, row_count)
48                        .and_then(|array| {
49                            if array.dtype() != self.dtype() {
50                                Err(vortex_err!(
51                                    "Array data type mismatch: expected {:?}, got {:?}",
52                                    self.dtype(),
53                                    array.dtype()
54                                ))
55                            } else {
56                                Ok(array)
57                            }
58                        }),
59                ),
60                msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
61            },
62            Err(e) => Some(Err(e)),
63        }
64    }
65}
66
67/// A trait for converting an [`ArrayIterator`] into an IPC stream.
68pub trait ArrayIteratorIPC {
69    fn into_ipc(self) -> ArrayIteratorIPCBytes
70    where
71        Self: Sized;
72
73    fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
74    where
75        Self: Sized;
76}
77
78impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
79    fn into_ipc(self) -> ArrayIteratorIPCBytes
80    where
81        Self: Sized,
82    {
83        let mut encoder = MessageEncoder::default();
84        let buffers = encoder.encode(EncoderMessage::DType(self.dtype()));
85        ArrayIteratorIPCBytes {
86            inner: Box::new(self),
87            encoder,
88            buffers,
89        }
90    }
91
92    fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
93    where
94        Self: Sized,
95    {
96        let mut stream = self.into_ipc();
97        for buffer in &mut stream {
98            write.write_all(buffer?.as_ref())?;
99        }
100        Ok(write)
101    }
102}
103
104pub struct ArrayIteratorIPCBytes {
105    inner: Box<dyn ArrayIterator + 'static>,
106    encoder: MessageEncoder,
107    buffers: Vec<Bytes>,
108}
109
110impl ArrayIteratorIPCBytes {
111    /// Collects the IPC bytes into a single `Bytes`.
112    pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
113        let buffers: Vec<Bytes> = self.try_collect()?;
114        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
115        for buf in buffers {
116            buffer.extend_from_slice(buf.as_ref());
117        }
118        Ok(buffer.freeze())
119    }
120}
121
122impl Iterator for ArrayIteratorIPCBytes {
123    type Item = VortexResult<Bytes>;
124
125    fn next(&mut self) -> Option<Self::Item> {
126        // Try to flush any buffers we have
127        if !self.buffers.is_empty() {
128            return Some(Ok(self.buffers.remove(0)));
129        }
130
131        // Or else try to serialize the next array
132        match self.inner.next()? {
133            Ok(chunk) => {
134                self.buffers
135                    .extend(self.encoder.encode(EncoderMessage::Array(&chunk)));
136            }
137            Err(e) => return Some(Err(e)),
138        }
139
140        // Try to flush any buffers we have again
141        if !self.buffers.is_empty() {
142            return Some(Ok(self.buffers.remove(0)));
143        }
144
145        // Otherwise, we're done
146        None
147    }
148}
149
150#[cfg(test)]
151mod test {
152    use std::io::Cursor;
153
154    use vortex_array::ToCanonical;
155    use vortex_array::arrays::PrimitiveArray;
156    use vortex_array::iter::{ArrayIterator, ArrayIteratorExt};
157
158    use super::*;
159
160    #[test]
161    fn test_sync_stream() {
162        let array = PrimitiveArray::from_iter([1i32, 2, 3]);
163        let ipc_buffer = array
164            .to_array_iterator()
165            .into_ipc()
166            .collect_to_buffer()
167            .unwrap();
168
169        let reader =
170            SyncIPCReader::try_new(Cursor::new(ipc_buffer), ArrayRegistry::canonical_only())
171                .unwrap();
172
173        assert_eq!(reader.dtype(), array.dtype());
174        let result = reader.read_all().unwrap().to_primitive().unwrap();
175        assert_eq!(array.as_slice::<i32>(), result.as_slice::<i32>());
176    }
177}