Skip to main content

vortex_fsst/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::arrays::VarBin;
8use vortex_array::builtins::ArrayBuiltins;
9use vortex_array::dtype::DType;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_error::VortexResult;
12
13use crate::FSST;
14use crate::FSSTArrayExt;
15impl CastReduce for FSST {
16    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // FSST is a string compression encoding.
18        // For nullability changes, we can cast the codes and symbols arrays
19        if array.dtype().eq_ignore_nullability(dtype) {
20            // Cast codes array to handle nullability
21            let new_codes = array
22                .codes()
23                .into_array()
24                .cast(array.codes_dtype().with_nullability(dtype.nullability()))?;
25
26            Ok(Some(
27                FSST::try_new(
28                    dtype.clone(),
29                    array.symbols().clone(),
30                    array.symbol_lengths().clone(),
31                    new_codes.as_::<VarBin>().into_owned(),
32                    array.uncompressed_lengths().clone(),
33                )?
34                .into_array(),
35            ))
36        } else {
37            Ok(None)
38        }
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use rstest::rstest;
45    use vortex_array::IntoArray;
46    use vortex_array::arrays::VarBinArray;
47    use vortex_array::builtins::ArrayBuiltins;
48    use vortex_array::compute::conformance::cast::test_cast_conformance;
49    use vortex_array::dtype::DType;
50    use vortex_array::dtype::Nullability;
51
52    use crate::fsst_compress;
53    use crate::fsst_train_compressor;
54
55    #[test]
56    fn test_cast_fsst_nullability() {
57        let strings = VarBinArray::from_iter(
58            vec![Some("hello"), Some("world"), Some("hello world")],
59            DType::Utf8(Nullability::NonNullable),
60        );
61
62        let compressor = fsst_train_compressor(&strings);
63        let len = strings.len();
64        let dtype = strings.dtype().clone();
65        let fsst = fsst_compress(strings, len, &dtype, &compressor);
66
67        // Cast to nullable
68        let casted = fsst
69            .into_array()
70            .cast(DType::Utf8(Nullability::Nullable))
71            .unwrap();
72        assert_eq!(casted.dtype(), &DType::Utf8(Nullability::Nullable));
73    }
74
75    #[rstest]
76    #[case(VarBinArray::from_iter(
77        vec![Some("hello"), Some("world"), Some("hello world")],
78        DType::Utf8(Nullability::NonNullable)
79    ))]
80    #[case(VarBinArray::from_iter(
81        vec![Some("foo"), None, Some("bar"), Some("foobar")],
82        DType::Utf8(Nullability::Nullable)
83    ))]
84    #[case(VarBinArray::from_iter(
85        vec![Some("test")],
86        DType::Utf8(Nullability::NonNullable)
87    ))]
88    fn test_cast_fsst_conformance(#[case] array: VarBinArray) {
89        let compressor = fsst_train_compressor(&array);
90        let fsst = fsst_compress(&array, array.len(), array.dtype(), &compressor);
91        test_cast_conformance(&fsst.into_array());
92    }
93}