vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7use vortex_error::vortex_panic;
8
9use crate::ArrayRef;
10use crate::arrays::DecimalArray;
11use crate::arrays::DecimalVTable;
12use crate::compute::CastKernel;
13use crate::compute::CastKernelAdapter;
14use crate::register_kernel;
15use crate::stats::ArrayStats;
16use crate::vtable::ValidityHelper;
17
18impl CastKernel for DecimalVTable {
19    fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20        // Early return if not casting to decimal
21        let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
22            return Ok(None);
23        };
24        let DType::Decimal(from_precision_scale, _) = array.dtype() else {
25            vortex_panic!(
26                "DecimalArray must have decimal dtype, got {:?}",
27                array.dtype()
28            );
29        };
30
31        // We only support casting to the same decimal type with different nullability
32        if from_precision_scale != to_precision_scale {
33            vortex_bail!(
34                "Cannot cast decimal({},{}) to decimal({},{})",
35                from_precision_scale.precision(),
36                from_precision_scale.scale(),
37                to_precision_scale.precision(),
38                to_precision_scale.scale()
39            );
40        }
41
42        // If the dtype is exactly the same, return self
43        if array.dtype() == dtype {
44            return Ok(Some(array.to_array()));
45        }
46
47        // Cast the validity to the new nullability
48        let new_validity = array
49            .validity()
50            .clone()
51            .cast_nullability(*to_nullability, array.len())?;
52
53        // Construct DecimalArray directly since we can't use new() without knowing the concrete type
54        Ok(Some(
55            DecimalArray {
56                dtype: DType::Decimal(*from_precision_scale, *to_nullability),
57                values: array.byte_buffer(),
58                values_type: array.values_type(),
59                validity: new_validity,
60                stats_set: ArrayStats::default(),
61            }
62            .to_array(),
63        ))
64    }
65}
66
67register_kernel!(CastKernelAdapter(DecimalVTable).lift());
68
69#[cfg(test)]
70mod tests {
71    use rstest::rstest;
72    use vortex_buffer::buffer;
73    use vortex_dtype::DType;
74    use vortex_dtype::DecimalDType;
75    use vortex_dtype::Nullability;
76
77    use crate::arrays::DecimalArray;
78    use crate::canonical::ToCanonical;
79    use crate::compute::cast;
80    use crate::compute::conformance::cast::test_cast_conformance;
81    use crate::validity::Validity;
82    use crate::vtable::ValidityHelper;
83
84    #[test]
85    fn cast_decimal_to_nullable() {
86        let decimal_dtype = DecimalDType::new(10, 2);
87        let array = DecimalArray::new(
88            buffer![100i32, 200, 300],
89            decimal_dtype,
90            Validity::NonNullable,
91        );
92
93        // Cast to nullable
94        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
95        let casted = cast(array.as_ref(), &nullable_dtype).unwrap().to_decimal();
96
97        assert_eq!(casted.dtype(), &nullable_dtype);
98        assert_eq!(casted.validity(), &Validity::AllValid);
99        assert_eq!(casted.len(), 3);
100    }
101
102    #[test]
103    fn cast_nullable_to_non_nullable() {
104        let decimal_dtype = DecimalDType::new(10, 2);
105
106        // Create nullable array with no nulls
107        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
108
109        // Cast to non-nullable
110        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
111        let casted = cast(array.as_ref(), &non_nullable_dtype)
112            .unwrap()
113            .to_decimal();
114
115        assert_eq!(casted.dtype(), &non_nullable_dtype);
116        assert_eq!(casted.validity(), &Validity::NonNullable);
117    }
118
119    #[test]
120    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
121    fn cast_nullable_with_nulls_to_non_nullable_fails() {
122        let decimal_dtype = DecimalDType::new(10, 2);
123
124        // Create nullable array with nulls
125        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
126
127        // Attempt to cast to non-nullable should fail
128        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
129        cast(array.as_ref(), &non_nullable_dtype).unwrap();
130    }
131
132    #[test]
133    fn cast_different_precision_fails() {
134        let array = DecimalArray::new(
135            buffer![100i32],
136            DecimalDType::new(10, 2),
137            Validity::NonNullable,
138        );
139
140        // Try to cast to different precision
141        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
142        let result = cast(array.as_ref(), &different_dtype);
143
144        assert!(result.is_err());
145        assert!(
146            result
147                .unwrap_err()
148                .to_string()
149                .contains("Cannot cast decimal(10,2) to decimal(15,3)")
150        );
151    }
152
153    #[test]
154    fn cast_to_non_decimal_returns_err() {
155        let array = DecimalArray::new(
156            buffer![100i32],
157            DecimalDType::new(10, 2),
158            Validity::NonNullable,
159        );
160
161        // Try to cast to non-decimal type - should fail since no kernel can handle it
162        let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
163
164        assert!(result.is_err());
165        assert!(
166            result
167                .unwrap_err()
168                .to_string()
169                .contains("No compute kernel to cast")
170        );
171    }
172
173    #[rstest]
174    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
175    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
176    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
177    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
178    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
179        test_cast_conformance(array.as_ref());
180    }
181}