use std::fmt::{Debug, Formatter};
use flatbuffers::{FlatBufferBuilder, Follow, WIPOffset};
use itertools::Itertools;
use vortex_buffer::ByteBuffer;
use vortex_dtype::DType;
use vortex_error::{vortex_panic, VortexExpect, VortexResult};
use vortex_flatbuffers::{
array as fba, FlatBuffer, FlatBufferRoot, WriteFlatBuffer, WriteFlatBufferExt,
};
use crate::{Array, ContextRef};
pub struct ArrayParts {
row_count: usize,
flatbuffer: FlatBuffer,
flatbuffer_loc: usize,
buffers: Vec<ByteBuffer>,
}
impl Debug for ArrayParts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrayParts")
.field("row_count", &self.row_count)
.field("flatbuffer", &self.flatbuffer.len())
.field("flatbuffer_loc", &self.flatbuffer_loc)
.field("buffers", &self.buffers.len())
.finish()
}
}
impl ArrayParts {
pub fn new(
row_count: usize,
array: fba::Array,
flatbuffer: FlatBuffer,
buffers: Vec<ByteBuffer>,
) -> Self {
if flatbuffer
.as_ref()
.as_slice()
.subslice_range(array._tab.buf())
!= Some(0..flatbuffer.len())
{
vortex_panic!("Array flatbuffer is not contained within the buffer");
}
Self {
row_count,
flatbuffer,
flatbuffer_loc: array._tab.loc(),
buffers,
}
}
pub fn decode(self, ctx: ContextRef, dtype: DType) -> VortexResult<Array> {
Array::try_new_viewed(
ctx,
dtype,
self.row_count,
self.flatbuffer,
|buf| unsafe { Ok(fba::Array::follow(buf, self.flatbuffer_loc)) },
self.buffers,
)
}
}
impl From<Array> for ArrayParts {
fn from(array: Array) -> Self {
let flatbuffer = ArrayPartsFlatBuffer {
array: &array,
buffer_idx: 0,
}
.write_flatbuffer_bytes();
let mut buffers: Vec<ByteBuffer> = vec![];
for child in array.depth_first_traversal() {
for buffer in child.byte_buffers() {
buffers.push(buffer);
}
}
Self {
row_count: array.len(),
flatbuffer,
flatbuffer_loc: 0,
buffers,
}
}
}
pub struct ArrayPartsFlatBuffer<'a> {
array: &'a Array,
buffer_idx: u16,
}
impl<'a> ArrayPartsFlatBuffer<'a> {
pub fn new(array: &'a Array) -> Self {
Self {
array,
buffer_idx: 0,
}
}
}
impl FlatBufferRoot for ArrayPartsFlatBuffer<'_> {}
impl WriteFlatBuffer for ArrayPartsFlatBuffer<'_> {
type Target<'t> = fba::Array<'t>;
fn write_flatbuffer<'fb>(
&self,
fbb: &mut FlatBufferBuilder<'fb>,
) -> WIPOffset<Self::Target<'fb>> {
let encoding = self.array.encoding().code();
let metadata = self
.array
.metadata_bytes()
.map(|bytes| fbb.create_vector(bytes));
let nbuffers = u16::try_from(self.array.nbuffers())
.vortex_expect("Array can have at most u16::MAX buffers");
let child_buffer_idx = self.buffer_idx + nbuffers;
let children = self
.array
.children()
.iter()
.scan(child_buffer_idx, |buffer_idx, child| {
let msg = ArrayPartsFlatBuffer {
array: child,
buffer_idx: *buffer_idx,
}
.write_flatbuffer(fbb);
*buffer_idx = u16::try_from(child.cumulative_nbuffers())
.ok()
.and_then(|nbuffers| nbuffers.checked_add(*buffer_idx))
.vortex_expect("Too many buffers (u16) for Array");
Some(msg)
})
.collect_vec();
let children = Some(fbb.create_vector(&children));
let buffers = Some(fbb.create_vector_from_iter((0..nbuffers).map(|i| i + self.buffer_idx)));
let stats = Some(self.array.statistics().write_flatbuffer(fbb));
fba::Array::create(
fbb,
&fba::ArrayArgs {
encoding,
metadata,
children,
buffers,
stats,
},
)
}
}