vortex_ipc/messages/
decoder.rs1use std::fmt::Debug;
5
6use bytes::Buf;
7use flatbuffers::root;
8use flatbuffers::root_unchecked;
9use itertools::Itertools;
10use vortex_array::ArrayContext;
11use vortex_array::serde::ArrayParts;
12use vortex_array::session::ArrayRegistry;
13use vortex_buffer::AlignedBuf;
14use vortex_buffer::Alignment;
15use vortex_buffer::ByteBuffer;
16use vortex_dtype::DType;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_err;
21use vortex_flatbuffers::FlatBuffer;
22use vortex_flatbuffers::dtype as fbd;
23use vortex_flatbuffers::message as fb;
24use vortex_flatbuffers::message::MessageHeader;
25use vortex_flatbuffers::message::MessageVersion;
26
27#[derive(Debug)]
33pub enum DecoderMessage {
34 Array((ArrayParts, ArrayContext, usize)),
35 Buffer(ByteBuffer),
36 DType(DType),
37}
38
39#[derive(Default)]
40enum State {
41 #[default]
42 Length,
43 Header(usize),
44 Reading(FlatBuffer),
45}
46
47#[derive(Debug)]
48pub enum PollRead {
49 Some(DecoderMessage),
51 NeedMore(usize),
65}
66
67pub struct MessageDecoder {
73 registry: ArrayRegistry,
74 state: State,
76}
77
78impl MessageDecoder {
79 pub fn new(registry: ArrayRegistry) -> Self {
81 Self {
82 registry,
83 state: State::default(),
84 }
85 }
86
87 pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
93 loop {
94 match &self.state {
95 State::Length => {
96 if bytes.remaining() < 4 {
97 return Ok(PollRead::NeedMore(4));
98 }
99
100 let msg_length = bytes.get_u32_le();
101 self.state = State::Header(msg_length as usize);
102 }
103 State::Header(msg_length) => {
104 if bytes.remaining() < *msg_length {
105 return Ok(PollRead::NeedMore(*msg_length));
106 }
107
108 let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
109 let msg = root::<fb::Message>(msg_bytes.as_ref())?;
110 if msg.version() != MessageVersion::V0 {
111 vortex_bail!("Unsupported message version {:?}", msg.version());
112 }
113
114 self.state = State::Reading(msg_bytes);
115 }
116 State::Reading(msg_bytes) => {
117 let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
119
120 let body_length = usize::try_from(msg.body_size()).map_err(|_| {
122 vortex_err!("body size {} is too large for usize", msg.body_size())
123 })?;
124 if bytes.remaining() < body_length {
125 return Ok(PollRead::NeedMore(body_length));
126 }
127
128 match msg.header_type() {
129 MessageHeader::ArrayMessage => {
130 let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
132 let parts = ArrayParts::try_from(body)?;
133
134 let header = msg
135 .header_as_array_message()
136 .vortex_expect("header is array");
137
138 let encodings: Vec<_> = header
139 .encodings()
140 .iter()
141 .flat_map(|e| e.iter())
142 .map(|id| {
143 self.registry.find(id).ok_or_else(|| {
144 vortex_err!("unknown array encoding id: {}", id)
145 })
146 })
147 .try_collect()?;
148 let ctx = ArrayContext::new(encodings);
149 let row_count = header.row_count() as usize;
150
151 self.state = Default::default();
152 return Ok(PollRead::Some(DecoderMessage::Array((
153 parts, ctx, row_count,
154 ))));
155 }
156 MessageHeader::BufferMessage => {
157 let body = bytes.copy_to_aligned(
158 body_length,
159 Alignment::from_exponent(
160 msg.header_as_buffer_message()
161 .vortex_expect("header is buffer")
162 .alignment_exponent(),
163 ),
164 );
165
166 self.state = Default::default();
167 return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
168 }
169 MessageHeader::DTypeMessage => {
170 let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
171 let fb_dtype = root::<fbd::DType>(body.as_ref())?;
172 let dtype = DType::try_from_view(fb_dtype, body.clone())?;
173
174 self.state = Default::default();
175 return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
176 }
177 _ => {
178 vortex_bail!("Unsupported message header {:?}", msg.header_type());
179 }
180 }
181 }
182 }
183 }
184 }
185}
186
187#[cfg(test)]
188mod test {
189 use bytes::BytesMut;
190 use vortex_array::Array;
191 use vortex_array::IntoArray;
192 use vortex_array::arrays::ConstantArray;
193 use vortex_array::session::ArraySession;
194 use vortex_buffer::buffer;
195 use vortex_error::vortex_panic;
196
197 use super::*;
198 use crate::messages::EncoderMessage;
199 use crate::messages::MessageEncoder;
200
201 fn write_and_read(expected: &dyn Array) {
202 let mut ipc_bytes = BytesMut::new();
203 let mut encoder = MessageEncoder::default();
204 for buf in encoder.encode(EncoderMessage::Array(expected)) {
205 ipc_bytes.extend_from_slice(buf.as_ref());
206 }
207
208 let registry = ArraySession::default().registry().clone();
209 let mut decoder = MessageDecoder::new(registry);
210
211 let mut buffer = BytesMut::from(ipc_bytes.as_ref());
213 let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
214 PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
215 otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
216 };
217
218 let actual = array_parts
220 .decode(&ctx, expected.dtype(), row_count)
221 .unwrap();
222
223 assert_eq!(expected.len(), actual.len());
224 assert_eq!(expected.encoding_id(), actual.encoding_id());
225 }
226
227 #[test]
228 fn array_ipc() {
229 write_and_read(&buffer![0i32, 1, 2, 3].into_array());
230 }
231
232 #[test]
233 fn array_no_buffers() {
234 let array = ConstantArray::new(10i32, 20);
236 assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
237 write_and_read(array.as_ref());
238 }
239}