Skip to main content

vortex_fsst/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::arrays::VarBin;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::FSST;
13use crate::FSSTArray;
14
15impl CastReduce for FSST {
16    fn cast(array: &FSSTArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // FSST is a string compression encoding.
18        // For nullability changes, we can cast the codes and symbols arrays
19        if array.dtype().eq_ignore_nullability(dtype) {
20            // Cast codes array to handle nullability
21            let new_codes = array
22                .codes()
23                .clone()
24                .into_array()
25                .cast(array.codes().dtype().with_nullability(dtype.nullability()))?;
26
27            Ok(Some(
28                FSSTArray::try_new(
29                    dtype.clone(),
30                    array.symbols().clone(),
31                    array.symbol_lengths().clone(),
32                    new_codes.as_::<VarBin>().clone(),
33                    array.uncompressed_lengths().clone(),
34                )?
35                .into_array(),
36            ))
37        } else {
38            Ok(None)
39        }
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use rstest::rstest;
46    use vortex_array::IntoArray;
47    use vortex_array::arrays::VarBinArray;
48    use vortex_array::builtins::ArrayBuiltins;
49    use vortex_array::compute::conformance::cast::test_cast_conformance;
50    use vortex_array::dtype::DType;
51    use vortex_array::dtype::Nullability;
52
53    use crate::fsst_compress;
54    use crate::fsst_train_compressor;
55
56    #[test]
57    fn test_cast_fsst_nullability() {
58        let strings = VarBinArray::from_iter(
59            vec![Some("hello"), Some("world"), Some("hello world")],
60            DType::Utf8(Nullability::NonNullable),
61        );
62
63        let compressor = fsst_train_compressor(&strings);
64        let fsst = fsst_compress(strings, &compressor);
65
66        // Cast to nullable
67        let casted = fsst
68            .into_array()
69            .cast(DType::Utf8(Nullability::Nullable))
70            .unwrap();
71        assert_eq!(casted.dtype(), &DType::Utf8(Nullability::Nullable));
72    }
73
74    #[rstest]
75    #[case(VarBinArray::from_iter(
76        vec![Some("hello"), Some("world"), Some("hello world")],
77        DType::Utf8(Nullability::NonNullable)
78    ))]
79    #[case(VarBinArray::from_iter(
80        vec![Some("foo"), None, Some("bar"), Some("foobar")],
81        DType::Utf8(Nullability::Nullable)
82    ))]
83    #[case(VarBinArray::from_iter(
84        vec![Some("test")],
85        DType::Utf8(Nullability::NonNullable)
86    ))]
87    fn test_cast_fsst_conformance(#[case] array: VarBinArray) {
88        let compressor = fsst_train_compressor(&array);
89        let fsst = fsst_compress(&array, &compressor);
90        test_cast_conformance(&fsst.into_array());
91    }
92}