vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14    fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        // Early return if not casting to decimal
16        let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17            return Ok(None);
18        };
19        let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20            vortex_panic!(
21                "DecimalArray must have decimal dtype, got {:?}",
22                array.dtype()
23            );
24        };
25
26        // We only support casting to the same decimal type with different nullability
27        if from_precision_scale != to_precision_scale {
28            vortex_bail!(
29                "Cannot cast decimal({},{}) to decimal({},{})",
30                from_precision_scale.precision(),
31                from_precision_scale.scale(),
32                to_precision_scale.precision(),
33                to_precision_scale.scale()
34            );
35        }
36
37        // If the dtype is exactly the same, return self
38        if array.dtype() == dtype {
39            return Ok(Some(array.to_array()));
40        }
41
42        // Cast the validity to the new nullability
43        let new_validity = array.validity().clone().cast_nullability(*to_nullability)?;
44
45        // Construct DecimalArray directly since we can't use new() without knowing the concrete type
46        Ok(Some(
47            DecimalArray {
48                dtype: DType::Decimal(*from_precision_scale, *to_nullability),
49                values: array.byte_buffer(),
50                values_type: array.values_type(),
51                validity: new_validity,
52                stats_set: ArrayStats::default(),
53            }
54            .to_array(),
55        ))
56    }
57}
58
59register_kernel!(CastKernelAdapter(DecimalVTable).lift());
60
61#[cfg(test)]
62mod tests {
63    use rstest::rstest;
64    use vortex_buffer::buffer;
65    use vortex_dtype::{DType, DecimalDType, Nullability};
66
67    use crate::arrays::DecimalArray;
68    use crate::canonical::ToCanonical;
69    use crate::compute::cast;
70    use crate::compute::conformance::cast::test_cast_conformance;
71    use crate::validity::Validity;
72    use crate::vtable::ValidityHelper;
73
74    #[test]
75    fn cast_decimal_to_nullable() {
76        let decimal_dtype = DecimalDType::new(10, 2);
77        let array = DecimalArray::new(
78            buffer![100i32, 200, 300],
79            decimal_dtype,
80            Validity::NonNullable,
81        );
82
83        // Cast to nullable
84        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
85        let casted = cast(array.as_ref(), &nullable_dtype).unwrap().to_decimal();
86
87        assert_eq!(casted.dtype(), &nullable_dtype);
88        assert_eq!(casted.validity(), &Validity::AllValid);
89        assert_eq!(casted.len(), 3);
90    }
91
92    #[test]
93    fn cast_nullable_to_non_nullable() {
94        let decimal_dtype = DecimalDType::new(10, 2);
95
96        // Create nullable array with no nulls
97        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
98
99        // Cast to non-nullable
100        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
101        let casted = cast(array.as_ref(), &non_nullable_dtype)
102            .unwrap()
103            .to_decimal();
104
105        assert_eq!(casted.dtype(), &non_nullable_dtype);
106        assert_eq!(casted.validity(), &Validity::NonNullable);
107    }
108
109    #[test]
110    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
111    fn cast_nullable_with_nulls_to_non_nullable_fails() {
112        let decimal_dtype = DecimalDType::new(10, 2);
113
114        // Create nullable array with nulls
115        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
116
117        // Attempt to cast to non-nullable should fail
118        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
119        cast(array.as_ref(), &non_nullable_dtype).unwrap();
120    }
121
122    #[test]
123    fn cast_different_precision_fails() {
124        let array = DecimalArray::new(
125            buffer![100i32],
126            DecimalDType::new(10, 2),
127            Validity::NonNullable,
128        );
129
130        // Try to cast to different precision
131        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
132        let result = cast(array.as_ref(), &different_dtype);
133
134        assert!(result.is_err());
135        assert!(
136            result
137                .unwrap_err()
138                .to_string()
139                .contains("Cannot cast decimal(10,2) to decimal(15,3)")
140        );
141    }
142
143    #[test]
144    fn cast_to_non_decimal_returns_err() {
145        let array = DecimalArray::new(
146            buffer![100i32],
147            DecimalDType::new(10, 2),
148            Validity::NonNullable,
149        );
150
151        // Try to cast to non-decimal type - should fail since no kernel can handle it
152        let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
153
154        assert!(result.is_err());
155        assert!(
156            result
157                .unwrap_err()
158                .to_string()
159                .contains("No compute kernel to cast")
160        );
161    }
162
163    #[rstest]
164    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
165    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
166    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
167    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
168    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
169        test_cast_conformance(array.as_ref());
170    }
171}