Skip to main content

vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::dtype::Nullability;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_array::validity::Validity;
11use vortex_array::vtable::child_to_validity;
12use vortex_error::VortexResult;
13
14use crate::Zstd;
15use crate::ZstdData;
16impl CastReduce for Zstd {
17    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        if !dtype.eq_ignore_nullability(array.dtype()) {
19            // Type changes can't be handled in ZSTD, need to decode and tweak.
20            // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc.
21            return Ok(None);
22        }
23
24        let src_nullability = array.dtype().nullability();
25        let target_nullability = dtype.nullability();
26
27        match (src_nullability, target_nullability) {
28            // Same type case. This should be handled in the layer above but for
29            // completeness of the match arms we also handle it here.
30            (Nullability::Nullable, Nullability::Nullable)
31            | (Nullability::NonNullable, Nullability::NonNullable) => {
32                Ok(Some(array.array().clone()))
33            }
34            (Nullability::NonNullable, Nullability::Nullable) => {
35                // nonnull => null, trivial cast by altering the validity
36                let unsliced_validity =
37                    child_to_validity(&array.slots()[0], array.dtype().nullability());
38                Ok(Some(
39                    Zstd::try_new(
40                        dtype.clone(),
41                        ZstdData::new(
42                            array.dictionary.clone(),
43                            array.frames.clone(),
44                            array.metadata.clone(),
45                            array.unsliced_n_rows(),
46                        ),
47                        unsliced_validity,
48                    )?
49                    .into_array()
50                    .slice(array.slice_start()..array.slice_stop())?,
51                ))
52            }
53            (Nullability::Nullable, Nullability::NonNullable) => {
54                // null => non-null works if there are no nulls in the sliced range
55                let unsliced_validity =
56                    child_to_validity(&array.slots()[0], array.dtype().nullability());
57                let has_nulls = !matches!(
58                    unsliced_validity.slice(array.slice_start()..array.slice_stop())?,
59                    Validity::AllValid | Validity::NonNullable
60                );
61
62                // We don't attempt to handle casting when there are nulls.
63                if has_nulls {
64                    return Ok(None);
65                }
66
67                // If there are no nulls, the cast is trivial
68                Ok(Some(
69                    Zstd::try_new(
70                        dtype.clone(),
71                        ZstdData::new(
72                            array.dictionary.clone(),
73                            array.frames.clone(),
74                            array.metadata.clone(),
75                            array.unsliced_n_rows(),
76                        ),
77                        unsliced_validity,
78                    )?
79                    .into_array()
80                    .slice(array.slice_start()..array.slice_stop())?,
81                ))
82            }
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use rstest::rstest;
90    use vortex_array::IntoArray;
91    use vortex_array::ToCanonical;
92    use vortex_array::arrays::PrimitiveArray;
93    use vortex_array::assert_arrays_eq;
94    use vortex_array::builtins::ArrayBuiltins;
95    use vortex_array::compute::conformance::cast::test_cast_conformance;
96    use vortex_array::dtype::DType;
97    use vortex_array::dtype::Nullability;
98    use vortex_array::dtype::PType;
99    use vortex_array::validity::Validity;
100    use vortex_buffer::buffer;
101
102    use crate::Zstd;
103
104    #[test]
105    fn test_cast_zstd_i32_to_i64() {
106        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
107        let zstd = Zstd::from_primitive(&values, 0, 0).unwrap();
108
109        let casted = zstd
110            .into_array()
111            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
112            .unwrap();
113        assert_eq!(
114            casted.dtype(),
115            &DType::Primitive(PType::I64, Nullability::NonNullable)
116        );
117
118        let decoded = casted.to_primitive();
119        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
120    }
121
122    #[test]
123    fn test_cast_zstd_nullability_change() {
124        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
125        let zstd = Zstd::from_primitive(&values, 0, 0).unwrap();
126
127        let casted = zstd
128            .into_array()
129            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
130            .unwrap();
131        assert_eq!(
132            casted.dtype(),
133            &DType::Primitive(PType::U32, Nullability::Nullable)
134        );
135    }
136
137    #[test]
138    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
139        let values = PrimitiveArray::new(
140            buffer![10u32, 20, 30, 40, 50, 60],
141            Validity::from_iter([true, true, true, true, true, true]),
142        );
143        let zstd = Zstd::from_primitive(&values, 0, 128).unwrap();
144        let sliced = zstd.slice(1..5).unwrap();
145        let casted = sliced
146            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
147            .unwrap();
148        assert_eq!(
149            casted.dtype(),
150            &DType::Primitive(PType::U32, Nullability::NonNullable)
151        );
152        // Verify the values are correct
153        let decoded = casted.to_primitive();
154        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
155    }
156
157    #[test]
158    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
159        let values = PrimitiveArray::from_option_iter([
160            None,
161            Some(20u32),
162            Some(30),
163            Some(40),
164            Some(50),
165            Some(60),
166        ]);
167        let zstd = Zstd::from_primitive(&values, 0, 128).unwrap();
168        let sliced = zstd.slice(1..5).unwrap();
169        let casted = sliced
170            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
171            .unwrap();
172        assert_eq!(
173            casted.dtype(),
174            &DType::Primitive(PType::U32, Nullability::NonNullable)
175        );
176        let decoded = casted.to_primitive();
177        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
178        assert_arrays_eq!(decoded, expected);
179    }
180
181    #[rstest]
182    #[case::i32(PrimitiveArray::new(
183        buffer![100i32, 200, 300, 400, 500],
184        Validity::NonNullable,
185    ))]
186    #[case::f64(PrimitiveArray::new(
187        buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
188        Validity::NonNullable,
189    ))]
190    #[case::single(PrimitiveArray::new(
191        buffer![42i64],
192        Validity::NonNullable,
193    ))]
194    #[case::large(PrimitiveArray::new(
195        buffer![0u32..1000],
196        Validity::NonNullable,
197    ))]
198    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
199        let zstd = Zstd::from_primitive(&values, 0, 0).unwrap();
200        test_cast_conformance(&zstd.into_array());
201    }
202}