vortex_ipc/messages/
decoder.rs1use std::fmt::Debug;
5
6use bytes::Buf;
7use flatbuffers::{root, root_unchecked};
8use vortex_array::serde::ArrayParts;
9use vortex_array::{ArrayContext, ArrayRegistry};
10use vortex_buffer::{AlignedBuf, Alignment, ByteBuffer};
11use vortex_dtype::DType;
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
13use vortex_flatbuffers::message::{MessageHeader, MessageVersion};
14use vortex_flatbuffers::{FlatBuffer, dtype as fbd, message as fb};
15
16#[derive(Debug)]
22pub enum DecoderMessage {
23 Array((ArrayParts, ArrayContext, usize)),
24 Buffer(ByteBuffer),
25 DType(DType),
26}
27
28#[derive(Default)]
29enum State {
30 #[default]
31 Length,
32 Header(usize),
33 Reading(FlatBuffer),
34}
35
36#[derive(Debug)]
37pub enum PollRead {
38 Some(DecoderMessage),
39 NeedMore(usize),
42}
43
44pub struct MessageDecoder {
50 registry: ArrayRegistry,
51 state: State,
53}
54
55impl MessageDecoder {
56 pub fn new(registry: ArrayRegistry) -> Self {
58 Self {
59 registry,
60 state: State::default(),
61 }
62 }
63
64 pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
70 loop {
71 match &self.state {
72 State::Length => {
73 if bytes.remaining() < 4 {
74 return Ok(PollRead::NeedMore(4));
75 }
76
77 let msg_length = bytes.get_u32_le();
78 self.state = State::Header(msg_length as usize);
79 }
80 State::Header(msg_length) => {
81 if bytes.remaining() < *msg_length {
82 return Ok(PollRead::NeedMore(*msg_length));
83 }
84
85 let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
86 let msg = root::<fb::Message>(msg_bytes.as_ref())?;
87 if msg.version() != MessageVersion::V0 {
88 vortex_bail!("Unsupported message version {:?}", msg.version());
89 }
90
91 self.state = State::Reading(msg_bytes);
92 }
93 State::Reading(msg_bytes) => {
94 let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
96
97 let body_length = usize::try_from(msg.body_size()).map_err(|_| {
99 vortex_err!("body size {} is too large for usize", msg.body_size())
100 })?;
101 if bytes.remaining() < body_length {
102 return Ok(PollRead::NeedMore(body_length));
103 }
104
105 match msg.header_type() {
106 MessageHeader::ArrayMessage => {
107 let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
109 let parts = ArrayParts::try_from(body)?;
110
111 let header = msg
112 .header_as_array_message()
113 .vortex_expect("header is array");
114
115 let ctx = self
116 .registry
117 .new_context(header.encodings().iter().flat_map(|e| e.iter()))?;
118 let row_count = header.row_count() as usize;
119
120 self.state = Default::default();
121 return Ok(PollRead::Some(DecoderMessage::Array((
122 parts, ctx, row_count,
123 ))));
124 }
125 MessageHeader::BufferMessage => {
126 let body = bytes.copy_to_aligned(
127 body_length,
128 Alignment::from_exponent(
129 msg.header_as_buffer_message()
130 .vortex_expect("header is buffer")
131 .alignment_exponent(),
132 ),
133 );
134
135 self.state = Default::default();
136 return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
137 }
138 MessageHeader::DTypeMessage => {
139 let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
140 let fb_dtype = root::<fbd::DType>(body.as_ref())?;
141 let dtype = DType::try_from_view(fb_dtype, body.clone())?;
142
143 self.state = Default::default();
144 return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
145 }
146 _ => {
147 vortex_bail!("Unsupported message header {:?}", msg.header_type());
148 }
149 }
150 }
151 }
152 }
153 }
154}
155
156#[cfg(test)]
157mod test {
158 use bytes::BytesMut;
159 use vortex_array::arrays::ConstantArray;
160 use vortex_array::{Array, IntoArray};
161 use vortex_buffer::buffer;
162 use vortex_error::vortex_panic;
163
164 use super::*;
165 use crate::messages::{EncoderMessage, MessageEncoder};
166
167 fn write_and_read(expected: &dyn Array) {
168 let mut ipc_bytes = BytesMut::new();
169 let mut encoder = MessageEncoder::default();
170 for buf in encoder.encode(EncoderMessage::Array(expected)) {
171 ipc_bytes.extend_from_slice(buf.as_ref());
172 }
173
174 let mut decoder = MessageDecoder::new(ArrayRegistry::canonical_only());
175
176 let mut buffer = BytesMut::from(ipc_bytes.as_ref());
178 let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
179 PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
180 otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
181 };
182
183 let actual = array_parts
185 .decode(&ctx, expected.dtype(), row_count)
186 .unwrap();
187
188 assert_eq!(expected.len(), actual.len());
189 assert_eq!(expected.encoding_id(), actual.encoding_id());
190 }
191
192 #[test]
193 fn array_ipc() {
194 write_and_read(&buffer![0i32, 1, 2, 3].into_array());
195 }
196
197 #[test]
198 fn array_no_buffers() {
199 let array = ConstantArray::new(10i32, 20);
201 assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
202 write_and_read(array.as_ref());
203 }
204}