vortex_array/arrays/decimal/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::VortexResult;
6use vortex_error::vortex_bail;
7use vortex_error::vortex_panic;
8
9use crate::ArrayRef;
10use crate::arrays::DecimalArray;
11use crate::arrays::DecimalVTable;
12use crate::compute::CastKernel;
13use crate::compute::CastKernelAdapter;
14use crate::register_kernel;
15use crate::stats::ArrayStats;
16use crate::vtable::ValidityHelper;
17
18impl CastKernel for DecimalVTable {
19 fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20 let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
22 return Ok(None);
23 };
24 let DType::Decimal(from_precision_scale, _) = array.dtype() else {
25 vortex_panic!(
26 "DecimalArray must have decimal dtype, got {:?}",
27 array.dtype()
28 );
29 };
30
31 if from_precision_scale != to_precision_scale {
33 vortex_bail!(
34 "Cannot cast decimal({},{}) to decimal({},{})",
35 from_precision_scale.precision(),
36 from_precision_scale.scale(),
37 to_precision_scale.precision(),
38 to_precision_scale.scale()
39 );
40 }
41
42 if array.dtype() == dtype {
44 return Ok(Some(array.to_array()));
45 }
46
47 let new_validity = array
49 .validity()
50 .clone()
51 .cast_nullability(*to_nullability, array.len())?;
52
53 Ok(Some(
55 DecimalArray {
56 dtype: DType::Decimal(*from_precision_scale, *to_nullability),
57 values: array.byte_buffer(),
58 values_type: array.values_type(),
59 validity: new_validity,
60 stats_set: ArrayStats::default(),
61 }
62 .to_array(),
63 ))
64 }
65}
66
67register_kernel!(CastKernelAdapter(DecimalVTable).lift());
68
69#[cfg(test)]
70mod tests {
71 use rstest::rstest;
72 use vortex_buffer::buffer;
73 use vortex_dtype::DType;
74 use vortex_dtype::DecimalDType;
75 use vortex_dtype::Nullability;
76
77 use crate::arrays::DecimalArray;
78 use crate::canonical::ToCanonical;
79 use crate::compute::cast;
80 use crate::compute::conformance::cast::test_cast_conformance;
81 use crate::validity::Validity;
82 use crate::vtable::ValidityHelper;
83
84 #[test]
85 fn cast_decimal_to_nullable() {
86 let decimal_dtype = DecimalDType::new(10, 2);
87 let array = DecimalArray::new(
88 buffer![100i32, 200, 300],
89 decimal_dtype,
90 Validity::NonNullable,
91 );
92
93 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
95 let casted = cast(array.as_ref(), &nullable_dtype).unwrap().to_decimal();
96
97 assert_eq!(casted.dtype(), &nullable_dtype);
98 assert_eq!(casted.validity(), &Validity::AllValid);
99 assert_eq!(casted.len(), 3);
100 }
101
102 #[test]
103 fn cast_nullable_to_non_nullable() {
104 let decimal_dtype = DecimalDType::new(10, 2);
105
106 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
108
109 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
111 let casted = cast(array.as_ref(), &non_nullable_dtype)
112 .unwrap()
113 .to_decimal();
114
115 assert_eq!(casted.dtype(), &non_nullable_dtype);
116 assert_eq!(casted.validity(), &Validity::NonNullable);
117 }
118
119 #[test]
120 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
121 fn cast_nullable_with_nulls_to_non_nullable_fails() {
122 let decimal_dtype = DecimalDType::new(10, 2);
123
124 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
126
127 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
129 cast(array.as_ref(), &non_nullable_dtype).unwrap();
130 }
131
132 #[test]
133 fn cast_different_precision_fails() {
134 let array = DecimalArray::new(
135 buffer![100i32],
136 DecimalDType::new(10, 2),
137 Validity::NonNullable,
138 );
139
140 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
142 let result = cast(array.as_ref(), &different_dtype);
143
144 assert!(result.is_err());
145 assert!(
146 result
147 .unwrap_err()
148 .to_string()
149 .contains("Cannot cast decimal(10,2) to decimal(15,3)")
150 );
151 }
152
153 #[test]
154 fn cast_to_non_decimal_returns_err() {
155 let array = DecimalArray::new(
156 buffer![100i32],
157 DecimalDType::new(10, 2),
158 Validity::NonNullable,
159 );
160
161 let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
163
164 assert!(result.is_err());
165 assert!(
166 result
167 .unwrap_err()
168 .to_string()
169 .contains("No compute kernel to cast")
170 );
171 }
172
173 #[rstest]
174 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
175 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
176 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
177 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
178 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
179 test_cast_conformance(array.as_ref());
180 }
181}