Skip to main content

vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::dtype::DType;
6use vortex_array::dtype::Nullability;
7use vortex_array::scalar_fn::fns::cast::CastReduce;
8use vortex_error::VortexResult;
9
10use crate::ZstdArray;
11use crate::ZstdVTable;
12
13impl CastReduce for ZstdVTable {
14    fn cast(array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        if !dtype.eq_ignore_nullability(array.dtype()) {
16            // Type changes can't be handled in ZSTD, need to decode and tweak.
17            // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc.
18            return Ok(None);
19        }
20
21        let src_nullability = array.dtype().nullability();
22        let target_nullability = dtype.nullability();
23
24        match (src_nullability, target_nullability) {
25            // Same type case. This should be handled in the layer above but for
26            // completeness of the match arms we also handle it here.
27            (Nullability::Nullable, Nullability::Nullable)
28            | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())),
29            (Nullability::NonNullable, Nullability::Nullable) => {
30                // nonnull => null, trivial cast by altering the validity
31                Ok(Some(
32                    ZstdArray::new(
33                        array.dictionary.clone(),
34                        array.frames.clone(),
35                        dtype.clone(),
36                        array.metadata.clone(),
37                        array.unsliced_n_rows(),
38                        array.unsliced_validity.clone(),
39                    )
40                    .slice(array.slice_start()..array.slice_stop())?,
41                ))
42            }
43            (Nullability::Nullable, Nullability::NonNullable) => {
44                // null => non-null works if there are no nulls in the sliced range
45                let sliced_len = array.slice_stop() - array.slice_start();
46                let has_nulls = !array
47                    .unsliced_validity
48                    .slice(array.slice_start()..array.slice_stop())?
49                    .all_valid(sliced_len)?;
50
51                // We don't attempt to handle casting when there are nulls.
52                if has_nulls {
53                    return Ok(None);
54                }
55
56                // If there are no nulls, the cast is trivial
57                Ok(Some(
58                    ZstdArray::new(
59                        array.dictionary.clone(),
60                        array.frames.clone(),
61                        dtype.clone(),
62                        array.metadata.clone(),
63                        array.unsliced_n_rows(),
64                        array.unsliced_validity.clone(),
65                    )
66                    .slice(array.slice_start()..array.slice_stop())?,
67                ))
68            }
69        }
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use rstest::rstest;
76    use vortex_array::ToCanonical;
77    use vortex_array::arrays::PrimitiveArray;
78    use vortex_array::assert_arrays_eq;
79    use vortex_array::builtins::ArrayBuiltins;
80    use vortex_array::compute::conformance::cast::test_cast_conformance;
81    use vortex_array::dtype::DType;
82    use vortex_array::dtype::Nullability;
83    use vortex_array::dtype::PType;
84    use vortex_array::validity::Validity;
85    use vortex_buffer::Buffer;
86
87    use crate::ZstdArray;
88
89    #[test]
90    fn test_cast_zstd_i32_to_i64() {
91        let values = PrimitiveArray::new(
92            Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
93            Validity::NonNullable,
94        );
95        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
96
97        let casted = zstd
98            .to_array()
99            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
100            .unwrap();
101        assert_eq!(
102            casted.dtype(),
103            &DType::Primitive(PType::I64, Nullability::NonNullable)
104        );
105
106        let decoded = casted.to_primitive();
107        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
108    }
109
110    #[test]
111    fn test_cast_zstd_nullability_change() {
112        let values = PrimitiveArray::new(
113            Buffer::copy_from(vec![10u32, 20, 30, 40]),
114            Validity::NonNullable,
115        );
116        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
117
118        let casted = zstd
119            .to_array()
120            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
121            .unwrap();
122        assert_eq!(
123            casted.dtype(),
124            &DType::Primitive(PType::U32, Nullability::Nullable)
125        );
126    }
127
128    #[test]
129    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
130        let values = PrimitiveArray::new(
131            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
132            Validity::from_iter([true, true, true, true, true, true]),
133        );
134        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
135        let sliced = zstd.slice(1..5).unwrap();
136        let casted = sliced
137            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
138            .unwrap();
139        assert_eq!(
140            casted.dtype(),
141            &DType::Primitive(PType::U32, Nullability::NonNullable)
142        );
143        // Verify the values are correct
144        let decoded = casted.to_primitive();
145        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
146    }
147
148    #[test]
149    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
150        let values = PrimitiveArray::from_option_iter([
151            None,
152            Some(20u32),
153            Some(30),
154            Some(40),
155            Some(50),
156            Some(60),
157        ]);
158        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
159        let sliced = zstd.slice(1..5).unwrap();
160        let casted = sliced
161            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
162            .unwrap();
163        assert_eq!(
164            casted.dtype(),
165            &DType::Primitive(PType::U32, Nullability::NonNullable)
166        );
167        let decoded = casted.to_primitive();
168        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
169        assert_arrays_eq!(decoded, expected);
170    }
171
172    #[rstest]
173    #[case::i32(PrimitiveArray::new(
174        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
175        Validity::NonNullable,
176    ))]
177    #[case::f64(PrimitiveArray::new(
178        Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
179        Validity::NonNullable,
180    ))]
181    #[case::single(PrimitiveArray::new(
182        Buffer::copy_from(vec![42i64]),
183        Validity::NonNullable,
184    ))]
185    #[case::large(PrimitiveArray::new(
186        Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
187        Validity::NonNullable,
188    ))]
189    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
190        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
191        test_cast_conformance(zstd.as_ref());
192    }
193}