vortex_zstd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{ZstdArray, ZstdVTable};
10
11impl CastKernel for ZstdVTable {
12    fn cast(&self, array: &ZstdArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        // ZstdArray is a general-purpose compression encoding using Zstandard compression.
14        // It can handle nullability changes without decompression by updating the validity
15        // bitmap, but type changes require decompression since the compressed data is
16        // type-specific and Zstd operates on raw bytes.
17        if array.dtype().eq_ignore_nullability(dtype) {
18            // Create a new validity with the target nullability
19            let new_validity = array
20                .unsliced_validity
21                .clone()
22                .cast_nullability(dtype.nullability(), array.len())?;
23
24            return Ok(Some(
25                ZstdArray::new(
26                    array.dictionary.clone(),
27                    array.frames.clone(),
28                    dtype.clone(),
29                    array.metadata.clone(),
30                    array.unsliced_n_rows(),
31                    new_validity,
32                )
33                ._slice(array.slice_start(), array.slice_stop())
34                .into_array(),
35            ));
36        }
37
38        // For other casts (e.g., type changes), decode to canonical and let the underlying array handle it
39        Ok(None)
40    }
41}
42
43register_kernel!(CastKernelAdapter(ZstdVTable).lift());
44
45#[cfg(test)]
46mod tests {
47    use rstest::rstest;
48    use vortex_array::ToCanonical;
49    use vortex_array::arrays::PrimitiveArray;
50    use vortex_array::compute::cast;
51    use vortex_array::compute::conformance::cast::test_cast_conformance;
52    use vortex_buffer::Buffer;
53    use vortex_dtype::{DType, Nullability, PType};
54
55    use crate::ZstdArray;
56
57    #[test]
58    fn test_cast_zstd_i32_to_i64() {
59        let values = PrimitiveArray::new(
60            Buffer::copy_from(vec![1i32, 2, 3, 4, 5]),
61            vortex_array::validity::Validity::NonNullable,
62        );
63        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
64
65        let casted = cast(
66            zstd.as_ref(),
67            &DType::Primitive(PType::I64, Nullability::NonNullable),
68        )
69        .unwrap();
70        assert_eq!(
71            casted.dtype(),
72            &DType::Primitive(PType::I64, Nullability::NonNullable)
73        );
74
75        let decoded = casted.to_primitive();
76        assert_eq!(decoded.as_slice::<i64>(), &[1i64, 2, 3, 4, 5]);
77    }
78
79    #[test]
80    fn test_cast_zstd_nullability_change() {
81        let values = PrimitiveArray::new(
82            Buffer::copy_from(vec![10u32, 20, 30, 40]),
83            vortex_array::validity::Validity::NonNullable,
84        );
85        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
86
87        let casted = cast(
88            zstd.as_ref(),
89            &DType::Primitive(PType::U32, Nullability::Nullable),
90        )
91        .unwrap();
92        assert_eq!(
93            casted.dtype(),
94            &DType::Primitive(PType::U32, Nullability::Nullable)
95        );
96    }
97
98    #[rstest]
99    #[case::i32(PrimitiveArray::new(
100        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
101        vortex_array::validity::Validity::NonNullable,
102    ))]
103    #[case::f64(PrimitiveArray::new(
104        Buffer::copy_from(vec![1.1f64, 2.2, 3.3, 4.4, 5.5]),
105        vortex_array::validity::Validity::NonNullable,
106    ))]
107    #[case::single(PrimitiveArray::new(
108        Buffer::copy_from(vec![42i64]),
109        vortex_array::validity::Validity::NonNullable,
110    ))]
111    #[case::large(PrimitiveArray::new(
112        Buffer::copy_from((0..1000).map(|i| i as u32).collect::<Vec<_>>()),
113        vortex_array::validity::Validity::NonNullable,
114    ))]
115    fn test_cast_zstd_conformance(#[case] values: PrimitiveArray) {
116        let zstd = ZstdArray::from_primitive(&values, 0, 0).unwrap();
117        test_cast_conformance(zstd.as_ref());
118    }
119}