vortex_array/arrays/decimal/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14 fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15 let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17 return Ok(None);
18 };
19 let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20 vortex_panic!(
21 "DecimalArray must have decimal dtype, got {:?}",
22 array.dtype()
23 );
24 };
25
26 if from_precision_scale != to_precision_scale {
28 vortex_bail!(
29 "Cannot cast decimal({},{}) to decimal({},{})",
30 from_precision_scale.precision(),
31 from_precision_scale.scale(),
32 to_precision_scale.precision(),
33 to_precision_scale.scale()
34 );
35 }
36
37 if array.dtype() == dtype {
39 return Ok(Some(array.to_array()));
40 }
41
42 let new_validity = array.validity().clone().cast_nullability(*to_nullability)?;
44
45 Ok(Some(
47 DecimalArray {
48 dtype: DType::Decimal(*from_precision_scale, *to_nullability),
49 values: array.byte_buffer(),
50 values_type: array.values_type(),
51 validity: new_validity,
52 stats_set: ArrayStats::default(),
53 }
54 .to_array(),
55 ))
56 }
57}
58
59register_kernel!(CastKernelAdapter(DecimalVTable).lift());
60
61#[cfg(test)]
62mod tests {
63 use rstest::rstest;
64 use vortex_buffer::buffer;
65 use vortex_dtype::{DType, DecimalDType, Nullability};
66
67 use crate::arrays::DecimalArray;
68 use crate::canonical::ToCanonical;
69 use crate::compute::cast;
70 use crate::compute::conformance::cast::test_cast_conformance;
71 use crate::validity::Validity;
72 use crate::vtable::ValidityHelper;
73
74 #[test]
75 fn cast_decimal_to_nullable() {
76 let decimal_dtype = DecimalDType::new(10, 2);
77 let array = DecimalArray::new(
78 buffer![100i32, 200, 300],
79 decimal_dtype,
80 Validity::NonNullable,
81 );
82
83 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
85 let casted = cast(array.as_ref(), &nullable_dtype).unwrap().to_decimal();
86
87 assert_eq!(casted.dtype(), &nullable_dtype);
88 assert_eq!(casted.validity(), &Validity::AllValid);
89 assert_eq!(casted.len(), 3);
90 }
91
92 #[test]
93 fn cast_nullable_to_non_nullable() {
94 let decimal_dtype = DecimalDType::new(10, 2);
95
96 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
98
99 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
101 let casted = cast(array.as_ref(), &non_nullable_dtype)
102 .unwrap()
103 .to_decimal();
104
105 assert_eq!(casted.dtype(), &non_nullable_dtype);
106 assert_eq!(casted.validity(), &Validity::NonNullable);
107 }
108
109 #[test]
110 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
111 fn cast_nullable_with_nulls_to_non_nullable_fails() {
112 let decimal_dtype = DecimalDType::new(10, 2);
113
114 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
116
117 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
119 cast(array.as_ref(), &non_nullable_dtype).unwrap();
120 }
121
122 #[test]
123 fn cast_different_precision_fails() {
124 let array = DecimalArray::new(
125 buffer![100i32],
126 DecimalDType::new(10, 2),
127 Validity::NonNullable,
128 );
129
130 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
132 let result = cast(array.as_ref(), &different_dtype);
133
134 assert!(result.is_err());
135 assert!(
136 result
137 .unwrap_err()
138 .to_string()
139 .contains("Cannot cast decimal(10,2) to decimal(15,3)")
140 );
141 }
142
143 #[test]
144 fn cast_to_non_decimal_returns_err() {
145 let array = DecimalArray::new(
146 buffer![100i32],
147 DecimalDType::new(10, 2),
148 Validity::NonNullable,
149 );
150
151 let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
153
154 assert!(result.is_err());
155 assert!(
156 result
157 .unwrap_err()
158 .to_string()
159 .contains("No compute kernel to cast")
160 );
161 }
162
163 #[rstest]
164 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
165 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
166 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
167 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
168 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
169 test_cast_conformance(array.as_ref());
170 }
171}