vortex_ipc/
iterator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::io::Read;
5use std::io::Write;
6
7use bytes::Bytes;
8use bytes::BytesMut;
9use itertools::Itertools;
10use vortex_array::ArrayRef;
11use vortex_array::iter::ArrayIterator;
12use vortex_array::session::ArrayRegistry;
13use vortex_dtype::DType;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17
18use crate::messages::DecoderMessage;
19use crate::messages::EncoderMessage;
20use crate::messages::MessageEncoder;
21use crate::messages::SyncMessageReader;
22
23/// An [`ArrayIterator`] for reading messages off an IPC stream.
24pub struct SyncIPCReader<R: Read> {
25    reader: SyncMessageReader<R>,
26    dtype: DType,
27}
28
29impl<R: Read> SyncIPCReader<R> {
30    pub fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
31        let mut reader = SyncMessageReader::new(read, registry);
32        match reader.next().transpose()? {
33            Some(msg) => match msg {
34                DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, dtype }),
35                msg => {
36                    vortex_bail!("Expected DType message, got {:?}", msg);
37                }
38            },
39            None => vortex_bail!("Expected DType message, got EOF"),
40        }
41    }
42}
43
44impl<R: Read> ArrayIterator for SyncIPCReader<R> {
45    fn dtype(&self) -> &DType {
46        &self.dtype
47    }
48}
49
50impl<R: Read> Iterator for SyncIPCReader<R> {
51    type Item = VortexResult<ArrayRef>;
52
53    fn next(&mut self) -> Option<Self::Item> {
54        match self.reader.next()? {
55            Ok(msg) => match msg {
56                DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
57                    array_parts
58                        .decode(&ctx, &self.dtype, row_count)
59                        .and_then(|array| {
60                            if array.dtype() != self.dtype() {
61                                Err(vortex_err!(
62                                    "Array data type mismatch: expected {:?}, got {:?}",
63                                    self.dtype(),
64                                    array.dtype()
65                                ))
66                            } else {
67                                Ok(array)
68                            }
69                        }),
70                ),
71                msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
72            },
73            Err(e) => Some(Err(e)),
74        }
75    }
76}
77
78/// A trait for converting an [`ArrayIterator`] into an IPC stream.
79pub trait ArrayIteratorIPC {
80    fn into_ipc(self) -> ArrayIteratorIPCBytes
81    where
82        Self: Sized;
83
84    fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
85    where
86        Self: Sized;
87}
88
89impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
90    fn into_ipc(self) -> ArrayIteratorIPCBytes
91    where
92        Self: Sized,
93    {
94        let mut encoder = MessageEncoder::default();
95        let buffers = encoder.encode(EncoderMessage::DType(self.dtype()));
96        ArrayIteratorIPCBytes {
97            inner: Box::new(self),
98            encoder,
99            buffers,
100        }
101    }
102
103    fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
104    where
105        Self: Sized,
106    {
107        let mut stream = self.into_ipc();
108        for buffer in &mut stream {
109            write.write_all(buffer?.as_ref())?;
110        }
111        Ok(write)
112    }
113}
114
115pub struct ArrayIteratorIPCBytes {
116    inner: Box<dyn ArrayIterator + 'static>,
117    encoder: MessageEncoder,
118    buffers: Vec<Bytes>,
119}
120
121impl ArrayIteratorIPCBytes {
122    /// Collects the IPC bytes into a single `Bytes`.
123    pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
124        let buffers: Vec<Bytes> = self.try_collect()?;
125        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
126        for buf in buffers {
127            buffer.extend_from_slice(buf.as_ref());
128        }
129        Ok(buffer.freeze())
130    }
131}
132
133impl Iterator for ArrayIteratorIPCBytes {
134    type Item = VortexResult<Bytes>;
135
136    fn next(&mut self) -> Option<Self::Item> {
137        // Try to flush any buffers we have
138        if !self.buffers.is_empty() {
139            return Some(Ok(self.buffers.remove(0)));
140        }
141
142        // Or else try to serialize the next array
143        match self.inner.next()? {
144            Ok(chunk) => {
145                self.buffers
146                    .extend(self.encoder.encode(EncoderMessage::Array(&chunk)));
147            }
148            Err(e) => return Some(Err(e)),
149        }
150
151        // Try to flush any buffers we have again
152        if !self.buffers.is_empty() {
153            return Some(Ok(self.buffers.remove(0)));
154        }
155
156        // Otherwise, we're done
157        None
158    }
159}
160
161#[cfg(test)]
162mod test {
163    use std::io::Cursor;
164
165    use vortex_array::IntoArray as _;
166    use vortex_array::ToCanonical;
167    use vortex_array::iter::ArrayIterator;
168    use vortex_array::iter::ArrayIteratorExt;
169    use vortex_array::session::ArraySession;
170    use vortex_buffer::buffer;
171
172    use super::*;
173
174    #[test]
175    fn test_sync_stream() {
176        let session = ArraySession::default();
177
178        let array = buffer![1i32, 2, 3].into_array();
179        let ipc_buffer = array
180            .to_array_iterator()
181            .into_ipc()
182            .collect_to_buffer()
183            .unwrap();
184
185        let reader =
186            SyncIPCReader::try_new(Cursor::new(ipc_buffer), session.registry().clone()).unwrap();
187
188        assert_eq!(reader.dtype(), array.dtype());
189        let result = reader.read_all().unwrap().to_primitive();
190        assert_eq!(
191            array.to_primitive().as_slice::<i32>(),
192            result.as_slice::<i32>()
193        );
194    }
195}