Skip to main content

vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::dtype::Nullability;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_array::vtable::child_to_validity;
11use vortex_error::VortexResult;
12
13use crate::Zstd;
14use crate::ZstdData;
15
16impl CastReduce for Zstd {
17    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        if !dtype.eq_ignore_nullability(array.dtype()) {
19            // Type changes can't be handled in ZSTD, need to decode and tweak.
20            // TODO(aduffy): handle trivial conversions like Binary -> UTF8, integer widening, etc.
21            return Ok(None);
22        }
23
24        let src_nullability = array.dtype().nullability();
25        let target_nullability = dtype.nullability();
26
27        let new_validity = match (src_nullability, target_nullability) {
28            // Same type case. This should be handled in the layer above but for
29            // completeness of the match arms we also handle it here.
30            (Nullability::Nullable, Nullability::Nullable)
31            | (Nullability::NonNullable, Nullability::NonNullable) => {
32                return Ok(Some(array.array().clone()));
33            }
34            (Nullability::NonNullable, Nullability::Nullable) => {
35                // nonnull => null, trivial cast by altering the validity
36                child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability())
37            }
38            (Nullability::Nullable, Nullability::NonNullable) => {
39                // null => non-null works if there are no nulls in the sliced range
40                let unsliced_validity =
41                    child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
42                let has_nulls = !unsliced_validity
43                    .slice(array.slice_start()..array.slice_stop())?
44                    .no_nulls();
45
46                // We don't attempt to handle casting when there are nulls.
47                if has_nulls {
48                    return Ok(None);
49                }
50                unsliced_validity
51            }
52        };
53
54        // If there are no nulls, the cast is trivial
55        Ok(Some(
56            Zstd::try_new(
57                dtype.clone(),
58                ZstdData::new(
59                    array.dictionary.clone(),
60                    array.frames.clone(),
61                    array.metadata.clone(),
62                    array.unsliced_n_rows(),
63                ),
64                new_validity,
65            )?
66            .into_array()
67            .slice(array.slice_start()..array.slice_stop())?,
68        ))
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::sync::LazyLock;
75
76    use rstest::rstest;
77    use vortex_array::IntoArray;
78    use vortex_array::VortexSessionExecute;
79    use vortex_array::arrays::PrimitiveArray;
80    use vortex_array::assert_arrays_eq;
81    use vortex_array::builtins::ArrayBuiltins;
82    use vortex_array::compute::conformance::cast::test_cast_conformance;
83    use vortex_array::dtype::DType;
84    use vortex_array::dtype::Nullability;
85    use vortex_array::dtype::PType;
86    use vortex_array::session::ArraySession;
87    use vortex_array::validity::Validity;
88    use vortex_buffer::buffer;
89    use vortex_session::VortexSession;
90
91    use crate::Zstd;
92
93    static SESSION: LazyLock<VortexSession> =
94        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
95
96    #[test]
97    fn test_cast_zstd_i32_to_i64() {
98        let mut ctx = SESSION.create_execution_ctx();
99        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
100        let zstd = Zstd::from_primitive(&values, 0, 0, &mut ctx).unwrap();
101
102        let casted = zstd
103            .into_array()
104            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
105            .unwrap();
106        assert_eq!(
107            casted.dtype(),
108            &DType::Primitive(PType::I64, Nullability::NonNullable)
109        );
110
111        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
112        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
113    }
114
115    #[test]
116    fn test_cast_zstd_nullability_change() {
117        let mut ctx = SESSION.create_execution_ctx();
118        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
119        let zstd = Zstd::from_primitive(&values, 0, 0, &mut ctx).unwrap();
120
121        let casted = zstd
122            .into_array()
123            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
124            .unwrap();
125        assert_eq!(
126            casted.dtype(),
127            &DType::Primitive(PType::U32, Nullability::Nullable)
128        );
129    }
130
131    #[test]
132    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
133        let mut ctx = SESSION.create_execution_ctx();
134        let values = PrimitiveArray::new(
135            buffer![10u32, 20, 30, 40, 50, 60],
136            Validity::from_iter([true, true, true, true, true, true]),
137        );
138        let zstd = Zstd::from_primitive(&values, 0, 128, &mut ctx).unwrap();
139        let sliced = zstd.slice(1..5).unwrap();
140        let casted = sliced
141            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
142            .unwrap();
143        assert_eq!(
144            casted.dtype(),
145            &DType::Primitive(PType::U32, Nullability::NonNullable)
146        );
147        // Verify the values are correct
148        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
149        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
150    }
151
152    #[test]
153    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
154        let mut ctx = SESSION.create_execution_ctx();
155        let values = PrimitiveArray::from_option_iter([
156            None,
157            Some(20u32),
158            Some(30),
159            Some(40),
160            Some(50),
161            Some(60),
162        ]);
163        let zstd = Zstd::from_primitive(&values, 0, 128, &mut ctx).unwrap();
164        let sliced = zstd.slice(1..5).unwrap();
165        let casted = sliced
166            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
167            .unwrap();
168        assert_eq!(
169            casted.dtype(),
170            &DType::Primitive(PType::U32, Nullability::NonNullable)
171        );
172        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
173        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
174        assert_arrays_eq!(decoded, expected);
175    }
176
177    #[rstest]
178    #[case::i32(PrimitiveArray::new(
179        buffer![100i32, 200, 300, 400, 500],
180        Validity::NonNullable,
181    ))]
182    #[case::f64(PrimitiveArray::new(
183        buffer![1.1f64, 2.2, 3.3, 4.4, 5.5],
184        Validity::NonNullable,
185    ))]
186    #[case::single(PrimitiveArray::new(
187        buffer![42i64],
188        Validity::NonNullable,
189    ))]
190    #[case::large(PrimitiveArray::new(
191        buffer![0u32..1000],
192        Validity::NonNullable,
193    ))]
194    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
195        let zstd =
196            Zstd::from_primitive(&values, 0, 0, &mut SESSION.create_execution_ctx()).unwrap();
197        test_cast_conformance(&zstd.into_array());
198    }
199}