vortex_ipc/messages/
decoder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use bytes::Buf;
7use flatbuffers::root;
8use flatbuffers::root_unchecked;
9use itertools::Itertools;
10use vortex_array::ArrayContext;
11use vortex_array::serde::ArrayParts;
12use vortex_array::session::ArrayRegistry;
13use vortex_buffer::AlignedBuf;
14use vortex_buffer::Alignment;
15use vortex_buffer::ByteBuffer;
16use vortex_dtype::DType;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_err;
21use vortex_flatbuffers::FlatBuffer;
22use vortex_flatbuffers::dtype as fbd;
23use vortex_flatbuffers::message as fb;
24use vortex_flatbuffers::message::MessageHeader;
25use vortex_flatbuffers::message::MessageVersion;
26
27/// A message decoded from an IPC stream.
28///
29/// Note that the `Array` variant cannot fully decode into an [`vortex_array::ArrayRef`] without
30/// a [`vortex_array::ArrayContext`] and a [`DType`]. As such, we partially decode into an
31/// [`ArrayParts`] and allow the caller to finish the decoding.
32#[derive(Debug)]
33pub enum DecoderMessage {
34    Array((ArrayParts, ArrayContext, usize)),
35    Buffer(ByteBuffer),
36    DType(DType),
37}
38
39#[derive(Default)]
40enum State {
41    #[default]
42    Length,
43    Header(usize),
44    Reading(FlatBuffer),
45}
46
47#[derive(Debug)]
48pub enum PollRead {
49    /// A complete message was decoded.
50    Some(DecoderMessage),
51    /// The decoder needs more data to make progress.
52    ///
53    /// The inner value is the **total*k number of bytes the buffer should contain, not the
54    /// incremental amount needed. Callers should:
55    ///
56    /// 1. Resize the buffer to this length.
57    /// 2. Fill the buffer completely (handling partial reads as needed).
58    /// 3. Only then call [`MessageDecoder::read_next`] again.
59    ///
60    /// The decoder checks [`bytes::Buf::remaining`] to determine available data, which for
61    /// [`bytes::BytesMut`] returns the buffer length regardless of how many bytes were actually
62    /// written. Calling `read_next` before the buffer is fully populated will cause the decoder
63    /// to read garbage data.
64    NeedMore(usize),
65}
66
67// NOTE(ngates): we should design some trait that the Decoder can take that doesn't require unique
68//  ownership of the underlying bytes. The decoder needs to split out bytes, and advance a cursor,
69//  but it doesn't need to mutate any bytes. So in theory, we should be able to do this zero-copy
70//  over a shared buffer of bytes, instead of requiring a `BytesMut`.
71/// A stateful reader for decoding IPC messages from an arbitrary stream of bytes.
72pub struct MessageDecoder {
73    registry: ArrayRegistry,
74    /// The current state of the decoder.
75    state: State,
76}
77
78impl MessageDecoder {
79    /// Create a new message decoder that can lookup encodings in the given registry.
80    pub fn new(registry: ArrayRegistry) -> Self {
81        Self {
82            registry,
83            state: State::default(),
84        }
85    }
86
87    /// Attempt to read the next message from the bytes object.
88    ///
89    /// If the message is incomplete, the function will return `NeedMore` with the _total_ number
90    /// of bytes needed to make progress. The next call to read_next _should_ provide at least
91    /// this number of bytes otherwise it will be given the same `NeedMore` response.
92    pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
93        loop {
94            match &self.state {
95                State::Length => {
96                    if bytes.remaining() < 4 {
97                        return Ok(PollRead::NeedMore(4));
98                    }
99
100                    let msg_length = bytes.get_u32_le();
101                    self.state = State::Header(msg_length as usize);
102                }
103                State::Header(msg_length) => {
104                    if bytes.remaining() < *msg_length {
105                        return Ok(PollRead::NeedMore(*msg_length));
106                    }
107
108                    let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
109                    let msg = root::<fb::Message>(msg_bytes.as_ref())?;
110                    if msg.version() != MessageVersion::V0 {
111                        vortex_bail!("Unsupported message version {:?}", msg.version());
112                    }
113
114                    self.state = State::Reading(msg_bytes);
115                }
116                State::Reading(msg_bytes) => {
117                    // SAFETY: we've already validated the header in the previous state
118                    let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
119
120                    // Now we read the body
121                    let body_length = usize::try_from(msg.body_size()).map_err(|_| {
122                        vortex_err!("body size {} is too large for usize", msg.body_size())
123                    })?;
124                    if bytes.remaining() < body_length {
125                        return Ok(PollRead::NeedMore(body_length));
126                    }
127
128                    match msg.header_type() {
129                        MessageHeader::ArrayMessage => {
130                            // We don't care about alignment here since ArrayParts will handle it.
131                            let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
132                            let parts = ArrayParts::try_from(body)?;
133
134                            let header = msg
135                                .header_as_array_message()
136                                .vortex_expect("header is array");
137
138                            let encodings: Vec<_> = header
139                                .encodings()
140                                .iter()
141                                .flat_map(|e| e.iter())
142                                .map(|id| {
143                                    self.registry.find(id).ok_or_else(|| {
144                                        vortex_err!("unknown array encoding id: {}", id)
145                                    })
146                                })
147                                .try_collect()?;
148                            let ctx = ArrayContext::new(encodings);
149                            let row_count = header.row_count() as usize;
150
151                            self.state = Default::default();
152                            return Ok(PollRead::Some(DecoderMessage::Array((
153                                parts, ctx, row_count,
154                            ))));
155                        }
156                        MessageHeader::BufferMessage => {
157                            let body = bytes.copy_to_aligned(
158                                body_length,
159                                Alignment::from_exponent(
160                                    msg.header_as_buffer_message()
161                                        .vortex_expect("header is buffer")
162                                        .alignment_exponent(),
163                                ),
164                            );
165
166                            self.state = Default::default();
167                            return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
168                        }
169                        MessageHeader::DTypeMessage => {
170                            let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
171                            let fb_dtype = root::<fbd::DType>(body.as_ref())?;
172                            let dtype = DType::try_from_view(fb_dtype, body.clone())?;
173
174                            self.state = Default::default();
175                            return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
176                        }
177                        _ => {
178                            vortex_bail!("Unsupported message header {:?}", msg.header_type());
179                        }
180                    }
181                }
182            }
183        }
184    }
185}
186
187#[cfg(test)]
188mod test {
189    use bytes::BytesMut;
190    use vortex_array::Array;
191    use vortex_array::IntoArray;
192    use vortex_array::arrays::ConstantArray;
193    use vortex_array::session::ArraySession;
194    use vortex_buffer::buffer;
195    use vortex_error::vortex_panic;
196
197    use super::*;
198    use crate::messages::EncoderMessage;
199    use crate::messages::MessageEncoder;
200
201    fn write_and_read(expected: &dyn Array) {
202        let mut ipc_bytes = BytesMut::new();
203        let mut encoder = MessageEncoder::default();
204        for buf in encoder.encode(EncoderMessage::Array(expected)) {
205            ipc_bytes.extend_from_slice(buf.as_ref());
206        }
207
208        let registry = ArraySession::default().registry().clone();
209        let mut decoder = MessageDecoder::new(registry);
210
211        // Since we provide all bytes up-front, we should never hit a NeedMore.
212        let mut buffer = BytesMut::from(ipc_bytes.as_ref());
213        let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
214            PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
215            otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
216        };
217
218        // Decode the array parts with the context
219        let actual = array_parts
220            .decode(&ctx, expected.dtype(), row_count)
221            .unwrap();
222
223        assert_eq!(expected.len(), actual.len());
224        assert_eq!(expected.encoding_id(), actual.encoding_id());
225    }
226
227    #[test]
228    fn array_ipc() {
229        write_and_read(&buffer![0i32, 1, 2, 3].into_array());
230    }
231
232    #[test]
233    fn array_no_buffers() {
234        // Constant arrays have a single buffer
235        let array = ConstantArray::new(10i32, 20);
236        assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
237        write_and_read(array.as_ref());
238    }
239}