vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::compute::CastKernel;
6use vortex_array::compute::CastKernelAdapter;
7use vortex_array::register_kernel;
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_error::VortexResult;
11
12use crate::ZstdArray;
13use crate::ZstdVTable;
14
15impl CastKernel for ZstdVTable {
16    fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        if !dtype.eq_ignore_nullability(array.dtype()) {
18            // Type changes can't be handled in ZSTD, need to decode and tweak.
19            // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc.
20            return Ok(None);
21        }
22
23        let src_nullability = array.dtype().nullability();
24        let target_nullability = dtype.nullability();
25
26        match (src_nullability, target_nullability) {
27            // Same type case. This should be handled in the layer above but for
28            // completeness of the match arms we also handle it here.
29            (Nullability::Nullable, Nullability::Nullable)
30            | (Nullability::NonNullable, Nullability::NonNullable) => Ok(Some(array.to_array())),
31            (Nullability::NonNullable, Nullability::Nullable) => Ok(Some(
32                // nonnull => null, trivial cast by altering the validity
33                ZstdArray::new(
34                    array.dictionary.clone(),
35                    array.frames.clone(),
36                    dtype.clone(),
37                    array.metadata.clone(),
38                    array.unsliced_n_rows(),
39                    array.unsliced_validity.clone(),
40                )
41                .slice(array.slice_start()..array.slice_stop()),
42            )),
43            (Nullability::Nullable, Nullability::NonNullable) => {
44                // null => non-null works if there are no nulls in the sliced range
45                let sliced_len = array.slice_stop() - array.slice_start();
46                let has_nulls = !array
47                    .unsliced_validity
48                    .slice(array.slice_start()..array.slice_stop())
49                    .all_valid(sliced_len);
50
51                // We don't attempt to handle casting when there are nulls.
52                if has_nulls {
53                    return Ok(None);
54                }
55
56                // If there are no nulls, the cast is trivial
57                Ok(Some(
58                    ZstdArray::new(
59                        array.dictionary.clone(),
60                        array.frames.clone(),
61                        dtype.clone(),
62                        array.metadata.clone(),
63                        array.unsliced_n_rows(),
64                        array.unsliced_validity.clone(),
65                    )
66                    .slice(array.slice_start()..array.slice_stop()),
67                ))
68            }
69        }
70    }
71}
72
73register_kernel!(CastKernelAdapter(ZstdVTable).lift());
74
75#[cfg(test)]
76mod tests {
77    use rstest::rstest;
78    use vortex_array::ToCanonical;
79    use vortex_array::arrays::PrimitiveArray;
80    use vortex_array::assert_arrays_eq;
81    use vortex_array::compute::cast;
82    use vortex_array::compute::conformance::cast::test_cast_conformance;
83    use vortex_array::validity::Validity;
84    use vortex_buffer::Buffer;
85    use vortex_dtype::DType;
86    use vortex_dtype::Nullability;
87    use vortex_dtype::PType;
88
89    use crate::ZstdArray;
90
91    #[test]
92    fn test_cast_zstd_i32_to_i64() {
93        let values = PrimitiveArray::new(
94            Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
95            Validity::NonNullable,
96        );
97        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
98
99        let casted = cast(
100            zstd.as_ref(),
101            &DType::Primitive(PType::I64, Nullability::NonNullable),
102        )
103        .unwrap();
104        assert_eq!(
105            casted.dtype(),
106            &DType::Primitive(PType::I64, Nullability::NonNullable)
107        );
108
109        let decoded = casted.to_primitive();
110        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
111    }
112
113    #[test]
114    fn test_cast_zstd_nullability_change() {
115        let values = PrimitiveArray::new(
116            Buffer::copy_from(vec![10u32, 20, 30, 40]),
117            Validity::NonNullable,
118        );
119        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
120
121        let casted = cast(
122            zstd.as_ref(),
123            &DType::Primitive(PType::U32, Nullability::Nullable),
124        )
125        .unwrap();
126        assert_eq!(
127            casted.dtype(),
128            &DType::Primitive(PType::U32, Nullability::Nullable)
129        );
130    }
131
132    #[test]
133    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
134        let values = PrimitiveArray::new(
135            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
136            Validity::from_iter([true, true, true, true, true, true]),
137        );
138        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
139        let sliced = zstd.slice(1..5);
140        let casted = cast(
141            sliced.as_ref(),
142            &DType::Primitive(PType::U32, Nullability::NonNullable),
143        )
144        .unwrap();
145        assert_eq!(
146            casted.dtype(),
147            &DType::Primitive(PType::U32, Nullability::NonNullable)
148        );
149        // Verify the values are correct
150        let decoded = casted.to_primitive();
151        let u32_values = decoded.as_slice::<u32>();
152        assert_eq!(u32_values, &[20, 30, 40, 50]);
153    }
154
155    #[test]
156    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
157        let values = PrimitiveArray::from_option_iter([
158            None,
159            Some(20u32),
160            Some(30),
161            Some(40),
162            Some(50),
163            Some(60),
164        ]);
165        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
166        let sliced = zstd.slice(1..5);
167        let casted = cast(
168            sliced.as_ref(),
169            &DType::Primitive(PType::U32, Nullability::NonNullable),
170        )
171        .unwrap();
172        assert_eq!(
173            casted.dtype(),
174            &DType::Primitive(PType::U32, Nullability::NonNullable)
175        );
176        let decoded = casted.to_primitive();
177        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
178        assert_arrays_eq!(decoded, expected);
179    }
180
181    #[rstest]
182    #[case::i32(PrimitiveArray::new(
183        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
184        Validity::NonNullable,
185    ))]
186    #[case::f64(PrimitiveArray::new(
187        Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
188        Validity::NonNullable,
189    ))]
190    #[case::single(PrimitiveArray::new(
191        Buffer::copy_from(vec![42i64]),
192        Validity::NonNullable,
193    ))]
194    #[case::large(PrimitiveArray::new(
195        Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
196        Validity::NonNullable,
197    ))]
198    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
199        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
200        test_cast_conformance(zstd.as_ref());
201    }
202}