vortex_fsst/
serde.rs

1use fsst::{Compressor, Symbol};
2use serde::{Deserialize, Serialize};
3use vortex_array::arrays::VarBinArray;
4use vortex_array::serde::ArrayParts;
5use vortex_array::vtable::EncodingVTable;
6use vortex_array::{
7    Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayContext, ArrayExt, ArrayRef,
8    ArrayVisitorImpl, Canonical, DeserializeMetadata, Encoding, EncodingId, SerdeMetadata,
9};
10use vortex_buffer::Buffer;
11use vortex_dtype::{DType, Nullability, PType};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
13
14use crate::{FSSTArray, FSSTEncoding, fsst_compress, fsst_train_compressor};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct FSSTMetadata {
18    uncompressed_lengths_ptype: PType,
19}
20
21impl EncodingVTable for FSSTEncoding {
22    fn id(&self) -> EncodingId {
23        EncodingId::new_ref("vortex.fsst")
24    }
25
26    fn decode(
27        &self,
28        parts: &ArrayParts,
29        ctx: &ArrayContext,
30        dtype: DType,
31        len: usize,
32    ) -> VortexResult<ArrayRef> {
33        let metadata = SerdeMetadata::<FSSTMetadata>::deserialize(parts.metadata())?;
34
35        if parts.nbuffers() != 2 {
36            vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", parts.nbuffers());
37        }
38        let symbols = Buffer::<Symbol>::from_byte_buffer(parts.buffer(0)?);
39        let symbol_lengths = Buffer::<u8>::from_byte_buffer(parts.buffer(1)?);
40
41        if parts.nchildren() != 2 {
42            vortex_bail!(InvalidArgument: "Expected 2 children, got {}", parts.nchildren());
43        }
44        let codes = parts
45            .child(0)
46            .decode(ctx, DType::Binary(dtype.nullability()), len)?
47            .as_opt::<VarBinArray>()
48            .ok_or_else(|| {
49                vortex_err!(
50                    "Expected VarBinArray for codes, got {:?}",
51                    ctx.lookup_encoding(parts.child(0).encoding_id())
52                )
53            })?
54            .clone();
55        let uncompressed_lengths = parts.child(1).decode(
56            ctx,
57            DType::Primitive(
58                metadata.uncompressed_lengths_ptype,
59                Nullability::NonNullable,
60            ),
61            len,
62        )?;
63
64        Ok(
65            FSSTArray::try_new(dtype, symbols, symbol_lengths, codes, uncompressed_lengths)?
66                .into_array(),
67        )
68    }
69
70    fn encode(
71        &self,
72        input: &Canonical,
73        like: Option<&dyn Array>,
74    ) -> VortexResult<Option<ArrayRef>> {
75        let like = like
76            .map(|like| {
77                like.as_opt::<<Self as Encoding>::Array>().ok_or_else(|| {
78                    vortex_err!(
79                        "Expected {} encoded array but got {}",
80                        self.id(),
81                        like.encoding()
82                    )
83                })
84            })
85            .transpose()?;
86
87        let array = input.clone().into_varbinview()?;
88
89        let compressor = match like {
90            Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
91            None => fsst_train_compressor(&array)?,
92        };
93
94        let fsst = fsst_compress(&array, &compressor)?;
95
96        Ok(Some(fsst.into_array()))
97    }
98}
99
100impl ArrayVisitorImpl<SerdeMetadata<FSSTMetadata>> for FSSTArray {
101    fn _visit_buffers(&self, visitor: &mut dyn ArrayBufferVisitor) {
102        visitor.visit_buffer(&self.symbols().clone().into_byte_buffer());
103        visitor.visit_buffer(&self.symbol_lengths().clone().into_byte_buffer());
104    }
105
106    fn _visit_children(&self, visitor: &mut dyn ArrayChildVisitor) {
107        visitor.visit_child("codes", self.codes());
108        visitor.visit_child("uncompressed_lengths", self.uncompressed_lengths());
109    }
110
111    fn _metadata(&self) -> SerdeMetadata<FSSTMetadata> {
112        SerdeMetadata(FSSTMetadata {
113            uncompressed_lengths_ptype: PType::try_from(self.uncompressed_lengths().dtype())
114                .vortex_expect("Must be a valid PType"),
115        })
116    }
117}
118
119#[cfg(test)]
120mod test {
121    use vortex_array::SerdeMetadata;
122    use vortex_array::test_harness::check_metadata;
123    use vortex_dtype::PType;
124
125    use crate::serde::FSSTMetadata;
126
127    #[cfg_attr(miri, ignore)]
128    #[test]
129    fn test_fsst_metadata() {
130        check_metadata(
131            "fsst.metadata",
132            SerdeMetadata(FSSTMetadata {
133                uncompressed_lengths_ptype: PType::U64,
134            }),
135        );
136    }
137}