vortex_array/arrays/decimal/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14 fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15 let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17 return Ok(None);
18 };
19 let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20 vortex_panic!(
21 "DecimalArray must have decimal dtype, got {:?}",
22 array.dtype()
23 );
24 };
25
26 if from_precision_scale != to_precision_scale {
28 vortex_bail!(
29 "Cannot cast decimal({},{}) to decimal({},{})",
30 from_precision_scale.precision(),
31 from_precision_scale.scale(),
32 to_precision_scale.precision(),
33 to_precision_scale.scale()
34 );
35 }
36
37 if array.dtype() == dtype {
39 return Ok(Some(array.to_array()));
40 }
41
42 let new_validity = array.validity().clone().cast_nullability(*to_nullability)?;
44
45 Ok(Some(
47 DecimalArray {
48 dtype: DType::Decimal(*from_precision_scale, *to_nullability),
49 values: array.byte_buffer(),
50 values_type: array.values_type(),
51 validity: new_validity,
52 stats_set: ArrayStats::default(),
53 }
54 .to_array(),
55 ))
56 }
57}
58
59register_kernel!(CastKernelAdapter(DecimalVTable).lift());
60
61#[cfg(test)]
62mod tests {
63 use rstest::rstest;
64 use vortex_buffer::buffer;
65 use vortex_dtype::{DType, DecimalDType, Nullability};
66
67 use crate::arrays::DecimalArray;
68 use crate::canonical::ToCanonical;
69 use crate::compute::cast;
70 use crate::compute::conformance::cast::test_cast_conformance;
71 use crate::validity::Validity;
72 use crate::vtable::ValidityHelper;
73
74 #[test]
75 fn cast_decimal_to_nullable() {
76 let decimal_dtype = DecimalDType::new(10, 2);
77 let array = DecimalArray::new(
78 buffer![100i32, 200, 300],
79 decimal_dtype,
80 Validity::NonNullable,
81 );
82
83 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
85 let casted = cast(array.as_ref(), &nullable_dtype)
86 .unwrap()
87 .to_decimal()
88 .unwrap();
89
90 assert_eq!(casted.dtype(), &nullable_dtype);
91 assert_eq!(casted.validity(), &Validity::AllValid);
92 assert_eq!(casted.len(), 3);
93 }
94
95 #[test]
96 fn cast_nullable_to_non_nullable() {
97 let decimal_dtype = DecimalDType::new(10, 2);
98
99 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
101
102 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
104 let casted = cast(array.as_ref(), &non_nullable_dtype)
105 .unwrap()
106 .to_decimal()
107 .unwrap();
108
109 assert_eq!(casted.dtype(), &non_nullable_dtype);
110 assert_eq!(casted.validity(), &Validity::NonNullable);
111 }
112
113 #[test]
114 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
115 fn cast_nullable_with_nulls_to_non_nullable_fails() {
116 let decimal_dtype = DecimalDType::new(10, 2);
117
118 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
120
121 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
123 cast(array.as_ref(), &non_nullable_dtype).unwrap();
124 }
125
126 #[test]
127 fn cast_different_precision_fails() {
128 let array = DecimalArray::new(
129 buffer![100i32],
130 DecimalDType::new(10, 2),
131 Validity::NonNullable,
132 );
133
134 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
136 let result = cast(array.as_ref(), &different_dtype);
137
138 assert!(result.is_err());
139 assert!(
140 result
141 .unwrap_err()
142 .to_string()
143 .contains("Cannot cast decimal(10,2) to decimal(15,3)")
144 );
145 }
146
147 #[test]
148 fn cast_to_non_decimal_returns_err() {
149 let array = DecimalArray::new(
150 buffer![100i32],
151 DecimalDType::new(10, 2),
152 Validity::NonNullable,
153 );
154
155 let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
157
158 assert!(result.is_err());
159 assert!(
160 result
161 .unwrap_err()
162 .to_string()
163 .contains("No compute kernel to cast")
164 );
165 }
166
167 #[rstest]
168 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
169 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
170 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
171 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
172 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
173 test_cast_conformance(array.as_ref());
174 }
175}