Skip to main content

vortex_zigzag/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_error::VortexResult;
10
11use crate::ZigZagArray;
12use crate::ZigZagVTable;
13
14impl CastReduce for ZigZagVTable {
15    fn cast(array: &ZigZagArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        if !dtype.is_signed_int() {
17            return Ok(None);
18        }
19
20        let new_encoded_dtype =
21            DType::Primitive(dtype.as_ptype().to_unsigned(), dtype.nullability());
22        let new_encoded = array.encoded().cast(new_encoded_dtype)?;
23        Ok(Some(ZigZagArray::try_new(new_encoded)?.into_array()))
24    }
25}
26
27#[cfg(test)]
28mod tests {
29    use rstest::rstest;
30    use vortex_array::DynArray;
31    use vortex_array::IntoArray;
32    use vortex_array::arrays::PrimitiveArray;
33    use vortex_array::assert_arrays_eq;
34    use vortex_array::builtins::ArrayBuiltins;
35    use vortex_array::compute::conformance::cast::test_cast_conformance;
36    use vortex_array::dtype::DType;
37    use vortex_array::dtype::Nullability;
38    use vortex_array::dtype::PType;
39
40    use crate::ZigZagArray;
41    use crate::zigzag_encode;
42
43    #[test]
44    fn test_cast_zigzag_i32_to_i64() {
45        let values = PrimitiveArray::from_iter([-100i32, -1, 0, 1, 100]);
46        let zigzag = zigzag_encode(values).unwrap();
47
48        let casted = zigzag
49            .into_array()
50            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
51            .unwrap();
52        assert_eq!(
53            casted.dtype(),
54            &DType::Primitive(PType::I64, Nullability::NonNullable)
55        );
56
57        // Verify the result is still a ZigZagArray (not decoded)
58        // Note: The result might be wrapped, so let's check the encoding ID
59        assert_eq!(
60            casted.encoding_id().as_ref(),
61            "vortex.zigzag",
62            "Cast should preserve ZigZag encoding"
63        );
64
65        assert_arrays_eq!(casted, PrimitiveArray::from_iter([-100i64, -1, 0, 1, 100]));
66    }
67
68    #[test]
69    fn test_cast_zigzag_width_changes() {
70        // Test i32 to i16 (narrowing)
71        let values = PrimitiveArray::from_iter([100i32, -50, 0, 25, -100]);
72        let zigzag = zigzag_encode(values).unwrap();
73
74        let casted = zigzag
75            .into_array()
76            .cast(DType::Primitive(PType::I16, Nullability::NonNullable))
77            .unwrap();
78        assert_eq!(
79            casted.encoding_id().as_ref(),
80            "vortex.zigzag",
81            "Should remain ZigZag encoded"
82        );
83
84        assert_arrays_eq!(
85            casted,
86            PrimitiveArray::from_iter([100i16, -50, 0, 25, -100])
87        );
88
89        // Test i16 to i64 (widening)
90        let values16 = PrimitiveArray::from_iter([1000i16, -500, 0, 250, -1000]);
91        let zigzag16 = zigzag_encode(values16).unwrap();
92
93        let casted64 = zigzag16
94            .into_array()
95            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
96            .unwrap();
97        assert_eq!(
98            casted64.encoding_id().as_ref(),
99            "vortex.zigzag",
100            "Should remain ZigZag encoded"
101        );
102
103        assert_arrays_eq!(
104            casted64,
105            PrimitiveArray::from_iter([1000i64, -500, 0, 250, -1000])
106        );
107    }
108
109    #[test]
110    fn test_cast_zigzag_nullable() {
111        let values =
112            PrimitiveArray::from_option_iter([Some(-10i32), None, Some(0), Some(10), None]);
113        let zigzag = zigzag_encode(values).unwrap();
114
115        let casted = zigzag
116            .into_array()
117            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
118            .unwrap();
119        assert_eq!(
120            casted.dtype(),
121            &DType::Primitive(PType::I64, Nullability::Nullable)
122        );
123    }
124
125    #[rstest]
126    #[case(zigzag_encode(PrimitiveArray::from_iter([-100i32, -50, -1, 0, 1, 50, 100])).unwrap())]
127    #[case(zigzag_encode(PrimitiveArray::from_iter([-1000i64, -1, 0, 1, 1000])).unwrap())]
128    #[case(zigzag_encode(PrimitiveArray::from_option_iter([Some(-5i16), None, Some(0), Some(5), None])).unwrap())]
129    #[case(zigzag_encode(PrimitiveArray::from_iter([i32::MIN, -1, 0, 1, i32::MAX])).unwrap())]
130    fn test_cast_zigzag_conformance(#[case] array: ZigZagArray) {
131        test_cast_conformance(&array.into_array());
132    }
133}