vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14    fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        // Early return if not casting to decimal
16        let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17            return Ok(None);
18        };
19        let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20            vortex_panic!(
21                "DecimalArray must have decimal dtype, got {:?}",
22                array.dtype()
23            );
24        };
25
26        // We only support casting to the same decimal type with different nullability
27        if from_precision_scale != to_precision_scale {
28            vortex_bail!(
29                "Cannot cast decimal({},{}) to decimal({},{})",
30                from_precision_scale.precision(),
31                from_precision_scale.scale(),
32                to_precision_scale.precision(),
33                to_precision_scale.scale()
34            );
35        }
36
37        // If the dtype is exactly the same, return self
38        if array.dtype() == dtype {
39            return Ok(Some(array.to_array()));
40        }
41
42        // Cast the validity to the new nullability
43        let new_validity = array
44            .validity()
45            .clone()
46            .cast_nullability(*to_nullability, array.len())?;
47
48        // Construct DecimalArray directly since we can't use new() without knowing the concrete type
49        Ok(Some(
50            DecimalArray {
51                dtype: DType::Decimal(*from_precision_scale, *to_nullability),
52                values: array.byte_buffer(),
53                values_type: array.values_type(),
54                validity: new_validity,
55                stats_set: ArrayStats::default(),
56            }
57            .to_array(),
58        ))
59    }
60}
61
62register_kernel!(CastKernelAdapter(DecimalVTable).lift());
63
64#[cfg(test)]
65mod tests {
66    use rstest::rstest;
67    use vortex_buffer::buffer;
68    use vortex_dtype::{DType, DecimalDType, Nullability};
69
70    use crate::arrays::DecimalArray;
71    use crate::canonical::ToCanonical;
72    use crate::compute::cast;
73    use crate::compute::conformance::cast::test_cast_conformance;
74    use crate::validity::Validity;
75    use crate::vtable::ValidityHelper;
76
77    #[test]
78    fn cast_decimal_to_nullable() {
79        let decimal_dtype = DecimalDType::new(10, 2);
80        let array = DecimalArray::new(
81            buffer![100i32, 200, 300],
82            decimal_dtype,
83            Validity::NonNullable,
84        );
85
86        // Cast to nullable
87        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
88        let casted = cast(array.as_ref(), &nullable_dtype).unwrap().to_decimal();
89
90        assert_eq!(casted.dtype(), &nullable_dtype);
91        assert_eq!(casted.validity(), &Validity::AllValid);
92        assert_eq!(casted.len(), 3);
93    }
94
95    #[test]
96    fn cast_nullable_to_non_nullable() {
97        let decimal_dtype = DecimalDType::new(10, 2);
98
99        // Create nullable array with no nulls
100        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
101
102        // Cast to non-nullable
103        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
104        let casted = cast(array.as_ref(), &non_nullable_dtype)
105            .unwrap()
106            .to_decimal();
107
108        assert_eq!(casted.dtype(), &non_nullable_dtype);
109        assert_eq!(casted.validity(), &Validity::NonNullable);
110    }
111
112    #[test]
113    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
114    fn cast_nullable_with_nulls_to_non_nullable_fails() {
115        let decimal_dtype = DecimalDType::new(10, 2);
116
117        // Create nullable array with nulls
118        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
119
120        // Attempt to cast to non-nullable should fail
121        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
122        cast(array.as_ref(), &non_nullable_dtype).unwrap();
123    }
124
125    #[test]
126    fn cast_different_precision_fails() {
127        let array = DecimalArray::new(
128            buffer![100i32],
129            DecimalDType::new(10, 2),
130            Validity::NonNullable,
131        );
132
133        // Try to cast to different precision
134        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
135        let result = cast(array.as_ref(), &different_dtype);
136
137        assert!(result.is_err());
138        assert!(
139            result
140                .unwrap_err()
141                .to_string()
142                .contains("Cannot cast decimal(10,2) to decimal(15,3)")
143        );
144    }
145
146    #[test]
147    fn cast_to_non_decimal_returns_err() {
148        let array = DecimalArray::new(
149            buffer![100i32],
150            DecimalDType::new(10, 2),
151            Validity::NonNullable,
152        );
153
154        // Try to cast to non-decimal type - should fail since no kernel can handle it
155        let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
156
157        assert!(result.is_err());
158        assert!(
159            result
160                .unwrap_err()
161                .to_string()
162                .contains("No compute kernel to cast")
163        );
164    }
165
166    #[rstest]
167    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
168    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
169    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
170    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
171    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
172        test_cast_conformance(array.as_ref());
173    }
174}