vortex_ipc/messages/
decoder.rs1use std::fmt::Debug;
2
3use bytes::Buf;
4use flatbuffers::{root, root_unchecked};
5use vortex_array::serde::ArrayParts;
6use vortex_array::{ArrayContext, ArrayRegistry};
7use vortex_buffer::{AlignedBuf, Alignment, ByteBuffer};
8use vortex_dtype::DType;
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_flatbuffers::message::{MessageHeader, MessageVersion};
11use vortex_flatbuffers::{FlatBuffer, dtype as fbd, message as fb};
12
13#[derive(Debug)]
19pub enum DecoderMessage {
20 Array((ArrayParts, ArrayContext, usize)),
21 Buffer(ByteBuffer),
22 DType(DType),
23}
24
25#[derive(Default)]
26enum State {
27 #[default]
28 Length,
29 Header(usize),
30 Reading(FlatBuffer),
31}
32
33#[derive(Debug)]
34pub enum PollRead {
35 Some(DecoderMessage),
36 NeedMore(usize),
39}
40
41pub struct MessageDecoder {
47 registry: ArrayRegistry,
48 state: State,
50}
51
52impl MessageDecoder {
53 pub fn new(registry: ArrayRegistry) -> Self {
55 Self {
56 registry,
57 state: State::default(),
58 }
59 }
60
61 pub fn read_next<B: AlignedBuf>(&mut self, bytes: &mut B) -> VortexResult<PollRead> {
67 loop {
68 match &self.state {
69 State::Length => {
70 if bytes.remaining() < 4 {
71 return Ok(PollRead::NeedMore(4));
72 }
73
74 let msg_length = bytes.get_u32_le();
75 self.state = State::Header(msg_length as usize);
76 }
77 State::Header(msg_length) => {
78 if bytes.remaining() < *msg_length {
79 return Ok(PollRead::NeedMore(*msg_length));
80 }
81
82 let msg_bytes = bytes.copy_to_const_aligned(*msg_length);
83 let msg = root::<fb::Message>(msg_bytes.as_ref())?;
84 if msg.version() != MessageVersion::V0 {
85 vortex_bail!("Unsupported message version {:?}", msg.version());
86 }
87
88 self.state = State::Reading(msg_bytes);
89 }
90 State::Reading(msg_bytes) => {
91 let msg = unsafe { root_unchecked::<fb::Message>(msg_bytes.as_ref()) };
93
94 let body_length = usize::try_from(msg.body_size()).map_err(|_| {
96 vortex_err!("body size {} is too large for usize", msg.body_size())
97 })?;
98 if bytes.remaining() < body_length {
99 return Ok(PollRead::NeedMore(body_length));
100 }
101
102 match msg.header_type() {
103 MessageHeader::ArrayMessage => {
104 let body = bytes.copy_to_aligned(body_length, Alignment::new(1));
106 let parts = ArrayParts::try_from(body)?;
107
108 let header = msg
109 .header_as_array_message()
110 .vortex_expect("header is array");
111
112 let ctx = self
113 .registry
114 .new_context(header.encodings().iter().flat_map(|e| e.iter()))?;
115 let row_count = header.row_count() as usize;
116
117 self.state = Default::default();
118 return Ok(PollRead::Some(DecoderMessage::Array((
119 parts, ctx, row_count,
120 ))));
121 }
122 MessageHeader::BufferMessage => {
123 let body = bytes.copy_to_aligned(
124 body_length,
125 Alignment::from_exponent(
126 msg.header_as_buffer_message()
127 .vortex_expect("header is buffer")
128 .alignment_exponent(),
129 ),
130 );
131
132 self.state = Default::default();
133 return Ok(PollRead::Some(DecoderMessage::Buffer(body)));
134 }
135 MessageHeader::DTypeMessage => {
136 let body: FlatBuffer = bytes.copy_to_const_aligned::<8>(body_length);
137 let fb_dtype = root::<fbd::DType>(body.as_ref())?;
138 let dtype = DType::try_from_view(fb_dtype, body.clone())?;
139
140 self.state = Default::default();
141 return Ok(PollRead::Some(DecoderMessage::DType(dtype)));
142 }
143 _ => {
144 vortex_bail!("Unsupported message header {:?}", msg.header_type());
145 }
146 }
147 }
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod test {
155 use bytes::BytesMut;
156 use vortex_array::arrays::ConstantArray;
157 use vortex_array::{Array, ArrayVisitor, IntoArray};
158 use vortex_buffer::buffer;
159 use vortex_error::vortex_panic;
160
161 use super::*;
162 use crate::messages::{EncoderMessage, MessageEncoder};
163
164 fn write_and_read(expected: &dyn Array) {
165 let mut ipc_bytes = BytesMut::new();
166 let mut encoder = MessageEncoder::default();
167 for buf in encoder.encode(EncoderMessage::Array(expected)) {
168 ipc_bytes.extend_from_slice(buf.as_ref());
169 }
170
171 let mut decoder = MessageDecoder::new(ArrayRegistry::canonical_only());
172
173 let mut buffer = BytesMut::from(ipc_bytes.as_ref());
175 let (array_parts, ctx, row_count) = match decoder.read_next(&mut buffer).unwrap() {
176 PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts,
177 otherwise => vortex_panic!("Expected an array, got {:?}", otherwise),
178 };
179
180 let actual = array_parts
182 .decode(&ctx, expected.dtype().clone(), row_count)
183 .unwrap();
184
185 assert_eq!(expected.len(), actual.len());
186 assert_eq!(expected.encoding(), actual.encoding());
187 }
188
189 #[test]
190 fn array_ipc() {
191 write_and_read(&buffer![0i32, 1, 2, 3].into_array());
192 }
193
194 #[test]
195 fn array_no_buffers() {
196 let array = ConstantArray::new(10i32, 20);
198 assert_eq!(array.nbuffers(), 1, "Array should have a single buffer");
199 write_and_read(&array);
200 }
201}