vortex_ipc/messages/
decoder.rs

1use std::fmt::Debug;
2
3use bytes::Buf;
4use flatbuffers::{root, root_unchecked};
5use vortex_array::serde::ArrayParts;
6use vortex_array::{ArrayContext, ArrayRegistry};
7use vortex_buffer::{AlignedBuf, Alignment, ByteBuffer};
8use vortex_dtype::DType;
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_flatbuffers::message::{MessageHeader, MessageVersion};
11use vortex_flatbuffers::{FlatBuffer, dtype as fbd, message as fb};
12
13/// A message decoded from an IPC stream.
14///
15/// Note that the `Array` variant cannot fully decode into an [`vortex_array::ArrayRef`] without
16/// a [`vortex_array::ArrayContext`] and a [`DType`]. As such, we partially decode into an
17/// [`ArrayParts`] and allow the caller to finish the decoding.
18#[derive(Debug)]
19pub enum DecoderMessage {
20    Array((ArrayParts, ArrayContext, usize)),
21    Buffer(ByteBuffer),
22    DType(DType),
23}
24
25#[derive(Default)]
26enum State {
27    #[default]
28    Length,
29    Header(usize),
30    Reading(FlatBuffer),
31}
32
33#[derive(Debug)]
34pub enum PollRead {
35    Some(DecoderMessage),
36    /// Returns the _total_ number of bytes needed to make progress.
37    /// Note this is _not_ the incremental number of bytes needed to make progress.
38    NeedMore(usize),
39}
40
41// NOTE(ngates): we should design some trait that the Decoder can take that doesn't require unique
42//  ownership of the underlying bytes. The decoder needs to split out bytes, and advance a cursor,
43//  but it doesn't need to mutate any bytes. So in theory, we should be able to do this zero-copy
44//  over a shared buffer of bytes, instead of requiring a `BytesMut`.
45/// A stateful reader for decoding IPC messages from an arbitrary stream of bytes.
46pub struct MessageDecoder {
47    registry: ArrayRegistry,
48    /// The current state of the decoder.
49    state: State,
50}
51
52impl MessageDecoder {
53    /// Create a new message decoder that can lookup encodings in the given registry.
54    pub fn new(registry: ArrayRegistry) -> Self {
55        Self {
56            registry,
57            state: State::default(),
58        }
59    }
60
61    /// Attempt to read the next message from the bytes object.
62    ///
63    /// If the message is incomplete, the function will return `NeedMore` with the _total_ number
64    /// of bytes needed to make progress. The next call to read_next _should_ provide at least
65    /// this number of bytes otherwise it will be given the same `NeedMore` response.
66    pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
67        loop {
68            match &self.state {
69                State::Length => {
70                    if bytes.remaining() < 4 {
71                        return Ok(PollRead::NeedMore(4));
72                    }
73
74                    let msg_length = bytes.get_u32_le();
75                    self.state = State::Header(msg_length as usize);
76                }
77                State::Header(msg_length) => {
78                    if bytes.remaining() < *msg_length {
79                        return Ok(PollRead::NeedMore(*msg_length));
80                    }
81
82                    let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
83                    let msg = root::<fb::Message>(msg_bytes.as_ref())?;
84                    if msg.version() != MessageVersion::V0 {
85                        vortex_bail!("Unsupported message version {:?}", msg.version());
86                    }
87
88                    self.state = State::Reading(msg_bytes);
89                }
90                State::Reading(msg_bytes) => {
91                    // SAFETY: we've already validated the header in the previous state
92                    let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
93
94                    // Now we read the body
95                    let body_length = usize::try_from(msg.body_size()).map_err(|_| {
96                        vortex_err!("body size {} is too large for usize", msg.body_size())
97                    })?;
98                    if bytes.remaining() < body_length {
99                        return Ok(PollRead::NeedMore(body_length));
100                    }
101
102                    match msg.header_type() {
103                        MessageHeader::ArrayMessage => {
104                            // We don't care about alignment here since ArrayParts will handle it.
105                            let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
106                            let parts = ArrayParts::try_from(body)?;
107
108                            let header = msg
109                                .header_as_array_message()
110                                .vortex_expect("header is array");
111
112                            let ctx = self
113                                .registry
114                                .new_context(header.encodings().iter().flat_map(|e| e.iter()))?;
115                            let row_count = header.row_count() as usize;
116
117                            self.state = Default::default();
118                            return Ok(PollRead::Some(DecoderMessage::Array((
119                                parts, ctx, row_count,
120                            ))));
121                        }
122                        MessageHeader::BufferMessage => {
123                            let body = bytes.copy_to_aligned(
124                                body_length,
125                                Alignment::from_exponent(
126                                    msg.header_as_buffer_message()
127                                        .vortex_expect("header is buffer")
128                                        .alignment_exponent(),
129                                ),
130                            );
131
132                            self.state = Default::default();
133                            return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
134                        }
135                        MessageHeader::DTypeMessage => {
136                            let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
137                            let fb_dtype = root::<fbd::DType>(body.as_ref())?;
138                            let dtype = DType::try_from_view(fb_dtype, body.clone())?;
139
140                            self.state = Default::default();
141                            return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
142                        }
143                        _ => {
144                            vortex_bail!("Unsupported message header {:?}", msg.header_type());
145                        }
146                    }
147                }
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use bytes::BytesMut;
156    use vortex_array::arrays::ConstantArray;
157    use vortex_array::{Array, ArrayVisitor, IntoArray};
158    use vortex_buffer::buffer;
159    use vortex_error::vortex_panic;
160
161    use super::*;
162    use crate::messages::{EncoderMessage, MessageEncoder};
163
164    fn write_and_read(expected: &dyn Array) {
165        let mut ipc_bytes = BytesMut::new();
166        let mut encoder = MessageEncoder::default();
167        for buf in encoder.encode(EncoderMessage::Array(expected)) {
168            ipc_bytes.extend_from_slice(buf.as_ref());
169        }
170
171        let mut decoder = MessageDecoder::new(ArrayRegistry::canonical_only());
172
173        // Since we provide all bytes up-front, we should never hit a NeedMore.
174        let mut buffer = BytesMut::from(ipc_bytes.as_ref());
175        let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
176            PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
177            otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
178        };
179
180        // Decode the array parts with the context
181        let actual = array_parts
182            .decode(&ctx, expected.dtype().clone(), row_count)
183            .unwrap();
184
185        assert_eq!(expected.len(), actual.len());
186        assert_eq!(expected.encoding(), actual.encoding());
187    }
188
189    #[test]
190    fn array_ipc() {
191        write_and_read(&buffer![0i32, 1, 2, 3].into_array());
192    }
193
194    #[test]
195    fn array_no_buffers() {
196        // Constant arrays have a single buffer
197        let array = ConstantArray::new(10i32, 20);
198        assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
199        write_and_read(&array);
200    }
201}