vortex_fsst/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::arrays::VarBinVTable;
7use vortex_array::compute::CastKernel;
8use vortex_array::compute::CastKernelAdapter;
9use vortex_array::compute::cast;
10use vortex_array::register_kernel;
11use vortex_dtype::DType;
12use vortex_error::VortexResult;
13
14use crate::FSSTArray;
15use crate::FSSTVTable;
16
17impl CastKernel for FSSTVTable {
18    fn cast(&self, array: &FSSTArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        // FSST is a string compression encoding.
20        // For nullability changes, we can cast the codes and symbols arrays
21        if array.dtype().eq_ignore_nullability(dtype) {
22            // Cast codes array to handle nullability
23            let new_codes = cast(
24                array.codes().as_ref(),
25                &array.codes().dtype().with_nullability(dtype.nullability()),
26            )?;
27
28            Ok(Some(
29                FSSTArray::try_new(
30                    dtype.clone(),
31                    array.symbols().clone(),
32                    array.symbol_lengths().clone(),
33                    new_codes.as_::<VarBinVTable>().clone(),
34                    array.uncompressed_lengths().clone(),
35                )?
36                .into_array(),
37            ))
38        } else {
39            Ok(None)
40        }
41    }
42}
43
44register_kernel!(CastKernelAdapter(FSSTVTable).lift());
45
46#[cfg(test)]
47mod tests {
48    use rstest::rstest;
49    use vortex_array::arrays::VarBinArray;
50    use vortex_array::compute::cast;
51    use vortex_array::compute::conformance::cast::test_cast_conformance;
52    use vortex_dtype::DType;
53    use vortex_dtype::Nullability;
54
55    use crate::fsst_compress;
56    use crate::fsst_train_compressor;
57
58    #[test]
59    fn test_cast_fsst_nullability() {
60        let strings = VarBinArray::from_iter(
61            vec![Some("hello"), Some("world"), Some("hello world")],
62            DType::Utf8(Nullability::NonNullable),
63        );
64
65        let compressor = fsst_train_compressor(&strings);
66        let fsst = fsst_compress(strings, &compressor);
67
68        // Cast to nullable
69        let casted = cast(fsst.as_ref(), &DType::Utf8(Nullability::Nullable)).unwrap();
70        assert_eq!(casted.dtype(), &DType::Utf8(Nullability::Nullable));
71    }
72
73    #[rstest]
74    #[case(VarBinArray::from_iter(
75        vec![Some("hello"), Some("world"), Some("hello world")],
76        DType::Utf8(Nullability::NonNullable)
77    ))]
78    #[case(VarBinArray::from_iter(
79        vec![Some("foo"), None, Some("bar"), Some("foobar")],
80        DType::Utf8(Nullability::Nullable)
81    ))]
82    #[case(VarBinArray::from_iter(
83        vec![Some("test")],
84        DType::Utf8(Nullability::NonNullable)
85    ))]
86    fn test_cast_fsst_conformance(#[case] array: VarBinArray) {
87        let compressor = fsst_train_compressor(&array);
88        let fsst = fsst_compress(&array, &compressor);
89        test_cast_conformance(fsst.as_ref());
90    }
91}