vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14    fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        // Early return if not casting to decimal
16        let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17            return Ok(None);
18        };
19        let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20            vortex_panic!(
21                "DecimalArray must have decimal dtype, got {:?}",
22                array.dtype()
23            );
24        };
25
26        // We only support casting to the same decimal type with different nullability
27        if from_precision_scale != to_precision_scale {
28            vortex_bail!(
29                "Cannot cast decimal({},{}) to decimal({},{})",
30                from_precision_scale.precision(),
31                from_precision_scale.scale(),
32                to_precision_scale.precision(),
33                to_precision_scale.scale()
34            );
35        }
36
37        // If the dtype is exactly the same, return self
38        if array.dtype() == dtype {
39            return Ok(Some(array.to_array()));
40        }
41
42        // Cast the validity to the new nullability
43        let new_validity = array.validity().clone().cast_nullability(*to_nullability)?;
44
45        // Construct DecimalArray directly since we can't use new() without knowing the concrete type
46        Ok(Some(
47            DecimalArray {
48                dtype: DType::Decimal(*from_precision_scale, *to_nullability),
49                values: array.byte_buffer(),
50                values_type: array.values_type(),
51                validity: new_validity,
52                stats_set: ArrayStats::default(),
53            }
54            .to_array(),
55        ))
56    }
57}
58
59register_kernel!(CastKernelAdapter(DecimalVTable).lift());
60
61#[cfg(test)]
62mod tests {
63    use rstest::rstest;
64    use vortex_buffer::buffer;
65    use vortex_dtype::{DType, DecimalDType, Nullability};
66
67    use crate::arrays::DecimalArray;
68    use crate::canonical::ToCanonical;
69    use crate::compute::cast;
70    use crate::compute::conformance::cast::test_cast_conformance;
71    use crate::validity::Validity;
72    use crate::vtable::ValidityHelper;
73
74    #[test]
75    fn cast_decimal_to_nullable() {
76        let decimal_dtype = DecimalDType::new(10, 2);
77        let array = DecimalArray::new(
78            buffer![100i32, 200, 300],
79            decimal_dtype,
80            Validity::NonNullable,
81        );
82
83        // Cast to nullable
84        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
85        let casted = cast(array.as_ref(), &nullable_dtype)
86            .unwrap()
87            .to_decimal()
88            .unwrap();
89
90        assert_eq!(casted.dtype(), &nullable_dtype);
91        assert_eq!(casted.validity(), &Validity::AllValid);
92        assert_eq!(casted.len(), 3);
93    }
94
95    #[test]
96    fn cast_nullable_to_non_nullable() {
97        let decimal_dtype = DecimalDType::new(10, 2);
98
99        // Create nullable array with no nulls
100        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
101
102        // Cast to non-nullable
103        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
104        let casted = cast(array.as_ref(), &non_nullable_dtype)
105            .unwrap()
106            .to_decimal()
107            .unwrap();
108
109        assert_eq!(casted.dtype(), &non_nullable_dtype);
110        assert_eq!(casted.validity(), &Validity::NonNullable);
111    }
112
113    #[test]
114    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
115    fn cast_nullable_with_nulls_to_non_nullable_fails() {
116        let decimal_dtype = DecimalDType::new(10, 2);
117
118        // Create nullable array with nulls
119        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
120
121        // Attempt to cast to non-nullable should fail
122        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
123        cast(array.as_ref(), &non_nullable_dtype).unwrap();
124    }
125
126    #[test]
127    fn cast_different_precision_fails() {
128        let array = DecimalArray::new(
129            buffer![100i32],
130            DecimalDType::new(10, 2),
131            Validity::NonNullable,
132        );
133
134        // Try to cast to different precision
135        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
136        let result = cast(array.as_ref(), &different_dtype);
137
138        assert!(result.is_err());
139        assert!(
140            result
141                .unwrap_err()
142                .to_string()
143                .contains("Cannot cast decimal(10,2) to decimal(15,3)")
144        );
145    }
146
147    #[test]
148    fn cast_to_non_decimal_returns_err() {
149        let array = DecimalArray::new(
150            buffer![100i32],
151            DecimalDType::new(10, 2),
152            Validity::NonNullable,
153        );
154
155        // Try to cast to non-decimal type - should fail since no kernel can handle it
156        let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
157
158        assert!(result.is_err());
159        assert!(
160            result
161                .unwrap_err()
162                .to_string()
163                .contains("No compute kernel to cast")
164        );
165    }
166
167    #[rstest]
168    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
169    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
170    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
171    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
172    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
173        test_cast_conformance(array.as_ref());
174    }
175}