vortex_ipc/messages/
decoder.rs1use std::fmt::Debug;
5use std::sync::Arc;
6
7use bytes::Buf;
8use flatbuffers::root;
9use flatbuffers::root_unchecked;
10use vortex_array::ArrayContext;
11use vortex_array::serde::ArrayParts;
12use vortex_array::vtable::ArrayId;
13use vortex_buffer::AlignedBuf;
14use vortex_buffer::Alignment;
15use vortex_buffer::ByteBuffer;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20use vortex_flatbuffers::FlatBuffer;
21use vortex_flatbuffers::message as fb;
22use vortex_flatbuffers::message::MessageHeader;
23use vortex_flatbuffers::message::MessageVersion;
24
25#[derive(Debug)]
27pub enum DecoderMessage {
28 Array((ArrayParts, ArrayContext, usize)),
29 Buffer(ByteBuffer),
30 DType(FlatBuffer),
31}
32
33#[derive(Default)]
34enum State {
35 #[default]
36 Length,
37 Header(usize),
38 Reading(FlatBuffer),
39}
40
41#[derive(Debug)]
42pub enum PollRead {
43 Some(DecoderMessage),
45 NeedMore(usize),
59}
60
61#[derive(Default)]
67pub struct MessageDecoder {
68 state: State,
70}
71
72impl MessageDecoder {
73 pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
79 loop {
80 match &self.state {
81 State::Length => {
82 if bytes.remaining() < 4 {
83 return Ok(PollRead::NeedMore(4));
84 }
85
86 let msg_length = bytes.get_u32_le();
87 self.state = State::Header(msg_length as usize);
88 }
89 State::Header(msg_length) => {
90 if bytes.remaining() < *msg_length {
91 return Ok(PollRead::NeedMore(*msg_length));
92 }
93
94 let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
95 let msg = root::<fb::Message>(msg_bytes.as_ref())?;
96 if msg.version() != MessageVersion::V0 {
97 vortex_bail!("Unsupported message version {:?}", msg.version());
98 }
99
100 self.state = State::Reading(msg_bytes);
101 }
102 State::Reading(msg_bytes) => {
103 let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
105
106 let body_length = usize::try_from(msg.body_size()).map_err(|_| {
108 vortex_err!("body size {} is too large for usize", msg.body_size())
109 })?;
110 if bytes.remaining() < body_length {
111 return Ok(PollRead::NeedMore(body_length));
112 }
113
114 match msg.header_type() {
115 MessageHeader::ArrayMessage => {
116 let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
118 let parts = ArrayParts::try_from(body)?;
119
120 let header = msg
121 .header_as_array_message()
122 .vortex_expect("header is array");
123
124 let encoding_ids: Vec<_> = header
125 .encodings()
126 .iter()
127 .flat_map(|e| e.iter())
128 .map(|id| ArrayId::new_arc(Arc::from(id.to_string())))
129 .collect();
130
131 let ctx = ArrayContext::new(encoding_ids);
132 let row_count = header.row_count() as usize;
133
134 self.state = Default::default();
135 return Ok(PollRead::Some(DecoderMessage::Array((
136 parts, ctx, row_count,
137 ))));
138 }
139 MessageHeader::BufferMessage => {
140 let body = bytes.copy_to_aligned(
141 body_length,
142 Alignment::from_exponent(
143 msg.header_as_buffer_message()
144 .vortex_expect("header is buffer")
145 .alignment_exponent(),
146 ),
147 );
148
149 self.state = Default::default();
150 return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
151 }
152 MessageHeader::DTypeMessage => {
153 let dtype: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
154 self.state = Default::default();
155 return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
156 }
157 _ => {
158 vortex_bail!("Unsupported message header {:?}", msg.header_type());
159 }
160 }
161 }
162 }
163 }
164 }
165}
166
167#[cfg(test)]
168mod test {
169 use bytes::BytesMut;
170 use vortex_array::Array;
171 use vortex_array::ArrayRef;
172 use vortex_array::IntoArray;
173 use vortex_array::arrays::ConstantArray;
174 use vortex_buffer::buffer;
175 use vortex_error::vortex_panic;
176
177 use super::*;
178 use crate::messages::EncoderMessage;
179 use crate::messages::MessageEncoder;
180 use crate::test::SESSION;
181
182 fn write_and_read(expected: &ArrayRef) {
183 let mut ipc_bytes = BytesMut::new();
184 let mut encoder = MessageEncoder::default();
185 for buf in encoder.encode(EncoderMessage::Array(expected)).unwrap() {
186 ipc_bytes.extend_from_slice(buf.as_ref());
187 }
188
189 let mut decoder = MessageDecoder::default();
190
191 let mut buffer = BytesMut::from(ipc_bytes.as_ref());
193 let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
194 PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
195 otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
196 };
197
198 let actual = array_parts
200 .decode(expected.dtype(), row_count, &ctx, &SESSION)
201 .unwrap();
202
203 assert_eq!(expected.len(), actual.len());
204 assert_eq!(expected.encoding_id(), actual.encoding_id());
205 }
206
207 #[test]
208 fn array_ipc() {
209 write_and_read(&buffer![0i32, 1, 2, 3].into_array());
210 }
211
212 #[test]
213 fn array_no_buffers() {
214 let array = ConstantArray::new(10i32, 20);
216 assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
217 write_and_read(&array.to_array());
218 }
219}