vortex_ipc/messages/
decoder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use bytes::Buf;
7use flatbuffers::{root, root_unchecked};
8use vortex_array::serde::ArrayParts;
9use vortex_array::{ArrayContext, ArrayRegistry};
10use vortex_buffer::{AlignedBuf, Alignment, ByteBuffer};
11use vortex_dtype::DType;
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
13use vortex_flatbuffers::message::{MessageHeader, MessageVersion};
14use vortex_flatbuffers::{FlatBuffer, dtype as fbd, message as fb};
15
16/// A message decoded from an IPC stream.
17///
18/// Note that the `Array` variant cannot fully decode into an [`vortex_array::ArrayRef`] without
19/// a [`vortex_array::ArrayContext`] and a [`DType`]. As such, we partially decode into an
20/// [`ArrayParts`] and allow the caller to finish the decoding.
21#[derive(Debug)]
22pub enum DecoderMessage {
23    Array((ArrayParts, ArrayContext, usize)),
24    Buffer(ByteBuffer),
25    DType(DType),
26}
27
28#[derive(Default)]
29enum State {
30    #[default]
31    Length,
32    Header(usize),
33    Reading(FlatBuffer),
34}
35
36#[derive(Debug)]
37pub enum PollRead {
38    Some(DecoderMessage),
39    /// Returns the _total_ number of bytes needed to make progress.
40    /// Note this is _not_ the incremental number of bytes needed to make progress.
41    NeedMore(usize),
42}
43
44// NOTE(ngates): we should design some trait that the Decoder can take that doesn't require unique
45//  ownership of the underlying bytes. The decoder needs to split out bytes, and advance a cursor,
46//  but it doesn't need to mutate any bytes. So in theory, we should be able to do this zero-copy
47//  over a shared buffer of bytes, instead of requiring a `BytesMut`.
48/// A stateful reader for decoding IPC messages from an arbitrary stream of bytes.
49pub struct MessageDecoder {
50    registry: ArrayRegistry,
51    /// The current state of the decoder.
52    state: State,
53}
54
55impl MessageDecoder {
56    /// Create a new message decoder that can lookup encodings in the given registry.
57    pub fn new(registry: ArrayRegistry) -> Self {
58        Self {
59            registry,
60            state: State::default(),
61        }
62    }
63
64    /// Attempt to read the next message from the bytes object.
65    ///
66    /// If the message is incomplete, the function will return `NeedMore` with the _total_ number
67    /// of bytes needed to make progress. The next call to read_next _should_ provide at least
68    /// this number of bytes otherwise it will be given the same `NeedMore` response.
69    pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
70        loop {
71            match &self.state {
72                State::Length => {
73                    if bytes.remaining() < 4 {
74                        return Ok(PollRead::NeedMore(4));
75                    }
76
77                    let msg_length = bytes.get_u32_le();
78                    self.state = State::Header(msg_length as usize);
79                }
80                State::Header(msg_length) => {
81                    if bytes.remaining() < *msg_length {
82                        return Ok(PollRead::NeedMore(*msg_length));
83                    }
84
85                    let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
86                    let msg = root::<fb::Message>(msg_bytes.as_ref())?;
87                    if msg.version() != MessageVersion::V0 {
88                        vortex_bail!("Unsupported message version {:?}", msg.version());
89                    }
90
91                    self.state = State::Reading(msg_bytes);
92                }
93                State::Reading(msg_bytes) => {
94                    // SAFETY: we've already validated the header in the previous state
95                    let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
96
97                    // Now we read the body
98                    let body_length = usize::try_from(msg.body_size()).map_err(|_| {
99                        vortex_err!("body size {} is too large for usize", msg.body_size())
100                    })?;
101                    if bytes.remaining() < body_length {
102                        return Ok(PollRead::NeedMore(body_length));
103                    }
104
105                    match msg.header_type() {
106                        MessageHeader::ArrayMessage => {
107                            // We don't care about alignment here since ArrayParts will handle it.
108                            let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
109                            let parts = ArrayParts::try_from(body)?;
110
111                            let header = msg
112                                .header_as_array_message()
113                                .vortex_expect("header is array");
114
115                            let ctx = self
116                                .registry
117                                .new_context(header.encodings().iter().flat_map(|e| e.iter()))?;
118                            let row_count = header.row_count() as usize;
119
120                            self.state = Default::default();
121                            return Ok(PollRead::Some(DecoderMessage::Array((
122                                parts, ctx, row_count,
123                            ))));
124                        }
125                        MessageHeader::BufferMessage => {
126                            let body = bytes.copy_to_aligned(
127                                body_length,
128                                Alignment::from_exponent(
129                                    msg.header_as_buffer_message()
130                                        .vortex_expect("header is buffer")
131                                        .alignment_exponent(),
132                                ),
133                            );
134
135                            self.state = Default::default();
136                            return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
137                        }
138                        MessageHeader::DTypeMessage => {
139                            let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
140                            let fb_dtype = root::<fbd::DType>(body.as_ref())?;
141                            let dtype = DType::try_from_view(fb_dtype, body.clone())?;
142
143                            self.state = Default::default();
144                            return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
145                        }
146                        _ => {
147                            vortex_bail!("Unsupported message header {:?}", msg.header_type());
148                        }
149                    }
150                }
151            }
152        }
153    }
154}
155
156#[cfg(test)]
157mod test {
158    use bytes::BytesMut;
159    use vortex_array::arrays::ConstantArray;
160    use vortex_array::{Array, IntoArray};
161    use vortex_buffer::buffer;
162    use vortex_error::vortex_panic;
163
164    use super::*;
165    use crate::messages::{EncoderMessage, MessageEncoder};
166
167    fn write_and_read(expected: &dyn Array) {
168        let mut ipc_bytes = BytesMut::new();
169        let mut encoder = MessageEncoder::default();
170        for buf in encoder.encode(EncoderMessage::Array(expected)) {
171            ipc_bytes.extend_from_slice(buf.as_ref());
172        }
173
174        let mut decoder = MessageDecoder::new(ArrayRegistry::canonical_only());
175
176        // Since we provide all bytes up-front, we should never hit a NeedMore.
177        let mut buffer = BytesMut::from(ipc_bytes.as_ref());
178        let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
179            PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
180            otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
181        };
182
183        // Decode the array parts with the context
184        let actual = array_parts
185            .decode(&ctx, expected.dtype(), row_count)
186            .unwrap();
187
188        assert_eq!(expected.len(), actual.len());
189        assert_eq!(expected.encoding_id(), actual.encoding_id());
190    }
191
192    #[test]
193    fn array_ipc() {
194        write_and_read(&buffer![0i32, 1, 2, 3].into_array());
195    }
196
197    #[test]
198    fn array_no_buffers() {
199        // Constant arrays have a single buffer
200        let array = ConstantArray::new(10i32, 20);
201        assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
202        write_and_read(array.as_ref());
203    }
204}