1use vortex_array::patches::PatchesMetadata;
2use vortex_array::serde::ArrayChildren;
3use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
4use vortex_array::{
5 ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
6};
7use vortex_buffer::{ByteBuffer, ByteBufferMut};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail};
10use vortex_scalar::{Scalar, ScalarValue};
11
12use crate::{SparseArray, SparseEncoding, SparseVTable};
13
14#[derive(Clone, prost::Message)]
15#[repr(C)]
16pub struct SparseMetadata {
17 #[prost(message, required, tag = "1")]
18 patches: PatchesMetadata,
19}
20
21impl SerdeVTable<SparseVTable> for SparseVTable {
22 type Metadata = ProstMetadata<SparseMetadata>;
23
24 fn metadata(array: &SparseArray) -> VortexResult<Option<Self::Metadata>> {
25 Ok(Some(ProstMetadata(SparseMetadata {
26 patches: array.patches().to_metadata(array.len(), array.dtype())?,
27 })))
28 }
29
30 fn build(
31 _encoding: &SparseEncoding,
32 dtype: &DType,
33 len: usize,
34 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
35 buffers: &[ByteBuffer],
36 children: &dyn ArrayChildren,
37 ) -> VortexResult<SparseArray> {
38 if children.len() != 2 {
39 vortex_bail!(
40 "Expected 2 children for sparse encoding, found {}",
41 children.len()
42 )
43 }
44 assert_eq!(
45 metadata.patches.offset(),
46 0,
47 "Patches must start at offset 0"
48 );
49
50 let patch_indices =
51 children.get(0, &metadata.patches.indices_dtype(), metadata.patches.len())?;
52 let patch_values = children.get(1, dtype, metadata.patches.len())?;
53
54 if buffers.len() != 1 {
55 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
56 }
57 let fill_value = Scalar::new(dtype.clone(), ScalarValue::from_protobytes(&buffers[0])?);
58
59 SparseArray::try_new(patch_indices, patch_values, len, fill_value)
60 }
61}
62
63impl EncodeVTable<SparseVTable> for SparseVTable {
64 fn encode(
65 _encoding: &SparseEncoding,
66 input: &Canonical,
67 like: Option<&SparseArray>,
68 ) -> VortexResult<Option<SparseArray>> {
69 let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
71
72 Ok(SparseArray::encode(input.as_ref(), fill_value)?
74 .as_opt::<SparseVTable>()
75 .cloned())
76 }
77}
78
79impl VisitorVTable<SparseVTable> for SparseVTable {
80 fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
81 let fill_value_buffer = array
82 .fill_value
83 .value()
84 .to_protobytes::<ByteBufferMut>()
85 .freeze();
86 visitor.visit_buffer(&fill_value_buffer);
87 }
88
89 fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
90 visitor.visit_patches(array.patches())
91 }
92}