vortex_sparse/
serde.rs

1use vortex_array::patches::PatchesMetadata;
2use vortex_array::serde::ArrayChildren;
3use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
4use vortex_array::{
5    ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
6};
7use vortex_buffer::{ByteBuffer, ByteBufferMut};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail};
10use vortex_scalar::{Scalar, ScalarValue};
11
12use crate::{SparseArray, SparseEncoding, SparseVTable};
13
14#[derive(Clone, prost::Message)]
15#[repr(C)]
16pub struct SparseMetadata {
17    #[prost(message, required, tag = "1")]
18    patches: PatchesMetadata,
19}
20
21impl SerdeVTable<SparseVTable> for SparseVTable {
22    type Metadata = ProstMetadata<SparseMetadata>;
23
24    fn metadata(array: &SparseArray) -> VortexResult<Option<Self::Metadata>> {
25        Ok(Some(ProstMetadata(SparseMetadata {
26            patches: array.patches().to_metadata(array.len(), array.dtype())?,
27        })))
28    }
29
30    fn build(
31        _encoding: &SparseEncoding,
32        dtype: &DType,
33        len: usize,
34        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
35        buffers: &[ByteBuffer],
36        children: &dyn ArrayChildren,
37    ) -> VortexResult<SparseArray> {
38        if children.len() != 2 {
39            vortex_bail!(
40                "Expected 2 children for sparse encoding, found {}",
41                children.len()
42            )
43        }
44        assert_eq!(
45            metadata.patches.offset(),
46            0,
47            "Patches must start at offset 0"
48        );
49
50        let patch_indices =
51            children.get(0, &metadata.patches.indices_dtype(), metadata.patches.len())?;
52        let patch_values = children.get(1, dtype, metadata.patches.len())?;
53
54        if buffers.len() != 1 {
55            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
56        }
57        let fill_value = Scalar::new(dtype.clone(), ScalarValue::from_protobytes(&buffers[0])?);
58
59        SparseArray::try_new(patch_indices, patch_values, len, fill_value)
60    }
61}
62
63impl EncodeVTable<SparseVTable> for SparseVTable {
64    fn encode(
65        _encoding: &SparseEncoding,
66        input: &Canonical,
67        like: Option<&SparseArray>,
68    ) -> VortexResult<Option<SparseArray>> {
69        // Try and cast the "like" fill value into the array's type. This is useful for cases where we narrow the arrays type.
70        let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
71
72        // TODO(ngates): encode should only handle arrays that _can_ be made sparse.
73        Ok(SparseArray::encode(input.as_ref(), fill_value)?
74            .as_opt::<SparseVTable>()
75            .cloned())
76    }
77}
78
79impl VisitorVTable<SparseVTable> for SparseVTable {
80    fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
81        let fill_value_buffer = array
82            .fill_value
83            .value()
84            .to_protobytes::<ByteBufferMut>()
85            .freeze();
86        visitor.visit_buffer(&fill_value_buffer);
87    }
88
89    fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
90        visitor.visit_patches(array.patches())
91    }
92}