vortex_array/arrays/decimal/compute/
cast.rs1use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::DecimalType;
7use vortex_dtype::NativeDecimalType;
8use vortex_dtype::match_each_decimal_value_type;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_panic;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::compute::CastKernel;
19use crate::vtable::ValidityHelper;
20
21impl CastKernel for DecimalVTable {
22 fn cast(
23 array: &DecimalArray,
24 dtype: &DType,
25 _ctx: &mut ExecutionCtx,
26 ) -> VortexResult<Option<ArrayRef>> {
27 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
29 return Ok(None);
30 };
31 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
32 vortex_panic!(
33 "DecimalArray must have decimal dtype, got {:?}",
34 array.dtype()
35 );
36 };
37
38 if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
40 vortex_bail!(
41 "Casting decimal with scale {} to scale {} not yet implemented",
42 from_decimal_dtype.scale(),
43 to_decimal_dtype.scale()
44 );
45 }
46
47 if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
49 vortex_bail!(
50 "Downcasting decimal from precision {} to {} not yet implemented",
51 from_decimal_dtype.precision(),
52 to_decimal_dtype.precision()
53 );
54 }
55
56 if array.dtype() == dtype {
58 return Ok(Some(array.to_array()));
59 }
60
61 let new_validity = array
63 .validity()
64 .clone()
65 .cast_nullability(*to_nullability, array.len())?;
66
67 let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
69 let array = if target_values_type > array.values_type() {
70 upcast_decimal_values(array, target_values_type)?
71 } else {
72 array.clone()
73 };
74
75 unsafe {
77 Ok(Some(
78 DecimalArray::new_unchecked_handle(
79 array.buffer_handle().clone(),
80 array.values_type(),
81 *to_decimal_dtype,
82 new_validity,
83 )
84 .to_array(),
85 ))
86 }
87 }
88}
89
90pub fn upcast_decimal_values(
102 array: &DecimalArray,
103 to_values_type: DecimalType,
104) -> VortexResult<DecimalArray> {
105 let from_values_type = array.values_type();
106
107 if from_values_type == to_values_type {
109 return Ok(array.clone());
110 }
111
112 if to_values_type < from_values_type {
114 vortex_bail!(
115 "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
116 from_values_type,
117 to_values_type
118 );
119 }
120
121 let decimal_dtype = array.decimal_dtype();
122 let validity = array.validity().clone();
123
124 match_each_decimal_value_type!(from_values_type, |F| {
126 let from_buffer = array.buffer::<F>();
127 match_each_decimal_value_type!(to_values_type, |T| {
128 let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
129 Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
130 })
131 })
132}
133
134fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
137 from.iter()
138 .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
139 .collect()
140}
141
142#[cfg(test)]
143mod tests {
144 use rstest::rstest;
145 use vortex_buffer::buffer;
146 use vortex_dtype::DType;
147 use vortex_dtype::DecimalDType;
148 use vortex_dtype::DecimalType;
149 use vortex_dtype::Nullability;
150
151 use super::upcast_decimal_values;
152 use crate::IntoArray;
153 use crate::arrays::DecimalArray;
154 use crate::builtins::ArrayBuiltins;
155 use crate::canonical::ToCanonical;
156 use crate::compute::conformance::cast::test_cast_conformance;
157 use crate::validity::Validity;
158 use crate::vtable::ValidityHelper;
159
160 #[test]
161 fn cast_decimal_to_nullable() {
162 let decimal_dtype = DecimalDType::new(10, 2);
163 let array = DecimalArray::new(
164 buffer![100i32, 200, 300],
165 decimal_dtype,
166 Validity::NonNullable,
167 );
168
169 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
171 let casted = array
172 .to_array()
173 .cast(nullable_dtype.clone())
174 .unwrap()
175 .to_decimal();
176
177 assert_eq!(casted.dtype(), &nullable_dtype);
178 assert_eq!(casted.validity(), &Validity::AllValid);
179 assert_eq!(casted.len(), 3);
180 }
181
182 #[test]
183 fn cast_nullable_to_non_nullable() {
184 let decimal_dtype = DecimalDType::new(10, 2);
185
186 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
188
189 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
191 let casted = array
192 .to_array()
193 .cast(non_nullable_dtype.clone())
194 .unwrap()
195 .to_decimal();
196
197 assert_eq!(casted.dtype(), &non_nullable_dtype);
198 assert_eq!(casted.validity(), &Validity::NonNullable);
199 }
200
201 #[test]
202 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
203 fn cast_nullable_with_nulls_to_non_nullable_fails() {
204 let decimal_dtype = DecimalDType::new(10, 2);
205
206 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
208
209 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
211 array
212 .to_array()
213 .cast(non_nullable_dtype)
214 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
215 .unwrap();
216 }
217
218 #[test]
219 fn cast_different_scale_fails() {
220 let array = DecimalArray::new(
221 buffer![100i32],
222 DecimalDType::new(10, 2),
223 Validity::NonNullable,
224 );
225
226 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
228 let result = array
229 .to_array()
230 .cast(different_dtype)
231 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
232
233 assert!(result.is_err());
234 assert!(
235 result
236 .unwrap_err()
237 .to_string()
238 .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
239 );
240 }
241
242 #[test]
243 fn cast_downcast_precision_fails() {
244 let array = DecimalArray::new(
245 buffer![100i64],
246 DecimalDType::new(18, 2),
247 Validity::NonNullable,
248 );
249
250 let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
252 let result = array
253 .to_array()
254 .cast(smaller_dtype)
255 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
256
257 assert!(result.is_err());
258 assert!(
259 result
260 .unwrap_err()
261 .to_string()
262 .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
263 );
264 }
265
266 #[test]
267 fn cast_upcast_precision_succeeds() {
268 let array = DecimalArray::new(
269 buffer![100i32, 200, 300],
270 DecimalDType::new(10, 2),
271 Validity::NonNullable,
272 );
273
274 let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
276 let casted = array.to_array().cast(wider_dtype).unwrap().to_decimal();
277
278 assert_eq!(casted.precision(), 38);
279 assert_eq!(casted.scale(), 2);
280 assert_eq!(casted.len(), 3);
281 assert_eq!(casted.values_type(), DecimalType::I128);
283 }
284
285 #[test]
286 fn cast_to_non_decimal_returns_err() {
287 let array = DecimalArray::new(
288 buffer![100i32],
289 DecimalDType::new(10, 2),
290 Validity::NonNullable,
291 );
292
293 let result = array
295 .to_array()
296 .cast(DType::Utf8(Nullability::NonNullable))
297 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
298
299 assert!(result.is_err());
300 assert!(
301 result
302 .unwrap_err()
303 .to_string()
304 .contains("No CastKernel to cast canonical array")
305 );
306 }
307
308 #[rstest]
309 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
310 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
311 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
312 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
313 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
314 test_cast_conformance(array.as_ref());
315 }
316
317 #[test]
318 fn upcast_decimal_values_i32_to_i64() {
319 let decimal_dtype = DecimalDType::new(10, 2);
320 let array = DecimalArray::new(
321 buffer![100i32, 200, 300],
322 decimal_dtype,
323 Validity::NonNullable,
324 );
325
326 assert_eq!(array.values_type(), DecimalType::I32);
327
328 let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
329
330 assert_eq!(casted.values_type(), DecimalType::I64);
331 assert_eq!(casted.decimal_dtype(), decimal_dtype);
332 assert_eq!(casted.len(), 3);
333
334 let buffer = casted.buffer::<i64>();
336 assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
337 }
338
339 #[test]
340 fn upcast_decimal_values_i64_to_i128() {
341 let decimal_dtype = DecimalDType::new(18, 4);
342 let array = DecimalArray::new(
343 buffer![10000i64, 20000, 30000],
344 decimal_dtype,
345 Validity::NonNullable,
346 );
347
348 let casted = upcast_decimal_values(&array, DecimalType::I128).unwrap();
349
350 assert_eq!(casted.values_type(), DecimalType::I128);
351 assert_eq!(casted.decimal_dtype(), decimal_dtype);
352
353 let buffer = casted.buffer::<i128>();
354 assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
355 }
356
357 #[test]
358 fn upcast_decimal_values_same_type_returns_clone() {
359 let decimal_dtype = DecimalDType::new(10, 2);
360 let array = DecimalArray::new(
361 buffer![100i32, 200, 300],
362 decimal_dtype,
363 Validity::NonNullable,
364 );
365
366 let casted = upcast_decimal_values(&array, DecimalType::I32).unwrap();
367
368 assert_eq!(casted.values_type(), DecimalType::I32);
369 assert_eq!(casted.decimal_dtype(), decimal_dtype);
370 }
371
372 #[test]
373 fn upcast_decimal_values_with_nulls() {
374 let decimal_dtype = DecimalDType::new(10, 2);
375 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
376
377 let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
378
379 assert_eq!(casted.values_type(), DecimalType::I64);
380 assert_eq!(casted.len(), 3);
381
382 let mask = casted.validity_mask().unwrap();
384 assert!(mask.value(0));
385 assert!(!mask.value(1));
386 assert!(mask.value(2));
387
388 let buffer = casted.buffer::<i64>();
390 assert_eq!(buffer[0], 100);
391 assert_eq!(buffer[2], 300);
392 }
393
394 #[test]
395 fn upcast_decimal_values_downcast_fails() {
396 let decimal_dtype = DecimalDType::new(18, 4);
397 let array = DecimalArray::new(
398 buffer![10000i64, 20000, 30000],
399 decimal_dtype,
400 Validity::NonNullable,
401 );
402
403 let result = upcast_decimal_values(&array, DecimalType::I32);
405 assert!(result.is_err());
406 assert!(
407 result
408 .unwrap_err()
409 .to_string()
410 .contains("Cannot downcast decimal values")
411 );
412 }
413}