1use vortex_array::patches::PatchesMetadata;
5use vortex_array::serde::ArrayChildren;
6use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
7use vortex_array::{
8 ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
9};
10use vortex_buffer::{ByteBuffer, ByteBufferMut};
11use vortex_dtype::DType;
12use vortex_error::{VortexResult, vortex_bail};
13use vortex_scalar::{Scalar, ScalarValue};
14
15use crate::{SparseArray, SparseEncoding, SparseVTable};
16
17#[derive(Clone, prost::Message)]
18#[repr(C)]
19pub struct SparseMetadata {
20 #[prost(message, required, tag = "1")]
21 patches: PatchesMetadata,
22}
23
24impl SerdeVTable<SparseVTable> for SparseVTable {
25 type Metadata = ProstMetadata<SparseMetadata>;
26
27 fn metadata(array: &SparseArray) -> VortexResult<Option<Self::Metadata>> {
28 Ok(Some(ProstMetadata(SparseMetadata {
29 patches: array.patches().to_metadata(array.len(), array.dtype())?,
30 })))
31 }
32
33 fn build(
34 _encoding: &SparseEncoding,
35 dtype: &DType,
36 len: usize,
37 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
38 buffers: &[ByteBuffer],
39 children: &dyn ArrayChildren,
40 ) -> VortexResult<SparseArray> {
41 if children.len() != 2 {
42 vortex_bail!(
43 "Expected 2 children for sparse encoding, found {}",
44 children.len()
45 )
46 }
47 assert_eq!(
48 metadata.patches.offset(),
49 0,
50 "Patches must start at offset 0"
51 );
52
53 let patch_indices =
54 children.get(0, &metadata.patches.indices_dtype(), metadata.patches.len())?;
55 let patch_values = children.get(1, dtype, metadata.patches.len())?;
56
57 if buffers.len() != 1 {
58 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
59 }
60 let fill_value = Scalar::new(dtype.clone(), ScalarValue::from_protobytes(&buffers[0])?);
61
62 SparseArray::try_new(patch_indices, patch_values, len, fill_value)
63 }
64}
65
66impl EncodeVTable<SparseVTable> for SparseVTable {
67 fn encode(
68 _encoding: &SparseEncoding,
69 input: &Canonical,
70 like: Option<&SparseArray>,
71 ) -> VortexResult<Option<SparseArray>> {
72 let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
74
75 Ok(SparseArray::encode(input.as_ref(), fill_value)?
77 .as_opt::<SparseVTable>()
78 .cloned())
79 }
80}
81
82impl VisitorVTable<SparseVTable> for SparseVTable {
83 fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
84 let fill_value_buffer = array
85 .fill_value
86 .value()
87 .to_protobytes::<ByteBufferMut>()
88 .freeze();
89 visitor.visit_buffer(&fill_value_buffer);
90 }
91
92 fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
93 visitor.visit_patches(array.patches())
94 }
95}