Skip to main content

vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::dtype::DType;
7use vortex_array::dtype::Nullability;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_error::VortexResult;
10
11use crate::ZstdArray;
12use crate::ZstdVTable;
13
14impl CastReduce for ZstdVTable {
15    fn cast(array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        if !dtype.eq_ignore_nullability(array.dtype()) {
17            // Type changes can't be handled in ZSTD, need to decode and tweak.
18            // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc.
19            return Ok(None);
20        }
21
22        let src_nullability = array.dtype().nullability();
23        let target_nullability = dtype.nullability();
24
25        match (src_nullability, target_nullability) {
26            // Same type case. This should be handled in the layer above but for
27            // completeness of the match arms we also handle it here.
28            (Nullability::Nullable, Nullability::Nullable)
29            | (Nullability::NonNullable, Nullability::NonNullable) => {
30                Ok(Some(array.clone().into_array()))
31            }
32            (Nullability::NonNullable, Nullability::Nullable) => {
33                // nonnull => null, trivial cast by altering the validity
34                Ok(Some(
35                    ZstdArray::new(
36                        array.dictionary.clone(),
37                        array.frames.clone(),
38                        dtype.clone(),
39                        array.metadata.clone(),
40                        array.unsliced_n_rows(),
41                        array.unsliced_validity.clone(),
42                    )
43                    .slice(array.slice_start()..array.slice_stop())?,
44                ))
45            }
46            (Nullability::Nullable, Nullability::NonNullable) => {
47                // null => non-null works if there are no nulls in the sliced range
48                let sliced_len = array.slice_stop() - array.slice_start();
49                let has_nulls = !array
50                    .unsliced_validity
51                    .slice(array.slice_start()..array.slice_stop())?
52                    .all_valid(sliced_len)?;
53
54                // We don't attempt to handle casting when there are nulls.
55                if has_nulls {
56                    return Ok(None);
57                }
58
59                // If there are no nulls, the cast is trivial
60                Ok(Some(
61                    ZstdArray::new(
62                        array.dictionary.clone(),
63                        array.frames.clone(),
64                        dtype.clone(),
65                        array.metadata.clone(),
66                        array.unsliced_n_rows(),
67                        array.unsliced_validity.clone(),
68                    )
69                    .slice(array.slice_start()..array.slice_stop())?,
70                ))
71            }
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use rstest::rstest;
79    use vortex_array::IntoArray;
80    use vortex_array::ToCanonical;
81    use vortex_array::arrays::PrimitiveArray;
82    use vortex_array::assert_arrays_eq;
83    use vortex_array::builtins::ArrayBuiltins;
84    use vortex_array::compute::conformance::cast::test_cast_conformance;
85    use vortex_array::dtype::DType;
86    use vortex_array::dtype::Nullability;
87    use vortex_array::dtype::PType;
88    use vortex_array::validity::Validity;
89    use vortex_buffer::buffer;
90
91    use crate::ZstdArray;
92
93    #[test]
94    fn test_cast_zstd_i32_to_i64() {
95        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
96        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
97
98        let casted = zstd
99            .into_array()
100            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
101            .unwrap();
102        assert_eq!(
103            casted.dtype(),
104            &DType::Primitive(PType::I64, Nullability::NonNullable)
105        );
106
107        let decoded = casted.to_primitive();
108        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
109    }
110
111    #[test]
112    fn test_cast_zstd_nullability_change() {
113        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
114        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
115
116        let casted = zstd
117            .into_array()
118            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
119            .unwrap();
120        assert_eq!(
121            casted.dtype(),
122            &DType::Primitive(PType::U32, Nullability::Nullable)
123        );
124    }
125
126    #[test]
127    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
128        let values = PrimitiveArray::new(
129            buffer![10u32, 20, 30, 40, 50, 60],
130            Validity::from_iter([true, true, true, true, true, true]),
131        );
132        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
133        let sliced = zstd.slice(1..5).unwrap();
134        let casted = sliced
135            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
136            .unwrap();
137        assert_eq!(
138            casted.dtype(),
139            &DType::Primitive(PType::U32, Nullability::NonNullable)
140        );
141        // Verify the values are correct
142        let decoded = casted.to_primitive();
143        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
144    }
145
146    #[test]
147    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
148        let values = PrimitiveArray::from_option_iter([
149            None,
150            Some(20u32),
151            Some(30),
152            Some(40),
153            Some(50),
154            Some(60),
155        ]);
156        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
157        let sliced = zstd.slice(1..5).unwrap();
158        let casted = sliced
159            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
160            .unwrap();
161        assert_eq!(
162            casted.dtype(),
163            &DType::Primitive(PType::U32, Nullability::NonNullable)
164        );
165        let decoded = casted.to_primitive();
166        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
167        assert_arrays_eq!(decoded, expected);
168    }
169
170    #[rstest]
171    #[case::i32(PrimitiveArray::new(
172        buffer![100i32, 200, 300, 400, 500],
173        Validity::NonNullable,
174    ))]
175    #[case::f64(PrimitiveArray::new(
176        buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
177        Validity::NonNullable,
178    ))]
179    #[case::single(PrimitiveArray::new(
180        buffer![42i64],
181        Validity::NonNullable,
182    ))]
183    #[case::large(PrimitiveArray::new(
184        buffer![0u32..1000],
185        Validity::NonNullable,
186    ))]
187    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
188        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
189        test_cast_conformance(&zstd.into_array());
190    }
191}