vortex_zigzag/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::compute::cast;
9use vortex_array::register_kernel;
10use vortex_dtype::DType;
11use vortex_error::VortexResult;
12
13use crate::ZigZagArray;
14use crate::ZigZagVTable;
15
16impl CastKernel for ZigZagVTable {
17    fn cast(&self, array: &ZigZagArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        if !dtype.is_signed_int() {
19            return Ok(None);
20        }
21
22        let new_encoded_dtype =
23            DType::Primitive(dtype.as_ptype().to_unsigned(), dtype.nullability());
24        let new_encoded = cast(array.encoded(), &new_encoded_dtype)?;
25        Ok(Some(ZigZagArray::try_new(new_encoded)?.into_array()))
26    }
27}
28
29register_kernel!(CastKernelAdapter(ZigZagVTable).lift());
30
31#[cfg(test)]
32mod tests {
33    use rstest::rstest;
34    use vortex_array::Array;
35    use vortex_array::ToCanonical;
36    use vortex_array::arrays::PrimitiveArray;
37    use vortex_array::assert_arrays_eq;
38    use vortex_array::compute::cast;
39    use vortex_array::compute::conformance::cast::test_cast_conformance;
40    use vortex_dtype::DType;
41    use vortex_dtype::Nullability;
42    use vortex_dtype::PType;
43
44    use crate::ZigZagArray;
45    use crate::zigzag_encode;
46
47    #[test]
48    fn test_cast_zigzag_i32_to_i64() {
49        let values = PrimitiveArray::from_iter([-100i32, -1, 0, 1, 100]);
50        let zigzag = zigzag_encode(values).unwrap();
51
52        let casted = cast(
53            zigzag.as_ref(),
54            &DType::Primitive(PType::I64, Nullability::NonNullable),
55        )
56        .unwrap();
57        assert_eq!(
58            casted.dtype(),
59            &DType::Primitive(PType::I64, Nullability::NonNullable)
60        );
61
62        // Verify the result is still a ZigZagArray (not decoded)
63        // Note: The result might be wrapped, so let's check the encoding ID
64        assert_eq!(
65            casted.encoding().id().as_ref(),
66            "vortex.zigzag",
67            "Cast should preserve ZigZag encoding"
68        );
69
70        let decoded = casted.to_primitive();
71        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([-100i64, -1, 0, 1, 100]));
72    }
73
74    #[test]
75    fn test_cast_zigzag_width_changes() {
76        // Test i32 to i16 (narrowing)
77        let values = PrimitiveArray::from_iter([100i32, -50, 0, 25, -100]);
78        let zigzag = zigzag_encode(values).unwrap();
79
80        let casted = cast(
81            zigzag.as_ref(),
82            &DType::Primitive(PType::I16, Nullability::NonNullable),
83        )
84        .unwrap();
85        assert_eq!(
86            casted.encoding().id().as_ref(),
87            "vortex.zigzag",
88            "Should remain ZigZag encoded"
89        );
90
91        let decoded = casted.to_primitive();
92        assert_arrays_eq!(
93            decoded,
94            PrimitiveArray::from_iter([100i16, -50, 0, 25, -100])
95        );
96
97        // Test i16 to i64 (widening)
98        let values16 = PrimitiveArray::from_iter([1000i16, -500, 0, 250, -1000]);
99        let zigzag16 = zigzag_encode(values16).unwrap();
100
101        let casted64 = cast(
102            zigzag16.as_ref(),
103            &DType::Primitive(PType::I64, Nullability::NonNullable),
104        )
105        .unwrap();
106        assert_eq!(
107            casted64.encoding().id().as_ref(),
108            "vortex.zigzag",
109            "Should remain ZigZag encoded"
110        );
111
112        let decoded64 = casted64.to_primitive();
113        assert_arrays_eq!(
114            decoded64,
115            PrimitiveArray::from_iter([1000i64, -500, 0, 250, -1000])
116        );
117    }
118
119    #[test]
120    fn test_cast_zigzag_nullable() {
121        let values =
122            PrimitiveArray::from_option_iter([Some(-10i32), None, Some(0), Some(10), None]);
123        let zigzag = zigzag_encode(values).unwrap();
124
125        let casted = cast(
126            zigzag.as_ref(),
127            &DType::Primitive(PType::I64, Nullability::Nullable),
128        )
129        .unwrap();
130        assert_eq!(
131            casted.dtype(),
132            &DType::Primitive(PType::I64, Nullability::Nullable)
133        );
134    }
135
136    #[rstest]
137    #[case(zigzag_encode(PrimitiveArray::from_iter([-100i32, -50, -1, 0, 1, 50, 100])).unwrap())]
138    #[case(zigzag_encode(PrimitiveArray::from_iter([-1000i64, -1, 0, 1, 1000])).unwrap())]
139    #[case(zigzag_encode(PrimitiveArray::from_option_iter([Some(-5i16), None, Some(0), Some(5), None])).unwrap())]
140    #[case(zigzag_encode(PrimitiveArray::from_iter([i32::MIN, -1, 0, 1, i32::MAX])).unwrap())]
141    fn test_cast_zigzag_conformance(#[case] array: ZigZagArray) {
142        test_cast_conformance(array.as_ref());
143    }
144}