Skip to main content

vortex_fastlanes/for/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_error::VortexResult;
10
11use crate::r#for::FoRArray;
12use crate::r#for::FoRVTable;
13
14impl CastReduce for FoRVTable {
15    fn cast(array: &FoRArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        // FoR only supports integer types
17        if !dtype.is_int() {
18            return Ok(None);
19        }
20
21        // For type changes between integers, cast the components
22        let casted_child = array.encoded().cast(dtype.clone())?;
23        let casted_reference = array.reference_scalar().cast(dtype)?;
24
25        Ok(Some(
26            FoRArray::try_new(casted_child, casted_reference)?.into_array(),
27        ))
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    use rstest::rstest;
34    use vortex_array::IntoArray;
35    use vortex_array::arrays::PrimitiveArray;
36    use vortex_array::assert_arrays_eq;
37    use vortex_array::builtins::ArrayBuiltins;
38    use vortex_array::compute::conformance::cast::test_cast_conformance;
39    use vortex_array::dtype::DType;
40    use vortex_array::dtype::Nullability;
41    use vortex_array::dtype::PType;
42    use vortex_array::scalar::Scalar;
43    use vortex_buffer::buffer;
44
45    use crate::FoRArray;
46
47    #[test]
48    fn test_cast_for_i32_to_i64() {
49        let for_array = FoRArray::try_new(
50            buffer![0i32, 10, 20, 30, 40].into_array(),
51            Scalar::from(100i32),
52        )
53        .unwrap();
54
55        let casted = for_array
56            .to_array()
57            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
58            .unwrap();
59        assert_eq!(
60            casted.dtype(),
61            &DType::Primitive(PType::I64, Nullability::NonNullable)
62        );
63
64        // Verify the values after decoding
65        assert_arrays_eq!(
66            casted,
67            PrimitiveArray::from_iter([100i64, 110, 120, 130, 140])
68        );
69    }
70
71    #[test]
72    fn test_cast_for_nullable() {
73        let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(20), Some(30), None]);
74        let for_array = FoRArray::try_new(values.into_array(), Scalar::from(50i32)).unwrap();
75
76        let casted = for_array
77            .to_array()
78            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
79            .unwrap();
80        assert_eq!(
81            casted.dtype(),
82            &DType::Primitive(PType::I64, Nullability::Nullable)
83        );
84    }
85
86    #[rstest]
87    #[case(FoRArray::try_new(
88        buffer![0i32, 1, 2, 3, 4].into_array(),
89        Scalar::from(100i32)
90    ).unwrap())]
91    #[case(FoRArray::try_new(
92        buffer![0u64, 10, 20, 30].into_array(),
93        Scalar::from(1000u64)
94    ).unwrap())]
95    #[case(FoRArray::try_new(
96        PrimitiveArray::from_option_iter([Some(0i16), None, Some(5), Some(10), None]).into_array(),
97        Scalar::from(50i16)
98    ).unwrap())]
99    #[case(FoRArray::try_new(
100        buffer![-10i32, -5, 0, 5, 10].into_array(),
101        Scalar::from(-100i32)
102    ).unwrap())]
103    fn test_cast_for_conformance(#[case] array: FoRArray) {
104        test_cast_conformance(array.as_ref());
105    }
106}