Skip to main content

vortex_ipc/
iterator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::io::Read;
5use std::io::Write;
6
7use bytes::Bytes;
8use bytes::BytesMut;
9use itertools::Itertools;
10use vortex_array::ArrayRef;
11use vortex_array::dtype::DType;
12use vortex_array::iter::ArrayIterator;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_session::VortexSession;
17
18use crate::messages::DecoderMessage;
19use crate::messages::EncoderMessage;
20use crate::messages::MessageEncoder;
21use crate::messages::SyncMessageReader;
22
23/// An [`ArrayIterator`] for reading messages off an IPC stream.
24pub struct SyncIPCReader<R: Read> {
25    reader: SyncMessageReader<R>,
26    dtype: DType,
27    session: VortexSession,
28}
29
30impl<R: Read> SyncIPCReader<R> {
31    pub fn try_new(read: R, session: &VortexSession) -> VortexResult<Self> {
32        let mut reader = SyncMessageReader::new(read);
33        match reader.next().transpose()? {
34            Some(msg) => match msg {
35                DecoderMessage::DType(fb_dtype) => {
36                    let dtype = DType::from_flatbuffer(fb_dtype, session)?;
37                    Ok(SyncIPCReader {
38                        reader,
39                        dtype,
40                        session: session.clone(),
41                    })
42                }
43                msg => {
44                    vortex_bail!("Expected DType message, got {:?}", msg);
45                }
46            },
47            None => vortex_bail!("Expected DType message, got EOF"),
48        }
49    }
50}
51
52impl<R: Read> ArrayIterator for SyncIPCReader<R> {
53    fn dtype(&self) -> &DType {
54        &self.dtype
55    }
56}
57
58impl<R: Read> Iterator for SyncIPCReader<R> {
59    type Item = VortexResult<ArrayRef>;
60
61    fn next(&mut self) -> Option<Self::Item> {
62        match self.reader.next()? {
63            Ok(msg) => match msg {
64                DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
65                    array_parts
66                        .decode(&self.dtype, row_count, &ctx, &self.session)
67                        .and_then(|array| {
68                            if array.dtype() != self.dtype() {
69                                Err(vortex_err!(
70                                    "Array data type mismatch: expected {:?}, got {:?}",
71                                    self.dtype(),
72                                    array.dtype()
73                                ))
74                            } else {
75                                Ok(array)
76                            }
77                        }),
78                ),
79                msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
80            },
81            Err(e) => Some(Err(e)),
82        }
83    }
84}
85
86/// A trait for converting an [`ArrayIterator`] into an IPC stream.
87pub trait ArrayIteratorIPC {
88    fn into_ipc(self) -> VortexResult<ArrayIteratorIPCBytes>
89    where
90        Self: Sized;
91
92    fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
93    where
94        Self: Sized;
95}
96
97impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
98    fn into_ipc(self) -> VortexResult<ArrayIteratorIPCBytes>
99    where
100        Self: Sized,
101    {
102        let mut encoder = MessageEncoder::default();
103        let buffers = encoder.encode(EncoderMessage::DType(self.dtype()))?;
104        Ok(ArrayIteratorIPCBytes {
105            inner: Box::new(self),
106            encoder,
107            buffers,
108        })
109    }
110
111    fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
112    where
113        Self: Sized,
114    {
115        let mut stream = self.into_ipc()?;
116        for buffer in &mut stream {
117            write.write_all(buffer?.as_ref())?;
118        }
119        Ok(write)
120    }
121}
122
123pub struct ArrayIteratorIPCBytes {
124    inner: Box<dyn ArrayIterator + 'static>,
125    encoder: MessageEncoder,
126    buffers: Vec<Bytes>,
127}
128
129impl ArrayIteratorIPCBytes {
130    /// Collects the IPC bytes into a single `Bytes`.
131    pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
132        let buffers: Vec<Bytes> = self.try_collect()?;
133        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
134        for buf in buffers {
135            buffer.extend_from_slice(buf.as_ref());
136        }
137        Ok(buffer.freeze())
138    }
139}
140
141impl Iterator for ArrayIteratorIPCBytes {
142    type Item = VortexResult<Bytes>;
143
144    fn next(&mut self) -> Option<Self::Item> {
145        // Try to flush any buffers we have
146        if !self.buffers.is_empty() {
147            return Some(Ok(self.buffers.remove(0)));
148        }
149
150        // Or else try to serialize the next array
151        match self
152            .inner
153            .next()?
154            .and_then(|chunk| self.encoder.encode(EncoderMessage::Array(&chunk)))
155        {
156            Ok(buffers) => self.buffers.extend(buffers),
157            Err(e) => return Some(Err(e)),
158        }
159
160        // Try to flush any buffers we have again
161        if !self.buffers.is_empty() {
162            return Some(Ok(self.buffers.remove(0)));
163        }
164
165        // Otherwise, we're done
166        None
167    }
168}
169
170#[cfg(test)]
171mod test {
172    use std::io::Cursor;
173
174    use vortex_array::IntoArray as _;
175    use vortex_array::assert_arrays_eq;
176    use vortex_array::iter::ArrayIterator;
177    use vortex_array::iter::ArrayIteratorExt;
178    use vortex_buffer::buffer;
179
180    use super::*;
181    use crate::test::SESSION;
182
183    #[test]
184    fn test_sync_stream() -> VortexResult<()> {
185        let array = buffer![1i32, 2, 3].into_array();
186        let ipc_buffer = array.to_array_iterator().into_ipc()?.collect_to_buffer()?;
187
188        let reader = SyncIPCReader::try_new(Cursor::new(ipc_buffer), &SESSION)?;
189
190        assert_eq!(reader.dtype(), array.dtype());
191        let result = reader.read_all()?;
192        assert_arrays_eq!(result, array);
193
194        Ok(())
195    }
196}