vortex_sparse/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::patches::PatchesMetadata;
5use vortex_array::serde::ArrayChildren;
6use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
7use vortex_array::{
8    ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
9};
10use vortex_buffer::{ByteBuffer, ByteBufferMut};
11use vortex_dtype::DType;
12use vortex_error::{VortexResult, vortex_bail};
13use vortex_scalar::{Scalar, ScalarValue};
14
15use crate::{SparseArray, SparseEncoding, SparseVTable};
16
17#[derive(Clone, prost::Message)]
18#[repr(C)]
19pub struct SparseMetadata {
20    #[prost(message, required, tag = "1")]
21    patches: PatchesMetadata,
22}
23
24impl SerdeVTable<SparseVTable> for SparseVTable {
25    type Metadata = ProstMetadata<SparseMetadata>;
26
27    fn metadata(array: &SparseArray) -> VortexResult<Option<Self::Metadata>> {
28        Ok(Some(ProstMetadata(SparseMetadata {
29            patches: array.patches().to_metadata(array.len(), array.dtype())?,
30        })))
31    }
32
33    fn build(
34        _encoding: &SparseEncoding,
35        dtype: &DType,
36        len: usize,
37        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
38        buffers: &[ByteBuffer],
39        children: &dyn ArrayChildren,
40    ) -> VortexResult<SparseArray> {
41        if children.len() != 2 {
42            vortex_bail!(
43                "Expected 2 children for sparse encoding, found {}",
44                children.len()
45            )
46        }
47        assert_eq!(
48            metadata.patches.offset(),
49            0,
50            "Patches must start at offset 0"
51        );
52
53        let patch_indices =
54            children.get(0, &metadata.patches.indices_dtype(), metadata.patches.len())?;
55        let patch_values = children.get(1, dtype, metadata.patches.len())?;
56
57        if buffers.len() != 1 {
58            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
59        }
60        let fill_value = Scalar::new(dtype.clone(), ScalarValue::from_protobytes(&buffers[0])?);
61
62        SparseArray::try_new(patch_indices, patch_values, len, fill_value)
63    }
64}
65
66impl EncodeVTable<SparseVTable> for SparseVTable {
67    fn encode(
68        _encoding: &SparseEncoding,
69        input: &Canonical,
70        like: Option<&SparseArray>,
71    ) -> VortexResult<Option<SparseArray>> {
72        // Try and cast the "like" fill value into the array's type. This is useful for cases where we narrow the arrays type.
73        let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
74
75        // TODO(ngates): encode should only handle arrays that _can_ be made sparse.
76        Ok(SparseArray::encode(input.as_ref(), fill_value)?
77            .as_opt::<SparseVTable>()
78            .cloned())
79    }
80}
81
82impl VisitorVTable<SparseVTable> for SparseVTable {
83    fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
84        let fill_value_buffer = array
85            .fill_value
86            .value()
87            .to_protobytes::<ByteBufferMut>()
88            .freeze();
89        visitor.visit_buffer(&fill_value_buffer);
90    }
91
92    fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
93        visitor.visit_patches(array.patches())
94    }
95}