vortex_fastlanes/for/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::compute::cast;
9use vortex_array::register_kernel;
10use vortex_dtype::DType;
11use vortex_error::VortexResult;
12
13use crate::r#for::FoRArray;
14use crate::r#for::FoRVTable;
15
16impl CastKernel for FoRVTable {
17    fn cast(&self, array: &FoRArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        // FoR only supports integer types
19        if !dtype.is_int() {
20            return Ok(None);
21        }
22
23        // For type changes between integers, cast the components
24        let casted_child = cast(array.encoded(), dtype)?;
25        let casted_reference = array.reference_scalar().cast(dtype)?;
26
27        Ok(Some(
28            FoRArray::try_new(casted_child, casted_reference)?.into_array(),
29        ))
30    }
31}
32
33register_kernel!(CastKernelAdapter(FoRVTable).lift());
34
35#[cfg(test)]
36mod tests {
37    use rstest::rstest;
38    use vortex_array::IntoArray;
39    use vortex_array::ToCanonical;
40    use vortex_array::arrays::PrimitiveArray;
41    use vortex_array::assert_arrays_eq;
42    use vortex_array::compute::cast;
43    use vortex_array::compute::conformance::cast::test_cast_conformance;
44    use vortex_buffer::buffer;
45    use vortex_dtype::DType;
46    use vortex_dtype::Nullability;
47    use vortex_dtype::PType;
48    use vortex_scalar::Scalar;
49
50    use crate::FoRArray;
51
52    #[test]
53    fn test_cast_for_i32_to_i64() {
54        let for_array = FoRArray::try_new(
55            buffer![0i32, 10, 20, 30, 40].into_array(),
56            Scalar::from(100i32),
57        )
58        .unwrap();
59
60        let casted = cast(
61            for_array.as_ref(),
62            &DType::Primitive(PType::I64, Nullability::NonNullable),
63        )
64        .unwrap();
65        assert_eq!(
66            casted.dtype(),
67            &DType::Primitive(PType::I64, Nullability::NonNullable)
68        );
69
70        // Verify the values after decoding
71        let decoded = casted.to_primitive();
72        assert_arrays_eq!(
73            decoded,
74            PrimitiveArray::from_iter([100i64, 110, 120, 130, 140])
75        );
76    }
77
78    #[test]
79    fn test_cast_for_nullable() {
80        let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(20), Some(30), None]);
81        let for_array = FoRArray::try_new(values.into_array(), Scalar::from(50i32)).unwrap();
82
83        let casted = cast(
84            for_array.as_ref(),
85            &DType::Primitive(PType::I64, Nullability::Nullable),
86        )
87        .unwrap();
88        assert_eq!(
89            casted.dtype(),
90            &DType::Primitive(PType::I64, Nullability::Nullable)
91        );
92    }
93
94    #[rstest]
95    #[case(FoRArray::try_new(
96        buffer![0i32, 1, 2, 3, 4].into_array(),
97        Scalar::from(100i32)
98    ).unwrap())]
99    #[case(FoRArray::try_new(
100        buffer![0u64, 10, 20, 30].into_array(),
101        Scalar::from(1000u64)
102    ).unwrap())]
103    #[case(FoRArray::try_new(
104        PrimitiveArray::from_option_iter([Some(0i16), None, Some(5), Some(10), None]).into_array(),
105        Scalar::from(50i16)
106    ).unwrap())]
107    #[case(FoRArray::try_new(
108        buffer![-10i32, -5, 0, 5, 10].into_array(),
109        Scalar::from(-100i32)
110    ).unwrap())]
111    fn test_cast_for_conformance(#[case] array: FoRArray) {
112        test_cast_conformance(array.as_ref());
113    }
114}