1use vortex_array::patches::PatchesMetadata;
2use vortex_array::serde::ArrayParts;
3use vortex_array::vtable::EncodingVTable;
4use vortex_array::{
5 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayExt, ArrayRef,
6 ArrayVisitorImpl, Canonical, DeserializeMetadata, Encoding, EncodingId, RkyvMetadata,
7};
8use vortex_buffer::ByteBufferMut;
9use vortex_dtype::DType;
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_scalar::{Scalar, ScalarValue};
12
13use crate::{SparseArray, SparseEncoding};
14
15#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
16#[repr(C)]
17pub struct SparseMetadata {
18 patches: PatchesMetadata,
19}
20
21impl EncodingVTable for SparseEncoding {
22 fn id(&self) -> EncodingId {
23 EncodingId::new_ref("vortex.sparse")
24 }
25
26 fn decode(
27 &self,
28 parts: &ArrayParts,
29 ctx: &ArrayContext,
30 dtype: DType,
31 len: usize,
32 ) -> VortexResult<ArrayRef> {
33 if parts.nchildren() != 2 {
34 vortex_bail!(
35 "Expected 2 children for sparse encoding, found {}",
36 parts.nchildren()
37 )
38 }
39 let metadata = RkyvMetadata::<SparseMetadata>::deserialize(parts.metadata())?;
40 assert_eq!(
41 metadata.patches.offset(),
42 0,
43 "Patches must start at offset 0"
44 );
45
46 let patch_indices = parts.child(0).decode(
47 ctx,
48 metadata.patches.indices_dtype(),
49 metadata.patches.len(),
50 )?;
51 let patch_values = parts
52 .child(1)
53 .decode(ctx, dtype.clone(), metadata.patches.len())?;
54
55 if parts.nbuffers() != 1 {
56 vortex_bail!("Expected 1 buffer, got {}", parts.nbuffers());
57 }
58 let fill_value = Scalar::new(dtype, ScalarValue::from_flexbytes(&parts.buffer(0)?)?);
59
60 Ok(SparseArray::try_new(patch_indices, patch_values, len, fill_value)?.into_array())
61 }
62
63 fn encode(
64 &self,
65 input: &Canonical,
66 like: Option<&dyn Array>,
67 ) -> VortexResult<Option<ArrayRef>> {
68 let like = like
69 .map(|like| {
70 like.as_opt::<<Self as Encoding>::Array>().ok_or_else(|| {
71 vortex_err!(
72 "Expected {} encoded array but got {}",
73 self.id(),
74 like.encoding()
75 )
76 })
77 })
78 .transpose()?;
79
80 let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
82
83 Ok(Some(SparseArray::encode(input.as_ref(), fill_value)?))
84 }
85}
86
87impl ArrayVisitorImpl<RkyvMetadata<SparseMetadata>> for SparseArray {
88 fn _visit_buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
89 let fill_value_buffer = self
90 .fill_value
91 .value()
92 .to_flexbytes::<ByteBufferMut>()
93 .freeze();
94 visitor.visit_buffer(&fill_value_buffer);
95 }
96
97 fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
98 visitor.visit_patches(self.patches())
99 }
100
101 fn _metadata(&self) -> RkyvMetadata<SparseMetadata> {
102 RkyvMetadata(SparseMetadata {
103 patches: self
104 .patches()
105 .to_metadata(self.len(), self.dtype())
106 .vortex_expect("Failed to create patches metadata"),
107 })
108 }
109}