vortex_fsst/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::VarBinVTable;
5use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
6use vortex_array::{ArrayRef, IntoArray, register_kernel};
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9
10use crate::{FSSTArray, FSSTVTable};
11
12impl CastKernel for FSSTVTable {
13    fn cast(&self, array: &FSSTArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14        // FSST is a string compression encoding.
15        // For nullability changes, we can cast the codes and symbols arrays
16        if array.dtype().eq_ignore_nullability(dtype) {
17            // Cast codes array to handle nullability
18            let new_codes = cast(
19                array.codes().as_ref(),
20                &array.codes().dtype().with_nullability(dtype.nullability()),
21            )?;
22
23            Ok(Some(
24                FSSTArray::try_new(
25                    dtype.clone(),
26                    array.symbols().clone(),
27                    array.symbol_lengths().clone(),
28                    new_codes.as_::<VarBinVTable>().clone(),
29                    array.uncompressed_lengths().clone(),
30                )?
31                .into_array(),
32            ))
33        } else {
34            Ok(None)
35        }
36    }
37}
38
39register_kernel!(CastKernelAdapter(FSSTVTable).lift());
40
41#[cfg(test)]
42mod tests {
43    use rstest::rstest;
44    use vortex_array::arrays::VarBinArray;
45    use vortex_array::compute::cast;
46    use vortex_array::compute::conformance::cast::test_cast_conformance;
47    use vortex_dtype::{DType, Nullability};
48
49    use crate::{fsst_compress, fsst_train_compressor};
50
51    #[test]
52    fn test_cast_fsst_nullability() {
53        let strings = VarBinArray::from_iter(
54            vec![Some("hello"), Some("world"), Some("hello world")],
55            DType::Utf8(Nullability::NonNullable),
56        );
57
58        let compressor = fsst_train_compressor(&strings);
59        let fsst = fsst_compress(strings, &compressor);
60
61        // Cast to nullable
62        let casted = cast(fsst.as_ref(), &DType::Utf8(Nullability::Nullable)).unwrap();
63        assert_eq!(casted.dtype(), &DType::Utf8(Nullability::Nullable));
64    }
65
66    #[rstest]
67    #[case(VarBinArray::from_iter(
68        vec![Some("hello"), Some("world"), Some("hello world")],
69        DType::Utf8(Nullability::NonNullable)
70    ))]
71    #[case(VarBinArray::from_iter(
72        vec![Some("foo"), None, Some("bar"), Some("foobar")],
73        DType::Utf8(Nullability::Nullable)
74    ))]
75    #[case(VarBinArray::from_iter(
76        vec![Some("test")],
77        DType::Utf8(Nullability::NonNullable)
78    ))]
79    fn test_cast_fsst_conformance(#[case] array: VarBinArray) {
80        let compressor = fsst_train_compressor(&array);
81        let fsst = fsst_compress(&array, &compressor);
82        test_cast_conformance(fsst.as_ref());
83    }
84}