Skip to main content

vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::dtype::DType;
6use vortex_array::dtype::Nullability;
7use vortex_array::scalar_fn::fns::cast::CastReduce;
8use vortex_error::VortexResult;
9
10use crate::ZstdArray;
11use crate::ZstdVTable;
12
13impl CastReduce for ZstdVTable {
14    fn cast(array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        if !dtype.eq_ignore_nullability(array.dtype()) {
16            // Type changes can't be handled in ZSTD, need to decode and tweak.
17            // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc.
18            return Ok(None);
19        }
20
21        let src_nullability = array.dtype().nullability();
22        let target_nullability = dtype.nullability();
23
24        match (src_nullability, target_nullability) {
25            // Same type case. This should be handled in the layer above but for
26            // completeness of the match arms we also handle it here.
27            (Nullability::Nullable, Nullability::Nullable)
28            | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())),
29            (Nullability::NonNullable, Nullability::Nullable) => {
30                // nonnull => null, trivial cast by altering the validity
31                Ok(Some(
32                    ZstdArray::new(
33                        array.dictionary.clone(),
34                        array.frames.clone(),
35                        dtype.clone(),
36                        array.metadata.clone(),
37                        array.unsliced_n_rows(),
38                        array.unsliced_validity.clone(),
39                    )
40                    .slice(array.slice_start()..array.slice_stop())?,
41                ))
42            }
43            (Nullability::Nullable, Nullability::NonNullable) => {
44                // null => non-null works if there are no nulls in the sliced range
45                let sliced_len = array.slice_stop() - array.slice_start();
46                let has_nulls = !array
47                    .unsliced_validity
48                    .slice(array.slice_start()..array.slice_stop())?
49                    .all_valid(sliced_len)?;
50
51                // We don't attempt to handle casting when there are nulls.
52                if has_nulls {
53                    return Ok(None);
54                }
55
56                // If there are no nulls, the cast is trivial
57                Ok(Some(
58                    ZstdArray::new(
59                        array.dictionary.clone(),
60                        array.frames.clone(),
61                        dtype.clone(),
62                        array.metadata.clone(),
63                        array.unsliced_n_rows(),
64                        array.unsliced_validity.clone(),
65                    )
66                    .slice(array.slice_start()..array.slice_stop())?,
67                ))
68            }
69        }
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use rstest::rstest;
76    use vortex_array::ToCanonical;
77    use vortex_array::arrays::PrimitiveArray;
78    use vortex_array::assert_arrays_eq;
79    use vortex_array::builtins::ArrayBuiltins;
80    use vortex_array::compute::conformance::cast::test_cast_conformance;
81    use vortex_array::dtype::DType;
82    use vortex_array::dtype::Nullability;
83    use vortex_array::dtype::PType;
84    use vortex_array::validity::Validity;
85    use vortex_buffer::buffer;
86
87    use crate::ZstdArray;
88
89    #[test]
90    fn test_cast_zstd_i32_to_i64() {
91        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
92        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
93
94        let casted = zstd
95            .to_array()
96            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
97            .unwrap();
98        assert_eq!(
99            casted.dtype(),
100            &DType::Primitive(PType::I64, Nullability::NonNullable)
101        );
102
103        let decoded = casted.to_primitive();
104        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
105    }
106
107    #[test]
108    fn test_cast_zstd_nullability_change() {
109        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
110        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
111
112        let casted = zstd
113            .to_array()
114            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
115            .unwrap();
116        assert_eq!(
117            casted.dtype(),
118            &DType::Primitive(PType::U32, Nullability::Nullable)
119        );
120    }
121
122    #[test]
123    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
124        let values = PrimitiveArray::new(
125            buffer![10u32, 20, 30, 40, 50, 60],
126            Validity::from_iter([true, true, true, true, true, true]),
127        );
128        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
129        let sliced = zstd.slice(1..5).unwrap();
130        let casted = sliced
131            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
132            .unwrap();
133        assert_eq!(
134            casted.dtype(),
135            &DType::Primitive(PType::U32, Nullability::NonNullable)
136        );
137        // Verify the values are correct
138        let decoded = casted.to_primitive();
139        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
140    }
141
142    #[test]
143    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
144        let values = PrimitiveArray::from_option_iter([
145            None,
146            Some(20u32),
147            Some(30),
148            Some(40),
149            Some(50),
150            Some(60),
151        ]);
152        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
153        let sliced = zstd.slice(1..5).unwrap();
154        let casted = sliced
155            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
156            .unwrap();
157        assert_eq!(
158            casted.dtype(),
159            &DType::Primitive(PType::U32, Nullability::NonNullable)
160        );
161        let decoded = casted.to_primitive();
162        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
163        assert_arrays_eq!(decoded, expected);
164    }
165
166    #[rstest]
167    #[case::i32(PrimitiveArray::new(
168        buffer![100i32, 200, 300, 400, 500],
169        Validity::NonNullable,
170    ))]
171    #[case::f64(PrimitiveArray::new(
172        buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
173        Validity::NonNullable,
174    ))]
175    #[case::single(PrimitiveArray::new(
176        buffer![42i64],
177        Validity::NonNullable,
178    ))]
179    #[case::large(PrimitiveArray::new(
180        buffer![0u32..1000],
181        Validity::NonNullable,
182    ))]
183    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
184        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
185        test_cast_conformance(&zstd.to_array());
186    }
187}