Skip to main content

vortex_fsst/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::VarBinArray;
9use vortex_array::arrays::varbin::VarBinArrayExt;
10use vortex_array::dtype::DType;
11use vortex_array::scalar_fn::fns::cast::CastKernel;
12use vortex_array::scalar_fn::fns::cast::CastReduce;
13use vortex_array::validity::Validity;
14use vortex_error::VortexResult;
15
16use crate::FSST;
17use crate::FSSTArrayExt;
18
19fn build_with_codes_validity(
20    array: ArrayView<'_, FSST>,
21    dtype: &DType,
22    new_codes_validity: Validity,
23) -> VortexResult<ArrayRef> {
24    let codes = array.codes();
25    let new_codes = VarBinArray::try_new(
26        codes.offsets().clone(),
27        codes.bytes().clone(),
28        codes.dtype().with_nullability(dtype.nullability()),
29        new_codes_validity,
30    )?;
31
32    Ok(unsafe {
33        FSST::new_unchecked(
34            dtype.clone(),
35            array.symbols().clone(),
36            array.symbol_lengths().clone(),
37            new_codes,
38            array.uncompressed_lengths().clone(),
39        )
40    }
41    .into_array())
42}
43
44impl CastReduce for FSST {
45    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
46        if !array.dtype().eq_ignore_nullability(dtype) {
47            return Ok(None);
48        }
49
50        let codes = array.codes();
51        let Some(new_codes_validity) = codes
52            .validity()?
53            .trivial_cast_nullability(dtype.nullability(), codes.len())?
54        else {
55            return Ok(None);
56        };
57
58        Ok(Some(build_with_codes_validity(
59            array,
60            dtype,
61            new_codes_validity,
62        )?))
63    }
64}
65
66impl CastKernel for FSST {
67    fn cast(
68        array: ArrayView<'_, Self>,
69        dtype: &DType,
70        ctx: &mut ExecutionCtx,
71    ) -> VortexResult<Option<ArrayRef>> {
72        if !array.dtype().eq_ignore_nullability(dtype) {
73            return Ok(None);
74        }
75
76        let codes = array.codes();
77        let new_codes_validity =
78            codes
79                .validity()?
80                .cast_nullability(dtype.nullability(), codes.len(), ctx)?;
81
82        Ok(Some(build_with_codes_validity(
83            array,
84            dtype,
85            new_codes_validity,
86        )?))
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use std::sync::LazyLock;
93
94    use rstest::rstest;
95    use vortex_array::IntoArray;
96    use vortex_array::VortexSessionExecute;
97    use vortex_array::arrays::VarBinArray;
98    use vortex_array::builtins::ArrayBuiltins;
99    use vortex_array::compute::conformance::cast::test_cast_conformance;
100    use vortex_array::dtype::DType;
101    use vortex_array::dtype::Nullability;
102    use vortex_array::session::ArraySession;
103    use vortex_session::VortexSession;
104
105    use crate::fsst_compress;
106    use crate::fsst_train_compressor;
107
108    static SESSION: LazyLock<VortexSession> =
109        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
110
111    #[test]
112    fn test_cast_fsst_nullability() {
113        let mut ctx = SESSION.create_execution_ctx();
114        let strings = VarBinArray::from_iter(
115            vec![Some("hello"), Some("world"), Some("hello world")],
116            DType::Utf8(Nullability::NonNullable),
117        );
118
119        let compressor = fsst_train_compressor(&strings);
120        let len = strings.len();
121        let dtype = strings.dtype().clone();
122        let fsst = fsst_compress(strings, len, &dtype, &compressor, &mut ctx);
123
124        // Cast to nullable
125        let casted = fsst
126            .into_array()
127            .cast(DType::Utf8(Nullability::Nullable))
128            .unwrap();
129        assert_eq!(casted.dtype(), &DType::Utf8(Nullability::Nullable));
130    }
131
132    #[rstest]
133    #[case(VarBinArray::from_iter(
134        vec![Some("hello"), Some("world"), Some("hello world")],
135        DType::Utf8(Nullability::NonNullable)
136    ))]
137    #[case(VarBinArray::from_iter(
138        vec![Some("foo"), None, Some("bar"), Some("foobar")],
139        DType::Utf8(Nullability::Nullable)
140    ))]
141    #[case(VarBinArray::from_iter(
142        vec![Some("test")],
143        DType::Utf8(Nullability::NonNullable)
144    ))]
145    fn test_cast_fsst_conformance(#[case] array: VarBinArray) {
146        let mut ctx = SESSION.create_execution_ctx();
147        let compressor = fsst_train_compressor(&array);
148        let fsst = fsst_compress(&array, array.len(), array.dtype(), &compressor, &mut ctx);
149        test_cast_conformance(&fsst.into_array());
150    }
151}