vortex_ipc/
iterator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::io::{Read, Write};
5
6use bytes::{Bytes, BytesMut};
7use itertools::Itertools;
8use vortex_array::iter::ArrayIterator;
9use vortex_array::{ArrayRef, ArrayRegistry};
10use vortex_dtype::DType;
11use vortex_error::{VortexResult, vortex_bail, vortex_err};
12
13use crate::messages::{DecoderMessage, EncoderMessage, MessageEncoder, SyncMessageReader};
14
15/// An [`ArrayIterator`] for reading messages off an IPC stream.
16pub struct SyncIPCReader<R: Read> {
17    reader: SyncMessageReader<R>,
18    dtype: DType,
19}
20
21impl<R: Read> SyncIPCReader<R> {
22    pub fn try_new(read: R, registry: ArrayRegistry) -> VortexResult<Self> {
23        let mut reader = SyncMessageReader::new(read, registry);
24        match reader.next().transpose()? {
25            Some(msg) => match msg {
26                DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, dtype }),
27                msg => {
28                    vortex_bail!("Expected DType message, got {:?}", msg);
29                }
30            },
31            None => vortex_bail!("Expected DType message, got EOF"),
32        }
33    }
34}
35
36impl<R: Read> ArrayIterator for SyncIPCReader<R> {
37    fn dtype(&self) -> &DType {
38        &self.dtype
39    }
40}
41
42impl<R: Read> Iterator for SyncIPCReader<R> {
43    type Item = VortexResult<ArrayRef>;
44
45    fn next(&mut self) -> Option<Self::Item> {
46        match self.reader.next()? {
47            Ok(msg) => match msg {
48                DecoderMessage::Array((array_parts, ctx, row_count)) => Some(
49                    array_parts
50                        .decode(&ctx, &self.dtype, row_count)
51                        .and_then(|array| {
52                            if array.dtype() != self.dtype() {
53                                Err(vortex_err!(
54                                    "Array data type mismatch: expected {:?}, got {:?}",
55                                    self.dtype(),
56                                    array.dtype()
57                                ))
58                            } else {
59                                Ok(array)
60                            }
61                        }),
62                ),
63                msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))),
64            },
65            Err(e) => Some(Err(e)),
66        }
67    }
68}
69
70/// A trait for converting an [`ArrayIterator`] into an IPC stream.
71pub trait ArrayIteratorIPC {
72    fn into_ipc(self) -> ArrayIteratorIPCBytes
73    where
74        Self: Sized;
75
76    fn write_ipc<W: Write>(self, write: W) -> VortexResult<W>
77    where
78        Self: Sized;
79}
80
81impl<I: ArrayIterator + 'static> ArrayIteratorIPC for I {
82    fn into_ipc(self) -> ArrayIteratorIPCBytes
83    where
84        Self: Sized,
85    {
86        let mut encoder = MessageEncoder::default();
87        let buffers = encoder.encode(EncoderMessage::DType(self.dtype()));
88        ArrayIteratorIPCBytes {
89            inner: Box::new(self),
90            encoder,
91            buffers,
92        }
93    }
94
95    fn write_ipc<W: Write>(self, mut write: W) -> VortexResult<W>
96    where
97        Self: Sized,
98    {
99        let mut stream = self.into_ipc();
100        for buffer in &mut stream {
101            write.write_all(buffer?.as_ref())?;
102        }
103        Ok(write)
104    }
105}
106
107pub struct ArrayIteratorIPCBytes {
108    inner: Box<dyn ArrayIterator + 'static>,
109    encoder: MessageEncoder,
110    buffers: Vec<Bytes>,
111}
112
113impl ArrayIteratorIPCBytes {
114    /// Collects the IPC bytes into a single `Bytes`.
115    pub fn collect_to_buffer(self) -> VortexResult<Bytes> {
116        let buffers: Vec<Bytes> = self.try_collect()?;
117        let mut buffer = BytesMut::with_capacity(buffers.iter().map(|b| b.len()).sum());
118        for buf in buffers {
119            buffer.extend_from_slice(buf.as_ref());
120        }
121        Ok(buffer.freeze())
122    }
123}
124
125impl Iterator for ArrayIteratorIPCBytes {
126    type Item = VortexResult<Bytes>;
127
128    fn next(&mut self) -> Option<Self::Item> {
129        // Try to flush any buffers we have
130        if !self.buffers.is_empty() {
131            return Some(Ok(self.buffers.remove(0)));
132        }
133
134        // Or else try to serialize the next array
135        match self.inner.next()? {
136            Ok(chunk) => {
137                self.buffers
138                    .extend(self.encoder.encode(EncoderMessage::Array(&chunk)));
139            }
140            Err(e) => return Some(Err(e)),
141        }
142
143        // Try to flush any buffers we have again
144        if !self.buffers.is_empty() {
145            return Some(Ok(self.buffers.remove(0)));
146        }
147
148        // Otherwise, we're done
149        None
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use std::io::Cursor;
156
157    use vortex_array::iter::{ArrayIterator, ArrayIteratorExt};
158    use vortex_array::{IntoArray as _, ToCanonical};
159    use vortex_buffer::buffer;
160
161    use super::*;
162
163    #[test]
164    fn test_sync_stream() {
165        let array = buffer![1i32, 2, 3].into_array();
166        let ipc_buffer = array
167            .to_array_iterator()
168            .into_ipc()
169            .collect_to_buffer()
170            .unwrap();
171
172        let reader =
173            SyncIPCReader::try_new(Cursor::new(ipc_buffer), ArrayRegistry::canonical_only())
174                .unwrap();
175
176        assert_eq!(reader.dtype(), array.dtype());
177        let result = reader.read_all().unwrap().to_primitive();
178        assert_eq!(
179            array.to_primitive().as_slice::<i32>(),
180            result.as_slice::<i32>()
181        );
182    }
183}