vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{ZstdArray, ZstdVTable};
10
11impl CastKernel for ZstdVTable {
12    fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        if !dtype.is_nullable() || !array.all_valid() {
14            // We cannot cast to non-nullable since the validity containing nulls is used to decode
15            // the ZSTD array, this would require rewriting tables.
16            return Ok(None);
17        }
18        // ZstdArray is a general-purpose compression encoding using Zstandard compression.
19        // It can handle nullability changes without decompression by updating the validity
20        // bitmap, but type changes require decompression since the compressed data is
21        // type-specific and Zstd operates on raw bytes.
22        if array.dtype().eq_ignore_nullability(dtype) {
23            // Create a new validity with the target nullability
24            let new_validity = array
25                .unsliced_validity
26                .clone()
27                .cast_nullability(dtype.nullability(), array.len())?;
28
29            return Ok(Some(
30                ZstdArray::new(
31                    array.dictionary.clone(),
32                    array.frames.clone(),
33                    dtype.clone(),
34                    array.metadata.clone(),
35                    array.unsliced_n_rows(),
36                    new_validity,
37                )
38                ._slice(array.slice_start(), array.slice_stop())
39                .into_array(),
40            ));
41        }
42
43        // For other casts (e.g., type changes), decode to canonical and let the underlying array handle it
44        Ok(None)
45    }
46}
47
48register_kernel!(CastKernelAdapter(ZstdVTable).lift());
49
50#[cfg(test)]
51mod tests {
52    use rstest::rstest;
53    use vortex_array::arrays::PrimitiveArray;
54    use vortex_array::compute::cast;
55    use vortex_array::compute::conformance::cast::test_cast_conformance;
56    use vortex_array::validity::Validity;
57    use vortex_array::{ToCanonical, assert_arrays_eq};
58    use vortex_buffer::Buffer;
59    use vortex_dtype::{DType, Nullability, PType};
60
61    use crate::ZstdArray;
62
63    #[test]
64    fn test_cast_zstd_i32_to_i64() {
65        let values = PrimitiveArray::new(
66            Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
67            Validity::NonNullable,
68        );
69        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
70
71        let casted = cast(
72            zstd.as_ref(),
73            &DType::Primitive(PType::I64, Nullability::NonNullable),
74        )
75        .unwrap();
76        assert_eq!(
77            casted.dtype(),
78            &DType::Primitive(PType::I64, Nullability::NonNullable)
79        );
80
81        let decoded = casted.to_primitive();
82        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 4, 5]));
83    }
84
85    #[test]
86    fn test_cast_zstd_nullability_change() {
87        let values = PrimitiveArray::new(
88            Buffer::copy_from(vec![10u32, 20, 30, 40]),
89            Validity::NonNullable,
90        );
91        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
92
93        let casted = cast(
94            zstd.as_ref(),
95            &DType::Primitive(PType::U32, Nullability::Nullable),
96        )
97        .unwrap();
98        assert_eq!(
99            casted.dtype(),
100            &DType::Primitive(PType::U32, Nullability::Nullable)
101        );
102    }
103
104    #[test]
105    fn test_cast_sliced_zstd_nullable_to_nonnullable() {
106        let values = PrimitiveArray::new(
107            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
108            Validity::from_iter([true, true, true, true, true, true]),
109        );
110        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
111        let sliced = zstd.slice(1..5);
112        let casted = cast(
113            sliced.as_ref(),
114            &DType::Primitive(PType::U32, Nullability::NonNullable),
115        )
116        .unwrap();
117        assert_eq!(
118            casted.dtype(),
119            &DType::Primitive(PType::U32, Nullability::NonNullable)
120        );
121        // Verify the values are correct
122        let decoded = casted.to_primitive();
123        let u32_values = decoded.as_slice::<u32>();
124        assert_eq!(u32_values, &[20, 30, 40, 50]);
125    }
126
127    #[test]
128    fn test_cast_sliced_zstd_part_valid_to_nonnullable() {
129        let values = PrimitiveArray::from_option_iter([
130            None,
131            Some(20u32),
132            Some(30),
133            Some(40),
134            Some(50),
135            Some(60),
136        ]);
137        let zstd = ZstdArray::from_primitive(&values, 0, 128).unwrap();
138        let sliced = zstd.slice(1..5);
139        let casted = cast(
140            sliced.as_ref(),
141            &DType::Primitive(PType::U32, Nullability::NonNullable),
142        )
143        .unwrap();
144        assert_eq!(
145            casted.dtype(),
146            &DType::Primitive(PType::U32, Nullability::NonNullable)
147        );
148        let decoded = casted.to_primitive();
149        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
150        assert_arrays_eq!(decoded, expected);
151    }
152
153    #[rstest]
154    #[case::i32(PrimitiveArray::new(
155        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
156        Validity::NonNullable,
157    ))]
158    #[case::f64(PrimitiveArray::new(
159        Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
160        Validity::NonNullable,
161    ))]
162    #[case::single(PrimitiveArray::new(
163        Buffer::copy_from(vec![42i64]),
164        Validity::NonNullable,
165    ))]
166    #[case::large(PrimitiveArray::new(
167        Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
168        Validity::NonNullable,
169    ))]
170    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
171        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
172        test_cast_conformance(zstd.as_ref());
173    }
174}