Skip to main content

vortex_array/arrays/constant/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::arrays::Constant;
10use crate::arrays::ConstantArray;
11use crate::dtype::DType;
12use crate::scalar_fn::fns::cast::CastReduce;
13
14impl CastReduce for Constant {
15    fn cast(array: ArrayView<'_, Constant>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        match array.scalar().cast(dtype) {
17            Ok(scalar) => Ok(Some(ConstantArray::new(scalar, array.len()).into_array())),
18            Err(_) => Ok(None),
19        }
20    }
21}
22
23#[cfg(test)]
24mod tests {
25    use rstest::rstest;
26
27    use crate::IntoArray;
28    use crate::LEGACY_SESSION;
29    use crate::VortexSessionExecute;
30    use crate::arrays::ConstantArray;
31    use crate::builtins::ArrayBuiltins;
32    use crate::compute::conformance::cast::test_cast_conformance;
33    use crate::dtype::DType;
34    use crate::dtype::DecimalDType;
35    use crate::dtype::Nullability;
36    use crate::scalar::DecimalValue;
37    use crate::scalar::Scalar;
38
39    #[rstest]
40    #[case(ConstantArray::new(Scalar::from(42u32), 5).into_array())]
41    #[case(ConstantArray::new(Scalar::from(-100i32), 10).into_array())]
42    #[case(ConstantArray::new(Scalar::from(3.5f32), 3).into_array())]
43    #[case(ConstantArray::new(Scalar::from(true), 7).into_array())]
44    #[case(ConstantArray::new(Scalar::null_native::<i32>(), 4).into_array())]
45    #[case(ConstantArray::new(Scalar::from(255u8), 1).into_array())]
46    fn test_cast_constant_conformance(#[case] array: crate::ArrayRef) {
47        test_cast_conformance(&array);
48    }
49
50    #[test]
51    fn test_cast_constant_i64_to_decimal() {
52        let target_dtype = DType::Decimal(DecimalDType::new(21, 2), Nullability::NonNullable);
53        let casted = ConstantArray::new(Scalar::from(42i64), 5)
54            .into_array()
55            .cast(target_dtype.clone())
56            .unwrap();
57
58        assert_eq!(casted.dtype(), &target_dtype);
59        let scalar = casted
60            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
61            .unwrap();
62        assert_eq!(
63            scalar.as_decimal().decimal_value(),
64            Some(DecimalValue::I128(4200))
65        );
66    }
67}