vortex_zigzag/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{ZigZagArray, ZigZagVTable};
10
11impl CastKernel for ZigZagVTable {
12    fn cast(&self, array: &ZigZagArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        if !dtype.is_signed_int() {
14            return Ok(None);
15        }
16
17        let new_encoded_dtype =
18            DType::Primitive(dtype.as_ptype().to_unsigned(), dtype.nullability());
19        let new_encoded = cast(array.encoded(), &new_encoded_dtype)?;
20        Ok(Some(ZigZagArray::try_new(new_encoded)?.into_array()))
21    }
22}
23
24register_kernel!(CastKernelAdapter(ZigZagVTable).lift());
25
26#[cfg(test)]
27mod tests {
28    use rstest::rstest;
29    use vortex_array::Array;
30    use vortex_array::arrays::PrimitiveArray;
31    use vortex_array::compute::cast;
32    use vortex_array::compute::conformance::cast::test_cast_conformance;
33    use vortex_dtype::{DType, Nullability, PType};
34
35    use crate::{ZigZagArray, zigzag_encode};
36
37    #[test]
38    fn test_cast_zigzag_i32_to_i64() {
39        let values = PrimitiveArray::from_iter([-100i32, -1, 0, 1, 100]);
40        let zigzag = zigzag_encode(values).unwrap();
41
42        let casted = cast(
43            zigzag.as_ref(),
44            &DType::Primitive(PType::I64, Nullability::NonNullable),
45        )
46        .unwrap();
47        assert_eq!(
48            casted.dtype(),
49            &DType::Primitive(PType::I64, Nullability::NonNullable)
50        );
51
52        // Verify the result is still a ZigZagArray (not decoded)
53        // Note: The result might be wrapped, so let's check the encoding ID
54        assert_eq!(
55            casted.encoding().id().as_ref(),
56            "vortex.zigzag",
57            "Cast should preserve ZigZag encoding"
58        );
59
60        let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
61        assert_eq!(decoded.as_slice::<i64>(), &[-100i64, -1, 0, 1, 100]);
62    }
63
64    #[test]
65    fn test_cast_zigzag_width_changes() {
66        // Test i32 to i16 (narrowing)
67        let values = PrimitiveArray::from_iter([100i32, -50, 0, 25, -100]);
68        let zigzag = zigzag_encode(values).unwrap();
69
70        let casted = cast(
71            zigzag.as_ref(),
72            &DType::Primitive(PType::I16, Nullability::NonNullable),
73        )
74        .unwrap();
75        assert_eq!(
76            casted.encoding().id().as_ref(),
77            "vortex.zigzag",
78            "Should remain ZigZag encoded"
79        );
80
81        let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
82        assert_eq!(decoded.as_slice::<i16>(), &[100i16, -50, 0, 25, -100]);
83
84        // Test i16 to i64 (widening)
85        let values16 = PrimitiveArray::from_iter([1000i16, -500, 0, 250, -1000]);
86        let zigzag16 = zigzag_encode(values16).unwrap();
87
88        let casted64 = cast(
89            zigzag16.as_ref(),
90            &DType::Primitive(PType::I64, Nullability::NonNullable),
91        )
92        .unwrap();
93        assert_eq!(
94            casted64.encoding().id().as_ref(),
95            "vortex.zigzag",
96            "Should remain ZigZag encoded"
97        );
98
99        let decoded64 = casted64.to_canonical().unwrap().into_primitive().unwrap();
100        assert_eq!(decoded64.as_slice::<i64>(), &[1000i64, -500, 0, 250, -1000]);
101    }
102
103    #[test]
104    fn test_cast_zigzag_nullable() {
105        let values =
106            PrimitiveArray::from_option_iter([Some(-10i32), None, Some(0), Some(10), None]);
107        let zigzag = zigzag_encode(values).unwrap();
108
109        let casted = cast(
110            zigzag.as_ref(),
111            &DType::Primitive(PType::I64, Nullability::Nullable),
112        )
113        .unwrap();
114        assert_eq!(
115            casted.dtype(),
116            &DType::Primitive(PType::I64, Nullability::Nullable)
117        );
118    }
119
120    #[rstest]
121    #[case(zigzag_encode(PrimitiveArray::from_iter([-100i32, -50, -1, 0, 1, 50, 100])).unwrap())]
122    #[case(zigzag_encode(PrimitiveArray::from_iter([-1000i64, -1, 0, 1, 1000])).unwrap())]
123    #[case(zigzag_encode(PrimitiveArray::from_option_iter([Some(-5i16), None, Some(0), Some(5), None])).unwrap())]
124    #[case(zigzag_encode(PrimitiveArray::from_iter([i32::MIN, -1, 0, 1, i32::MAX])).unwrap())]
125    fn test_cast_zigzag_conformance(#[case] array: ZigZagArray) {
126        test_cast_conformance(array.as_ref());
127    }
128}