1use vortex_array::patches::PatchesMetadata;
2use vortex_array::serde::ArrayParts;
3use vortex_array::vtable::SerdeVTable;
4use vortex_array::{
5 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl,
6 DeserializeMetadata, RkyvMetadata,
7};
8use vortex_buffer::ByteBufferMut;
9use vortex_dtype::DType;
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11use vortex_scalar::{Scalar, ScalarValue};
12
13use crate::{SparseArray, SparseEncoding};
14
15#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
16#[repr(C)]
17pub struct SparseMetadata {
18 patches: PatchesMetadata,
19}
20
21impl ArrayVisitorImpl<RkyvMetadata<SparseMetadata>> for SparseArray {
22 fn _buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
23 let fill_value_buffer = self
24 .fill_value
25 .value()
26 .to_flexbytes::<ByteBufferMut>()
27 .freeze();
28 visitor.visit_buffer(&fill_value_buffer);
29 }
30
31 fn _children(&self, visitor: &mut dyn ArrayChildVisitor) {
32 visitor.visit_patches(self.patches())
33 }
34
35 fn _metadata(&self) -> RkyvMetadata<SparseMetadata> {
36 RkyvMetadata(SparseMetadata {
37 patches: self
38 .patches()
39 .to_metadata(self.len(), self.dtype())
40 .vortex_expect("Failed to create patches metadata"),
41 })
42 }
43}
44
45impl SerdeVTable<&SparseArray> for SparseEncoding {
46 fn decode(
47 &self,
48 parts: &ArrayParts,
49 ctx: &ArrayContext,
50 dtype: DType,
51 len: usize,
52 ) -> VortexResult<ArrayRef> {
53 if parts.nchildren() != 2 {
54 vortex_bail!(
55 "Expected 2 children for sparse encoding, found {}",
56 parts.nchildren()
57 )
58 }
59 let metadata = RkyvMetadata::<SparseMetadata>::deserialize(parts.metadata())?;
60 assert_eq!(
61 metadata.patches.offset(),
62 0,
63 "Patches must start at offset 0"
64 );
65
66 let patch_indices = parts.child(0).decode(
67 ctx,
68 metadata.patches.indices_dtype(),
69 metadata.patches.len(),
70 )?;
71 let patch_values = parts
72 .child(1)
73 .decode(ctx, dtype.clone(), metadata.patches.len())?;
74
75 if parts.nbuffers() != 1 {
76 vortex_bail!("Expected 1 buffer, got {}", parts.nbuffers());
77 }
78 let fill_value = Scalar::new(dtype, ScalarValue::from_flexbytes(&parts.buffer(0)?)?);
79
80 Ok(SparseArray::try_new(patch_indices, patch_values, len, fill_value)?.into_array())
81 }
82}