vortex_ipc/messages/
decoder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use bytes::Buf;
7use flatbuffers::{root, root_unchecked};
8use itertools::Itertools;
9use vortex_array::serde::ArrayParts;
10use vortex_array::{ArrayContext, ArrayRegistry};
11use vortex_buffer::{AlignedBuf, Alignment, ByteBuffer};
12use vortex_dtype::DType;
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_flatbuffers::message::{MessageHeader, MessageVersion};
15use vortex_flatbuffers::{FlatBuffer, dtype as fbd, message as fb};
16
17/// A message decoded from an IPC stream.
18///
19/// Note that the `Array` variant cannot fully decode into an [`vortex_array::ArrayRef`] without
20/// a [`vortex_array::ArrayContext`] and a [`DType`]. As such, we partially decode into an
21/// [`ArrayParts`] and allow the caller to finish the decoding.
22#[derive(Debug)]
23pub enum DecoderMessage {
24    Array((ArrayParts, ArrayContext, usize)),
25    Buffer(ByteBuffer),
26    DType(DType),
27}
28
29#[derive(Default)]
30enum State {
31    #[default]
32    Length,
33    Header(usize),
34    Reading(FlatBuffer),
35}
36
37#[derive(Debug)]
38pub enum PollRead {
39    Some(DecoderMessage),
40    /// Returns the _total_ number of bytes needed to make progress.
41    /// Note this is _not_ the incremental number of bytes needed to make progress.
42    NeedMore(usize),
43}
44
45// NOTE(ngates): we should design some trait that the Decoder can take that doesn't require unique
46//  ownership of the underlying bytes. The decoder needs to split out bytes, and advance a cursor,
47//  but it doesn't need to mutate any bytes. So in theory, we should be able to do this zero-copy
48//  over a shared buffer of bytes, instead of requiring a `BytesMut`.
49/// A stateful reader for decoding IPC messages from an arbitrary stream of bytes.
50pub struct MessageDecoder {
51    registry: ArrayRegistry,
52    /// The current state of the decoder.
53    state: State,
54}
55
56impl MessageDecoder {
57    /// Create a new message decoder that can lookup encodings in the given registry.
58    pub fn new(registry: ArrayRegistry) -> Self {
59        Self {
60            registry,
61            state: State::default(),
62        }
63    }
64
65    /// Attempt to read the next message from the bytes object.
66    ///
67    /// If the message is incomplete, the function will return `NeedMore` with the _total_ number
68    /// of bytes needed to make progress. The next call to read_next _should_ provide at least
69    /// this number of bytes otherwise it will be given the same `NeedMore` response.
70    pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
71        loop {
72            match &self.state {
73                State::Length => {
74                    if bytes.remaining() < 4 {
75                        return Ok(PollRead::NeedMore(4));
76                    }
77
78                    let msg_length = bytes.get_u32_le();
79                    self.state = State::Header(msg_length as usize);
80                }
81                State::Header(msg_length) => {
82                    if bytes.remaining() < *msg_length {
83                        return Ok(PollRead::NeedMore(*msg_length));
84                    }
85
86                    let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
87                    let msg = root::<fb::Message>(msg_bytes.as_ref())?;
88                    if msg.version() != MessageVersion::V0 {
89                        vortex_bail!("Unsupported message version {:?}", msg.version());
90                    }
91
92                    self.state = State::Reading(msg_bytes);
93                }
94                State::Reading(msg_bytes) => {
95                    // SAFETY: we've already validated the header in the previous state
96                    let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
97
98                    // Now we read the body
99                    let body_length = usize::try_from(msg.body_size()).map_err(|_| {
100                        vortex_err!("body size {} is too large for usize", msg.body_size())
101                    })?;
102                    if bytes.remaining() < body_length {
103                        return Ok(PollRead::NeedMore(body_length));
104                    }
105
106                    match msg.header_type() {
107                        MessageHeader::ArrayMessage => {
108                            // We don't care about alignment here since ArrayParts will handle it.
109                            let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
110                            let parts = ArrayParts::try_from(body)?;
111
112                            let header = msg
113                                .header_as_array_message()
114                                .vortex_expect("header is array");
115
116                            let encodings: Vec<_> = header
117                                .encodings()
118                                .iter()
119                                .flat_map(|e| e.iter())
120                                .map(|id| {
121                                    self.registry.find(id).ok_or_else(|| {
122                                        vortex_err!("unknown array encoding id: {}", id)
123                                    })
124                                })
125                                .try_collect()?;
126                            let ctx = ArrayContext::new(encodings);
127                            let row_count = header.row_count() as usize;
128
129                            self.state = Default::default();
130                            return Ok(PollRead::Some(DecoderMessage::Array((
131                                parts, ctx, row_count,
132                            ))));
133                        }
134                        MessageHeader::BufferMessage => {
135                            let body = bytes.copy_to_aligned(
136                                body_length,
137                                Alignment::from_exponent(
138                                    msg.header_as_buffer_message()
139                                        .vortex_expect("header is buffer")
140                                        .alignment_exponent(),
141                                ),
142                            );
143
144                            self.state = Default::default();
145                            return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
146                        }
147                        MessageHeader::DTypeMessage => {
148                            let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
149                            let fb_dtype = root::<fbd::DType>(body.as_ref())?;
150                            let dtype = DType::try_from_view(fb_dtype, body.clone())?;
151
152                            self.state = Default::default();
153                            return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
154                        }
155                        _ => {
156                            vortex_bail!("Unsupported message header {:?}", msg.header_type());
157                        }
158                    }
159                }
160            }
161        }
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use bytes::BytesMut;
168    use vortex_array::arrays::ConstantArray;
169    use vortex_array::{Array, ArraySession, IntoArray};
170    use vortex_buffer::buffer;
171    use vortex_error::vortex_panic;
172
173    use super::*;
174    use crate::messages::{EncoderMessage, MessageEncoder};
175
176    fn write_and_read(expected: &dyn Array) {
177        let mut ipc_bytes = BytesMut::new();
178        let mut encoder = MessageEncoder::default();
179        for buf in encoder.encode(EncoderMessage::Array(expected)) {
180            ipc_bytes.extend_from_slice(buf.as_ref());
181        }
182
183        let registry = ArraySession::default().registry().clone();
184        let mut decoder = MessageDecoder::new(registry);
185
186        // Since we provide all bytes up-front, we should never hit a NeedMore.
187        let mut buffer = BytesMut::from(ipc_bytes.as_ref());
188        let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
189            PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
190            otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
191        };
192
193        // Decode the array parts with the context
194        let actual = array_parts
195            .decode(&ctx, expected.dtype(), row_count)
196            .unwrap();
197
198        assert_eq!(expected.len(), actual.len());
199        assert_eq!(expected.encoding_id(), actual.encoding_id());
200    }
201
202    #[test]
203    fn array_ipc() {
204        write_and_read(&buffer![0i32, 1, 2, 3].into_array());
205    }
206
207    #[test]
208    fn array_no_buffers() {
209        // Constant arrays have a single buffer
210        let array = ConstantArray::new(10i32, 20);
211        assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
212        write_and_read(array.as_ref());
213    }
214}