1use vortex_array::patches::PatchesMetadata;
2use vortex_array::serde::ArrayParts;
3use vortex_array::vtable::EncodingVTable;
4use vortex_array::{
5 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayExt, ArrayRef,
6 ArrayVisitorImpl, Canonical, DeserializeMetadata, Encoding, EncodingId, ProstMetadata,
7};
8use vortex_buffer::ByteBufferMut;
9use vortex_dtype::DType;
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_scalar::{Scalar, ScalarValue};
12
13use crate::{SparseArray, SparseEncoding};
14
15#[derive(Clone, prost::Message)]
16#[repr(C)]
17pub struct SparseMetadata {
18 #[prost(message, required, tag = "1")]
19 patches: PatchesMetadata,
20}
21
22impl EncodingVTable for SparseEncoding {
23 fn id(&self) -> EncodingId {
24 EncodingId::new_ref("vortex.sparse")
25 }
26
27 fn decode(
28 &self,
29 parts: &ArrayParts,
30 ctx: &ArrayContext,
31 dtype: DType,
32 len: usize,
33 ) -> VortexResult<ArrayRef> {
34 if parts.nchildren() != 2 {
35 vortex_bail!(
36 "Expected 2 children for sparse encoding, found {}",
37 parts.nchildren()
38 )
39 }
40 let metadata = ProstMetadata::<SparseMetadata>::deserialize(parts.metadata())?;
41 assert_eq!(
42 metadata.patches.offset(),
43 0,
44 "Patches must start at offset 0"
45 );
46
47 let patch_indices = parts.child(0).decode(
48 ctx,
49 metadata.patches.indices_dtype(),
50 metadata.patches.len(),
51 )?;
52 let patch_values = parts
53 .child(1)
54 .decode(ctx, dtype.clone(), metadata.patches.len())?;
55
56 if parts.nbuffers() != 1 {
57 vortex_bail!("Expected 1 buffer, got {}", parts.nbuffers());
58 }
59 let fill_value = Scalar::new(dtype, ScalarValue::from_flexbytes(&parts.buffer(0)?)?);
60
61 Ok(SparseArray::try_new(patch_indices, patch_values, len, fill_value)?.into_array())
62 }
63
64 fn encode(
65 &self,
66 input: &Canonical,
67 like: Option<&dyn Array>,
68 ) -> VortexResult<Option<ArrayRef>> {
69 let like = like
70 .map(|like| {
71 like.as_opt::<<Self as Encoding>::Array>().ok_or_else(|| {
72 vortex_err!(
73 "Expected {} encoded array but got {}",
74 self.id(),
75 like.encoding()
76 )
77 })
78 })
79 .transpose()?;
80
81 let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
83
84 Ok(Some(SparseArray::encode(input.as_ref(), fill_value)?))
85 }
86}
87
88impl ArrayVisitorImpl<ProstMetadata<SparseMetadata>> for SparseArray {
89 fn _visit_buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
90 let fill_value_buffer = self
91 .fill_value
92 .value()
93 .to_flexbytes::<ByteBufferMut>()
94 .freeze();
95 visitor.visit_buffer(&fill_value_buffer);
96 }
97
98 fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
99 visitor.visit_patches(self.patches())
100 }
101
102 fn _metadata(&self) -> ProstMetadata<SparseMetadata> {
103 ProstMetadata(SparseMetadata {
104 patches: self
105 .patches()
106 .to_metadata(self.len(), self.dtype())
107 .vortex_expect("Failed to create patches metadata"),
108 })
109 }
110}