vortex_sparse/
serde.rs

1use vortex_array::patches::PatchesMetadata;
2use vortex_array::serde::ArrayParts;
3use vortex_array::vtable::SerdeVTable;
4use vortex_array::{
5    Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayRef, ArrayVisitorImpl,
6    DeserializeMetadata, RkyvMetadata,
7};
8use vortex_buffer::ByteBufferMut;
9use vortex_dtype::DType;
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11use vortex_scalar::{Scalar, ScalarValue};
12
13use crate::{SparseArray, SparseEncoding};
14
15#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
16#[repr(C)]
17pub struct SparseMetadata {
18    patches: PatchesMetadata,
19}
20
21impl ArrayVisitorImpl<RkyvMetadata<SparseMetadata>> for SparseArray {
22    fn _buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
23        let fill_value_buffer = self
24            .fill_value
25            .value()
26            .to_flexbytes::<ByteBufferMut>()
27            .freeze();
28        visitor.visit_buffer(&fill_value_buffer);
29    }
30
31    fn _children(&self, visitor: &mut dyn ArrayChildVisitor) {
32        visitor.visit_patches(self.patches())
33    }
34
35    fn _metadata(&self) -> RkyvMetadata<SparseMetadata> {
36        RkyvMetadata(SparseMetadata {
37            patches: self
38                .patches()
39                .to_metadata(self.len(), self.dtype())
40                .vortex_expect("Failed to create patches metadata"),
41        })
42    }
43}
44
45impl SerdeVTable<&SparseArray> for SparseEncoding {
46    fn decode(
47        &self,
48        parts: &ArrayParts,
49        ctx: &ArrayContext,
50        dtype: DType,
51        len: usize,
52    ) -> VortexResult<ArrayRef> {
53        if parts.nchildren() != 2 {
54            vortex_bail!(
55                "Expected 2 children for sparse encoding, found {}",
56                parts.nchildren()
57            )
58        }
59        let metadata = RkyvMetadata::<SparseMetadata>::deserialize(parts.metadata())?;
60        assert_eq!(
61            metadata.patches.offset(),
62            0,
63            "Patches must start at offset 0"
64        );
65
66        let patch_indices = parts.child(0).decode(
67            ctx,
68            metadata.patches.indices_dtype(),
69            metadata.patches.len(),
70        )?;
71        let patch_values = parts
72            .child(1)
73            .decode(ctx, dtype.clone(), metadata.patches.len())?;
74
75        if parts.nbuffers() != 1 {
76            vortex_bail!("Expected 1 buffer, got {}", parts.nbuffers());
77        }
78        let fill_value = Scalar::new(dtype, ScalarValue::from_flexbytes(&parts.buffer(0)?)?);
79
80        Ok(SparseArray::try_new(patch_indices, patch_values, len, fill_value)?.into_array())
81    }
82}