vortex_ipc/messages/
decoder.rs1use std::fmt::Debug;
5
6use bytes::Buf;
7use flatbuffers::{root, root_unchecked};
8use itertools::Itertools;
9use vortex_array::serde::ArrayParts;
10use vortex_array::{ArrayContext, ArrayRegistry};
11use vortex_buffer::{AlignedBuf, Alignment, ByteBuffer};
12use vortex_dtype::DType;
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_flatbuffers::message::{MessageHeader, MessageVersion};
15use vortex_flatbuffers::{FlatBuffer, dtype as fbd, message as fb};
16
17#[derive(Debug)]
23pub enum DecoderMessage {
24 Array((ArrayParts, ArrayContext, usize)),
25 Buffer(ByteBuffer),
26 DType(DType),
27}
28
29#[derive(Default)]
30enum State {
31 #[default]
32 Length,
33 Header(usize),
34 Reading(FlatBuffer),
35}
36
37#[derive(Debug)]
38pub enum PollRead {
39 Some(DecoderMessage),
40 NeedMore(usize),
43}
44
45pub struct MessageDecoder {
51 registry: ArrayRegistry,
52 state: State,
54}
55
56impl MessageDecoder {
57 pub fn new(registry: ArrayRegistry) -> Self {
59 Self {
60 registry,
61 state: State::default(),
62 }
63 }
64
65 pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
71 loop {
72 match &self.state {
73 State::Length => {
74 if bytes.remaining() < 4 {
75 return Ok(PollRead::NeedMore(4));
76 }
77
78 let msg_length = bytes.get_u32_le();
79 self.state = State::Header(msg_length as usize);
80 }
81 State::Header(msg_length) => {
82 if bytes.remaining() < *msg_length {
83 return Ok(PollRead::NeedMore(*msg_length));
84 }
85
86 let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
87 let msg = root::<fb::Message>(msg_bytes.as_ref())?;
88 if msg.version() != MessageVersion::V0 {
89 vortex_bail!("Unsupported message version {:?}", msg.version());
90 }
91
92 self.state = State::Reading(msg_bytes);
93 }
94 State::Reading(msg_bytes) => {
95 let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
97
98 let body_length = usize::try_from(msg.body_size()).map_err(|_| {
100 vortex_err!("body size {} is too large for usize", msg.body_size())
101 })?;
102 if bytes.remaining() < body_length {
103 return Ok(PollRead::NeedMore(body_length));
104 }
105
106 match msg.header_type() {
107 MessageHeader::ArrayMessage => {
108 let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
110 let parts = ArrayParts::try_from(body)?;
111
112 let header = msg
113 .header_as_array_message()
114 .vortex_expect("header is array");
115
116 let encodings: Vec<_> = header
117 .encodings()
118 .iter()
119 .flat_map(|e| e.iter())
120 .map(|id| {
121 self.registry.find(id).ok_or_else(|| {
122 vortex_err!("unknown array encoding id: {}", id)
123 })
124 })
125 .try_collect()?;
126 let ctx = ArrayContext::new(encodings);
127 let row_count = header.row_count() as usize;
128
129 self.state = Default::default();
130 return Ok(PollRead::Some(DecoderMessage::Array((
131 parts, ctx, row_count,
132 ))));
133 }
134 MessageHeader::BufferMessage => {
135 let body = bytes.copy_to_aligned(
136 body_length,
137 Alignment::from_exponent(
138 msg.header_as_buffer_message()
139 .vortex_expect("header is buffer")
140 .alignment_exponent(),
141 ),
142 );
143
144 self.state = Default::default();
145 return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
146 }
147 MessageHeader::DTypeMessage => {
148 let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
149 let fb_dtype = root::<fbd::DType>(body.as_ref())?;
150 let dtype = DType::try_from_view(fb_dtype, body.clone())?;
151
152 self.state = Default::default();
153 return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
154 }
155 _ => {
156 vortex_bail!("Unsupported message header {:?}", msg.header_type());
157 }
158 }
159 }
160 }
161 }
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use bytes::BytesMut;
168 use vortex_array::arrays::ConstantArray;
169 use vortex_array::{Array, ArraySession, IntoArray};
170 use vortex_buffer::buffer;
171 use vortex_error::vortex_panic;
172
173 use super::*;
174 use crate::messages::{EncoderMessage, MessageEncoder};
175
176 fn write_and_read(expected: &dyn Array) {
177 let mut ipc_bytes = BytesMut::new();
178 let mut encoder = MessageEncoder::default();
179 for buf in encoder.encode(EncoderMessage::Array(expected)) {
180 ipc_bytes.extend_from_slice(buf.as_ref());
181 }
182
183 let registry = ArraySession::default().registry().clone();
184 let mut decoder = MessageDecoder::new(registry);
185
186 let mut buffer = BytesMut::from(ipc_bytes.as_ref());
188 let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
189 PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
190 otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
191 };
192
193 let actual = array_parts
195 .decode(&ctx, expected.dtype(), row_count)
196 .unwrap();
197
198 assert_eq!(expected.len(), actual.len());
199 assert_eq!(expected.encoding_id(), actual.encoding_id());
200 }
201
202 #[test]
203 fn array_ipc() {
204 write_and_read(&buffer![0i32, 1, 2, 3].into_array());
205 }
206
207 #[test]
208 fn array_no_buffers() {
209 let array = ConstantArray::new(10i32, 20);
211 assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
212 write_and_read(array.as_ref());
213 }
214}