vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14    fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        // Early return if not casting to decimal
16        let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17            return Ok(None);
18        };
19        let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20            vortex_panic!(
21                "DecimalArray must have decimal dtype, got {:?}",
22                array.dtype()
23            );
24        };
25
26        // We only support casting to the same decimal type with different nullability
27        if from_precision_scale != to_precision_scale {
28            vortex_bail!(
29                "Cannot cast decimal({},{}) to decimal({},{})",
30                from_precision_scale.precision(),
31                from_precision_scale.scale(),
32                to_precision_scale.precision(),
33                to_precision_scale.scale()
34            );
35        }
36
37        // If the dtype is exactly the same, return self
38        if array.dtype() == dtype {
39            return Ok(Some(array.to_array()));
40        }
41
42        // Cast the validity to the new nullability
43        let new_validity = array.validity().clone().cast_nullability(*to_nullability)?;
44
45        // Construct DecimalArray directly since we can't use new() without knowing the concrete type
46        Ok(Some(
47            DecimalArray {
48                dtype: DType::Decimal(*from_precision_scale, *to_nullability),
49                values: array.byte_buffer(),
50                values_type: array.values_type(),
51                validity: new_validity,
52                stats_set: ArrayStats::default(),
53            }
54            .to_array(),
55        ))
56    }
57}
58
59register_kernel!(CastKernelAdapter(DecimalVTable).lift());
60
61#[cfg(test)]
62mod tests {
63    use vortex_buffer::buffer;
64    use vortex_dtype::{DType, DecimalDType, Nullability};
65
66    use crate::arrays::DecimalArray;
67    use crate::canonical::ToCanonical;
68    use crate::compute::cast;
69    use crate::validity::Validity;
70    use crate::vtable::ValidityHelper;
71
72    #[test]
73    fn cast_decimal_to_nullable() {
74        let decimal_dtype = DecimalDType::new(10, 2);
75        let array = DecimalArray::new(
76            buffer![100i32, 200, 300],
77            decimal_dtype,
78            Validity::NonNullable,
79        );
80
81        // Cast to nullable
82        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
83        let casted = cast(array.as_ref(), &nullable_dtype)
84            .unwrap()
85            .to_decimal()
86            .unwrap();
87
88        assert_eq!(casted.dtype(), &nullable_dtype);
89        assert_eq!(casted.validity(), &Validity::AllValid);
90        assert_eq!(casted.len(), 3);
91    }
92
93    #[test]
94    fn cast_nullable_to_non_nullable() {
95        let decimal_dtype = DecimalDType::new(10, 2);
96
97        // Create nullable array with no nulls
98        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
99
100        // Cast to non-nullable
101        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
102        let casted = cast(array.as_ref(), &non_nullable_dtype)
103            .unwrap()
104            .to_decimal()
105            .unwrap();
106
107        assert_eq!(casted.dtype(), &non_nullable_dtype);
108        assert_eq!(casted.validity(), &Validity::NonNullable);
109    }
110
111    #[test]
112    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
113    fn cast_nullable_with_nulls_to_non_nullable_fails() {
114        let decimal_dtype = DecimalDType::new(10, 2);
115
116        // Create nullable array with nulls
117        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
118
119        // Attempt to cast to non-nullable should fail
120        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
121        cast(array.as_ref(), &non_nullable_dtype).unwrap();
122    }
123
124    #[test]
125    fn cast_different_precision_fails() {
126        let array = DecimalArray::new(
127            buffer![100i32],
128            DecimalDType::new(10, 2),
129            Validity::NonNullable,
130        );
131
132        // Try to cast to different precision
133        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
134        let result = cast(array.as_ref(), &different_dtype);
135
136        assert!(result.is_err());
137        assert!(
138            result
139                .unwrap_err()
140                .to_string()
141                .contains("Cannot cast decimal(10,2) to decimal(15,3)")
142        );
143    }
144
145    #[test]
146    fn cast_to_non_decimal_returns_err() {
147        let array = DecimalArray::new(
148            buffer![100i32],
149            DecimalDType::new(10, 2),
150            Validity::NonNullable,
151        );
152
153        // Try to cast to non-decimal type - should fail since no kernel can handle it
154        let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
155
156        assert!(result.is_err());
157        assert!(
158            result
159                .unwrap_err()
160                .to_string()
161                .contains("No compute kernel to cast")
162        );
163    }
164}