1use std::fmt::{Debug, Formatter};
2use std::iter;
3use std::sync::Arc;
4
5use flatbuffers::{FlatBufferBuilder, Follow, WIPOffset, root};
6use vortex_buffer::{Alignment, ByteBuffer};
7use vortex_dtype::{DType, TryFromBytes};
8use vortex_error::{
9 VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic,
10};
11use vortex_flatbuffers::array::Compression;
12use vortex_flatbuffers::{
13 FlatBuffer, FlatBufferRoot, ReadFlatBuffer, WriteFlatBuffer, array as fba,
14};
15
16use crate::stats::StatsSet;
17use crate::{Array, ArrayContext, ArrayRef, ArrayVisitor, ArrayVisitorExt};
18
19#[derive(Default, Debug)]
21pub struct SerializeOptions {
22 pub offset: usize,
25 pub include_padding: bool,
27}
28
29impl dyn Array + '_ {
30 pub fn serialize(
41 &self,
42 ctx: &ArrayContext,
43 options: &SerializeOptions,
44 ) -> VortexResult<Vec<ByteBuffer>> {
45 let array_buffers = self
47 .depth_first_traversal()
48 .flat_map(|f| f.buffers())
49 .collect::<Vec<_>>();
50
51 let mut buffers = vec![];
53 let mut fb_buffers = Vec::with_capacity(buffers.capacity());
54
55 let max_alignment = array_buffers
57 .iter()
58 .map(|buf| buf.alignment())
59 .chain(iter::once(FlatBuffer::alignment()))
60 .max()
61 .unwrap_or_else(FlatBuffer::alignment);
62
63 let zeros = ByteBuffer::zeroed(*max_alignment);
65
66 buffers.push(ByteBuffer::zeroed_aligned(0, max_alignment));
69
70 let mut pos = options.offset;
72
73 for buffer in array_buffers {
75 let padding = if options.include_padding {
76 let padding = pos.next_multiple_of(*buffer.alignment()) - pos;
77 if padding > 0 {
78 pos += padding;
79 buffers.push(zeros.slice(0..padding));
80 }
81 padding
82 } else {
83 0
84 };
85
86 fb_buffers.push(fba::Buffer::new(
87 u16::try_from(padding).vortex_expect("padding fits into u16"),
88 buffer.alignment().exponent(),
89 Compression::None,
90 u32::try_from(buffer.len())
91 .map_err(|_| vortex_err!("All buffers must fit into u32 for serialization"))?,
92 ));
93
94 pos += buffer.len();
95 buffers.push(buffer.aligned(Alignment::none()));
96 }
97
98 let mut fbb = FlatBufferBuilder::new();
100 let root = ArrayNodeFlatBuffer::try_new(ctx, self)?;
101 let fb_root = root.write_flatbuffer(&mut fbb);
102 let fb_buffers = fbb.create_vector(&fb_buffers);
103 let fb_array = fba::Array::create(
104 &mut fbb,
105 &fba::ArrayArgs {
106 root: Some(fb_root),
107 buffers: Some(fb_buffers),
108 },
109 );
110 fbb.finish_minimal(fb_array);
111 let (fb_vec, fb_start) = fbb.collapse();
112 let fb_end = fb_vec.len();
113 let fb_buffer = ByteBuffer::from(fb_vec).slice(fb_start..fb_end);
114 let fb_length = fb_buffer.len();
115
116 if options.include_padding {
117 let padding = pos.next_multiple_of(*FlatBuffer::alignment()) - pos;
118 if padding > 0 {
119 buffers.push(zeros.slice(0..padding));
120 }
121 }
122 buffers.push(fb_buffer);
123
124 buffers.push(ByteBuffer::from(
126 u32::try_from(fb_length)
127 .map_err(|_| vortex_err!("Array metadata flatbuffer must fit into u32 for serialization. Array encoding tree is too large."))?
128 .to_le_bytes()
129 .to_vec(),
130 ));
131
132 Ok(buffers)
133 }
134}
135
136pub struct ArrayNodeFlatBuffer<'a> {
138 ctx: &'a ArrayContext,
139 array: &'a dyn Array,
140 buffer_idx: u16,
141}
142
143impl<'a> ArrayNodeFlatBuffer<'a> {
144 pub fn try_new(ctx: &'a ArrayContext, array: &'a dyn Array) -> VortexResult<Self> {
145 for child in array.depth_first_traversal() {
147 if child.metadata()?.is_none() {
148 vortex_bail!(
149 "Array {} does not support serialization",
150 child.encoding_id()
151 );
152 }
153 }
154 Ok(Self {
155 ctx,
156 array,
157 buffer_idx: 0,
158 })
159 }
160}
161
162impl FlatBufferRoot for ArrayNodeFlatBuffer<'_> {}
163
164impl WriteFlatBuffer for ArrayNodeFlatBuffer<'_> {
165 type Target<'t> = fba::ArrayNode<'t>;
166
167 fn write_flatbuffer<'fb>(
168 &self,
169 fbb: &mut FlatBufferBuilder<'fb>,
170 ) -> WIPOffset<Self::Target<'fb>> {
171 let encoding = self.ctx.encoding_idx(&self.array.encoding());
172 let metadata = self
173 .array
174 .metadata()
175 .vortex_expect("Failed to serialize metadata")
177 .vortex_expect("Validated that all arrays support serialization");
178 let metadata = Some(fbb.create_vector(metadata.as_slice()));
179
180 let nbuffers = u16::try_from(self.array.nbuffers())
182 .vortex_expect("Array can have at most u16::MAX buffers");
183 let mut child_buffer_idx = self.buffer_idx + nbuffers;
184
185 let children = &self
186 .array
187 .children()
188 .iter()
189 .map(|child| {
190 let msg = ArrayNodeFlatBuffer {
192 ctx: self.ctx,
193 array: child,
194 buffer_idx: child_buffer_idx,
195 }
196 .write_flatbuffer(fbb);
197 child_buffer_idx = u16::try_from(child.nbuffers_recursive())
198 .ok()
199 .and_then(|nbuffers| nbuffers.checked_add(child_buffer_idx))
200 .vortex_expect("Too many buffers (u16) for Array");
201 msg
202 })
203 .collect::<Vec<_>>();
204 let children = Some(fbb.create_vector(children));
205
206 let buffers = Some(fbb.create_vector_from_iter((0..nbuffers).map(|i| i + self.buffer_idx)));
207 let stats = Some(self.array.statistics().to_owned().write_flatbuffer(fbb));
208
209 fba::ArrayNode::create(
210 fbb,
211 &fba::ArrayNodeArgs {
212 encoding,
213 metadata,
214 children,
215 buffers,
216 stats,
217 },
218 )
219 }
220}
221
222pub trait ArrayChildren {
227 fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef>;
229
230 fn len(&self) -> usize;
232
233 fn is_empty(&self) -> bool {
235 self.len() == 0
236 }
237}
238
239#[derive(Clone)]
246pub struct ArrayParts {
247 flatbuffer: FlatBuffer,
249 flatbuffer_loc: usize,
251 buffers: Arc<[ByteBuffer]>,
252}
253
254impl Debug for ArrayParts {
255 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 f.debug_struct("ArrayParts")
257 .field("encoding_id", &self.encoding_id())
258 .field("children", &(0..self.nchildren()).map(|i| self.child(i)))
259 .field(
260 "buffers",
261 &(0..self.nbuffers()).map(|i| self.buffer(i).ok()),
262 )
263 .field("metadata", &self.metadata())
264 .finish()
265 }
266}
267
268impl ArrayParts {
269 pub fn decode(&self, ctx: &ArrayContext, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
271 let encoding_id = self.flatbuffer().encoding();
272 let vtable = ctx
273 .lookup_encoding(encoding_id)
274 .ok_or_else(|| vortex_err!("Unknown encoding: {}", encoding_id))?;
275
276 let buffers: Vec<_> = (0..self.nbuffers())
277 .map(|idx| self.buffer(idx))
278 .try_collect()?;
279
280 let children = ArrayPartsChildren { parts: self, ctx };
281
282 let decoded = vtable.build(dtype, len, self.metadata(), &buffers, &children)?;
283
284 assert_eq!(
285 decoded.len(),
286 len,
287 "Array decoded from {} has incorrect length {}, expected {}",
288 vtable.id(),
289 decoded.len(),
290 len
291 );
292 assert_eq!(
293 decoded.dtype(),
294 dtype,
295 "Array decoded from {} has incorrect dtype {}, expected {}",
296 vtable.id(),
297 decoded.dtype(),
298 dtype,
299 );
300 assert_eq!(
301 decoded.encoding_id(),
302 vtable.id(),
303 "Array decoded from {} has incorrect encoding {}",
304 vtable.id(),
305 decoded.encoding_id(),
306 );
307
308 if let Some(stats) = self.flatbuffer().stats() {
310 let decoded_statistics = decoded.statistics();
311 StatsSet::read_flatbuffer(&stats)?
312 .into_iter()
313 .for_each(|(stat, val)| decoded_statistics.set(stat, val));
314 }
315
316 Ok(decoded)
317 }
318
319 pub fn encoding_id(&self) -> u16 {
321 self.flatbuffer().encoding()
322 }
323
324 pub fn metadata(&self) -> &[u8] {
326 self.flatbuffer()
327 .metadata()
328 .map(|metadata| metadata.bytes())
329 .unwrap_or(&[])
330 }
331
332 pub fn nchildren(&self) -> usize {
334 self.flatbuffer()
335 .children()
336 .map_or(0, |children| children.len())
337 }
338
339 pub fn child(&self, idx: usize) -> ArrayParts {
341 let children = self
342 .flatbuffer()
343 .children()
344 .vortex_expect("Expected array to have children");
345 if idx >= children.len() {
346 vortex_panic!(
347 "Invalid child index {} for array with {} children",
348 idx,
349 children.len()
350 );
351 }
352 self.with_root(children.get(idx))
353 }
354
355 pub fn nbuffers(&self) -> usize {
357 self.flatbuffer()
358 .buffers()
359 .map_or(0, |buffers| buffers.len())
360 }
361
362 pub fn buffer(&self, idx: usize) -> VortexResult<ByteBuffer> {
364 let buffer_idx = self
365 .flatbuffer()
366 .buffers()
367 .ok_or_else(|| vortex_err!("Array has no buffers"))?
368 .get(idx);
369 self.buffers
370 .get(buffer_idx as usize)
371 .cloned()
372 .ok_or_else(|| {
373 vortex_err!(
374 "Invalid buffer index {} for array with {} buffers",
375 buffer_idx,
376 self.nbuffers()
377 )
378 })
379 }
380
381 fn flatbuffer(&self) -> fba::ArrayNode {
383 unsafe { fba::ArrayNode::follow(self.flatbuffer.as_ref(), self.flatbuffer_loc) }
384 }
385
386 fn with_root(&self, root: fba::ArrayNode) -> Self {
389 let mut this = self.clone();
390 this.flatbuffer_loc = root._tab.loc();
391 this
392 }
393}
394
395struct ArrayPartsChildren<'a> {
396 parts: &'a ArrayParts,
397 ctx: &'a ArrayContext,
398}
399
400impl ArrayChildren for ArrayPartsChildren<'_> {
401 fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
402 self.parts.child(index).decode(self.ctx, dtype, len)
403 }
404
405 fn len(&self) -> usize {
406 self.parts.nchildren()
407 }
408}
409
410impl TryFrom<ByteBuffer> for ArrayParts {
411 type Error = VortexError;
412
413 fn try_from(value: ByteBuffer) -> Result<Self, Self::Error> {
414 if value.len() < 4 {
416 vortex_bail!("ArrayParts buffer is too short");
417 }
418
419 let value = value.aligned(Alignment::none());
421
422 let fb_length = u32::try_from_le_bytes(&value.as_slice()[value.len() - 4..])? as usize;
423 if value.len() < 4 + fb_length {
424 vortex_bail!("ArrayParts buffer is too short for flatbuffer");
425 }
426
427 let fb_offset = value.len() - 4 - fb_length;
428 let fb_buffer = value.slice(fb_offset..fb_offset + fb_length);
429 let fb_buffer = FlatBuffer::align_from(fb_buffer);
430
431 let fb_array = root::<fba::Array>(fb_buffer.as_ref())?;
432 let fb_root = fb_array.root().vortex_expect("Array must have a root node");
433
434 let mut offset = 0;
435 let buffers: Arc<[ByteBuffer]> = fb_array
436 .buffers()
437 .unwrap_or_default()
438 .iter()
439 .map(|fb_buffer| {
440 offset += fb_buffer.padding() as usize;
442
443 let buffer_len = fb_buffer.length() as usize;
444
445 let buffer = value
447 .slice(offset..(offset + buffer_len))
448 .aligned(Alignment::from_exponent(fb_buffer.alignment_exponent()));
449
450 offset += buffer_len;
451 buffer
452 })
453 .collect();
454
455 Ok(ArrayParts {
456 flatbuffer: fb_buffer.clone(),
457 flatbuffer_loc: fb_root._tab.loc(),
458 buffers,
459 })
460 }
461}