vortex_fsst/
serde.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use fsst::{Compressor, Symbol};
5use vortex_array::arrays::VarBinVTable;
6use vortex_array::serde::ArrayChildren;
7use vortex_array::vtable::{EncodeVTable, SerdeVTable, VisitorVTable};
8use vortex_array::{
9    Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
10};
11use vortex_buffer::{Buffer, ByteBuffer};
12use vortex_dtype::{DType, Nullability, PType};
13use vortex_error::{VortexResult, vortex_bail, vortex_err};
14
15use crate::{FSSTArray, FSSTEncoding, FSSTVTable, fsst_compress, fsst_train_compressor};
16
17#[derive(Clone, prost::Message)]
18pub struct FSSTMetadata {
19    #[prost(enumeration = "PType", tag = "1")]
20    uncompressed_lengths_ptype: i32,
21}
22
23impl SerdeVTable<FSSTVTable> for FSSTVTable {
24    type Metadata = ProstMetadata<FSSTMetadata>;
25
26    fn metadata(array: &FSSTArray) -> VortexResult<Option<Self::Metadata>> {
27        Ok(Some(ProstMetadata(FSSTMetadata {
28            uncompressed_lengths_ptype: PType::try_from(array.uncompressed_lengths().dtype())?
29                as i32,
30        })))
31    }
32
33    fn build(
34        _encoding: &FSSTEncoding,
35        dtype: &DType,
36        len: usize,
37        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
38        buffers: &[ByteBuffer],
39        children: &dyn ArrayChildren,
40    ) -> VortexResult<FSSTArray> {
41        if buffers.len() != 2 {
42            vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len());
43        }
44        let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone());
45        let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone());
46
47        if children.len() != 2 {
48            vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
49        }
50        let codes = children.get(0, &DType::Binary(dtype.nullability()), len)?;
51        let codes = codes
52            .as_opt::<VarBinVTable>()
53            .ok_or_else(|| {
54                vortex_err!(
55                    "Expected VarBinArray for codes, got {}",
56                    codes.encoding_id()
57                )
58            })?
59            .clone();
60        let uncompressed_lengths = children.get(
61            1,
62            &DType::Primitive(
63                metadata.uncompressed_lengths_ptype(),
64                Nullability::NonNullable,
65            ),
66            len,
67        )?;
68
69        FSSTArray::try_new(
70            dtype.clone(),
71            symbols,
72            symbol_lengths,
73            codes,
74            uncompressed_lengths,
75        )
76    }
77}
78
79impl EncodeVTable<FSSTVTable> for FSSTVTable {
80    fn encode(
81        _encoding: &FSSTEncoding,
82        canonical: &Canonical,
83        like: Option<&FSSTArray>,
84    ) -> VortexResult<Option<FSSTArray>> {
85        let array = canonical.clone().into_varbinview();
86
87        let compressor = match like {
88            Some(like) => Compressor::rebuild_from(like.symbols(), like.symbol_lengths()),
89            None => fsst_train_compressor(&array),
90        };
91
92        Ok(Some(fsst_compress(array, &compressor)))
93    }
94}
95
96impl VisitorVTable<FSSTVTable> for FSSTVTable {
97    fn visit_buffers(array: &FSSTArray, visitor: &mut dyn ArrayBufferVisitor) {
98        visitor.visit_buffer(&array.symbols().clone().into_byte_buffer());
99        visitor.visit_buffer(&array.symbol_lengths().clone().into_byte_buffer());
100    }
101
102    fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
103        visitor.visit_child("codes", array.codes().as_ref());
104        visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
105    }
106}
107
108#[cfg(test)]
109mod test {
110    use vortex_array::ProstMetadata;
111    use vortex_array::test_harness::check_metadata;
112    use vortex_dtype::PType;
113
114    use crate::serde::FSSTMetadata;
115
116    #[cfg_attr(miri, ignore)]
117    #[test]
118    fn test_fsst_metadata() {
119        check_metadata(
120            "fsst.metadata",
121            ProstMetadata(FSSTMetadata {
122                uncompressed_lengths_ptype: PType::U64 as i32,
123            }),
124        );
125    }
126}