Skip to main content

vortex_ipc/messages/
decoder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::sync::Arc;
6
7use bytes::Buf;
8use flatbuffers::root;
9use flatbuffers::root_unchecked;
10use vortex_array::ArrayContext;
11use vortex_array::serde::ArrayParts;
12use vortex_array::vtable::ArrayId;
13use vortex_buffer::AlignedBuf;
14use vortex_buffer::Alignment;
15use vortex_buffer::ByteBuffer;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20use vortex_flatbuffers::FlatBuffer;
21use vortex_flatbuffers::message as fb;
22use vortex_flatbuffers::message::MessageHeader;
23use vortex_flatbuffers::message::MessageVersion;
24
25/// A message decoded from an IPC stream.
26#[derive(Debug)]
27pub enum DecoderMessage {
28    Array((ArrayParts, ArrayContext, usize)),
29    Buffer(ByteBuffer),
30    DType(FlatBuffer),
31}
32
33#[derive(Default)]
34enum State {
35    #[default]
36    Length,
37    Header(usize),
38    Reading(FlatBuffer),
39}
40
41#[derive(Debug)]
42pub enum PollRead {
43    /// A complete message was decoded.
44    Some(DecoderMessage),
45    /// The decoder needs more data to make progress.
46    ///
47    /// The inner value is the **total*k number of bytes the buffer should contain, not the
48    /// incremental amount needed. Callers should:
49    ///
50    /// 1. Resize the buffer to this length.
51    /// 2. Fill the buffer completely (handling partial reads as needed).
52    /// 3. Only then call [`MessageDecoder::read_next`] again.
53    ///
54    /// The decoder checks [`bytes::Buf::remaining`] to determine available data, which for
55    /// [`bytes::BytesMut`] returns the buffer length regardless of how many bytes were actually
56    /// written. Calling `read_next` before the buffer is fully populated will cause the decoder
57    /// to read garbage data.
58    NeedMore(usize),
59}
60
61// NOTE(ngates): we should design some trait that the Decoder can take that doesn't require unique
62//  ownership of the underlying bytes. The decoder needs to split out bytes, and advance a cursor,
63//  but it doesn't need to mutate any bytes. So in theory, we should be able to do this zero-copy
64//  over a shared buffer of bytes, instead of requiring a `BytesMut`.
65/// A stateful reader for decoding IPC messages from an arbitrary stream of bytes.
66#[derive(Default)]
67pub struct MessageDecoder {
68    /// The current state of the decoder.
69    state: State,
70}
71
72impl MessageDecoder {
73    /// Attempt to read the next message from the bytes object.
74    ///
75    /// If the message is incomplete, the function will return `NeedMore` with the _total_ number
76    /// of bytes needed to make progress. The next call to read_next _should_ provide at least
77    /// this number of bytes otherwise it will be given the same `NeedMore` response.
78    pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
79        loop {
80            match &self.state {
81                State::Length => {
82                    if bytes.remaining() < 4 {
83                        return Ok(PollRead::NeedMore(4));
84                    }
85
86                    let msg_length = bytes.get_u32_le();
87                    self.state = State::Header(msg_length as usize);
88                }
89                State::Header(msg_length) => {
90                    if bytes.remaining() < *msg_length {
91                        return Ok(PollRead::NeedMore(*msg_length));
92                    }
93
94                    let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
95                    let msg = root::<fb::Message>(msg_bytes.as_ref())?;
96                    if msg.version() != MessageVersion::V0 {
97                        vortex_bail!("Unsupported message version {:?}", msg.version());
98                    }
99
100                    self.state = State::Reading(msg_bytes);
101                }
102                State::Reading(msg_bytes) => {
103                    // SAFETY: we've already validated the header in the previous state
104                    let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
105
106                    // Now we read the body
107                    let body_length = usize::try_from(msg.body_size()).map_err(|_| {
108                        vortex_err!("body size {} is too large for usize", msg.body_size())
109                    })?;
110                    if bytes.remaining() < body_length {
111                        return Ok(PollRead::NeedMore(body_length));
112                    }
113
114                    match msg.header_type() {
115                        MessageHeader::ArrayMessage => {
116                            // We don't care about alignment here since ArrayParts will handle it.
117                            let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
118                            let parts = ArrayParts::try_from(body)?;
119
120                            let header = msg
121                                .header_as_array_message()
122                                .vortex_expect("header is array");
123
124                            let encoding_ids: Vec<_> = header
125                                .encodings()
126                                .iter()
127                                .flat_map(|e| e.iter())
128                                .map(|id| ArrayId::new_arc(Arc::from(id.to_string())))
129                                .collect();
130
131                            let ctx = ArrayContext::new(encoding_ids);
132                            let row_count = header.row_count() as usize;
133
134                            self.state = Default::default();
135                            return Ok(PollRead::Some(DecoderMessage::Array((
136                                parts, ctx, row_count,
137                            ))));
138                        }
139                        MessageHeader::BufferMessage => {
140                            let body = bytes.copy_to_aligned(
141                                body_length,
142                                Alignment::from_exponent(
143                                    msg.header_as_buffer_message()
144                                        .vortex_expect("header is buffer")
145                                        .alignment_exponent(),
146                                ),
147                            );
148
149                            self.state = Default::default();
150                            return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
151                        }
152                        MessageHeader::DTypeMessage => {
153                            let dtype: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
154                            self.state = Default::default();
155                            return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
156                        }
157                        _ => {
158                            vortex_bail!("Unsupported message header {:?}", msg.header_type());
159                        }
160                    }
161                }
162            }
163        }
164    }
165}
166
167#[cfg(test)]
168mod test {
169    use bytes::BytesMut;
170    use vortex_array::Array;
171    use vortex_array::ArrayRef;
172    use vortex_array::IntoArray;
173    use vortex_array::arrays::ConstantArray;
174    use vortex_buffer::buffer;
175    use vortex_error::vortex_panic;
176
177    use super::*;
178    use crate::messages::EncoderMessage;
179    use crate::messages::MessageEncoder;
180    use crate::test::SESSION;
181
182    fn write_and_read(expected: &ArrayRef) {
183        let mut ipc_bytes = BytesMut::new();
184        let mut encoder = MessageEncoder::default();
185        for buf in encoder.encode(EncoderMessage::Array(expected)).unwrap() {
186            ipc_bytes.extend_from_slice(buf.as_ref());
187        }
188
189        let mut decoder = MessageDecoder::default();
190
191        // Since we provide all bytes up-front, we should never hit a NeedMore.
192        let mut buffer = BytesMut::from(ipc_bytes.as_ref());
193        let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
194            PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
195            otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
196        };
197
198        // Decode the array parts with the context
199        let actual = array_parts
200            .decode(expected.dtype(), row_count, &ctx, &SESSION)
201            .unwrap();
202
203        assert_eq!(expected.len(), actual.len());
204        assert_eq!(expected.encoding_id(), actual.encoding_id());
205    }
206
207    #[test]
208    fn array_ipc() {
209        write_and_read(&buffer![0i32, 1, 2, 3].into_array());
210    }
211
212    #[test]
213    fn array_no_buffers() {
214        // Constant arrays have a single buffer
215        let array = ConstantArray::new(10i32, 20);
216        assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
217        write_and_read(&array.to_array());
218    }
219}