1use num_traits::CheckedMul;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferMut;
7use vortex_compute::lane_kernels::IndexedSourceExt;
8use vortex_error::VortexError;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14use vortex_mask::Mask;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::arrays::Decimal;
21use crate::arrays::DecimalArray;
22use crate::dtype::BigCast;
23use crate::dtype::DType;
24use crate::dtype::DecimalDType;
25use crate::dtype::DecimalType;
26use crate::dtype::NativeDecimalType;
27use crate::dtype::i256;
28use crate::match_each_decimal_value_type;
29use crate::scalar::DecimalValue;
30use crate::scalar_fn::fns::cast::CastKernel;
31use crate::scalar_fn::fns::cast::CastReduce;
32use crate::validity::Validity;
33
34impl CastReduce for Decimal {
35 fn cast(array: ArrayView<'_, Decimal>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
36 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
39 return Ok(None);
40 };
41 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
42 vortex_panic!(
43 "DecimalArray must have decimal dtype, got {:?}",
44 array.dtype()
45 );
46 };
47
48 if from_decimal_dtype != to_decimal_dtype {
49 return Ok(None);
50 }
51
52 let Some(new_validity) = array
53 .validity()?
54 .trivially_cast_nullability(*to_nullability, array.len())?
55 else {
56 return Ok(None);
57 };
58
59 unsafe {
61 Ok(Some(
62 DecimalArray::new_unchecked_handle(
63 array.buffer_handle().clone(),
64 array.values_type(),
65 *to_decimal_dtype,
66 new_validity,
67 )
68 .into_array(),
69 ))
70 }
71 }
72}
73
74impl CastKernel for Decimal {
75 fn cast(
76 array: ArrayView<'_, Decimal>,
77 dtype: &DType,
78 ctx: &mut ExecutionCtx,
79 ) -> VortexResult<Option<ArrayRef>> {
80 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
82 return Ok(None);
83 };
84 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
85 vortex_panic!(
86 "DecimalArray must have decimal dtype, got {:?}",
87 array.dtype()
88 );
89 };
90
91 if array.dtype() == dtype {
93 return Ok(Some(array.array().clone()));
94 }
95
96 let validity = array.validity()?;
97
98 let new_validity = validity
100 .clone()
101 .cast_nullability(*to_nullability, array.len(), ctx)?;
102
103 if from_decimal_dtype.scale() == to_decimal_dtype.scale()
108 && to_decimal_dtype.precision() >= from_decimal_dtype.precision()
109 && array
110 .values_type()
111 .is_compatible_decimal_value_type(*to_decimal_dtype)
112 {
113 unsafe {
116 return Ok(Some(
117 DecimalArray::new_unchecked_handle(
118 array.buffer_handle().clone(),
119 array.values_type(),
120 *to_decimal_dtype,
121 new_validity,
122 )
123 .into_array(),
124 ));
125 }
126 }
127
128 let valid_values = validity.execute_mask(array.len(), ctx)?;
129 let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
130
131 match_each_decimal_value_type!(array.values_type(), |F| {
132 match_each_decimal_value_type!(target_values_type, |T| {
133 cast_decimal_values::<F, T>(
134 array,
135 *from_decimal_dtype,
136 *to_decimal_dtype,
137 new_validity,
138 &valid_values,
139 )
140 .map(Some)
141 })
142 })
143 }
144}
145
146fn cast_decimal_values<F, T>(
147 array: ArrayView<'_, Decimal>,
148 from_decimal_dtype: DecimalDType,
149 to_decimal_dtype: DecimalDType,
150 validity: Validity,
151 valid_values: &Mask,
152) -> VortexResult<ArrayRef>
153where
154 F: NativeDecimalType,
155 T: NativeDecimalType + CheckedMul,
156 DecimalValue: From<F>,
157{
158 let values = array.buffer::<F>();
159 let values = values.as_slice();
160 let cast_plan = DecimalCastPlan::<T>::new(from_decimal_dtype, to_decimal_dtype);
161
162 let buffer = match valid_values {
163 Mask::AllTrue(_) => {
164 let mut buffer = BufferMut::<T>::with_capacity(values.len());
165 values
166 .try_map_into(&mut buffer.spare_capacity_mut()[..values.len()], |value| {
167 cast_plan.cast(value)
168 })
169 .map_err(|idx| {
170 decimal_cast_error::<F, T>(values[idx], from_decimal_dtype, to_decimal_dtype)
171 })?;
172 unsafe { buffer.set_len(values.len()) };
174 buffer.freeze()
175 }
176 Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
177 Mask::Values(mask) => {
178 let mut buffer = BufferMut::<T>::with_capacity(values.len());
179 values
180 .try_map_masked_into(
181 mask.bit_buffer(),
182 &mut buffer.spare_capacity_mut()[..values.len()],
183 |value| cast_plan.cast(value),
184 )
185 .map_err(|idx| {
186 decimal_cast_error::<F, T>(values[idx], from_decimal_dtype, to_decimal_dtype)
187 })?;
188 unsafe { buffer.set_len(values.len()) };
190 buffer.freeze()
191 }
192 };
193
194 Ok(DecimalArray::new(buffer, to_decimal_dtype, validity).into_array())
195}
196
197#[cold]
198fn decimal_cast_error<F, T>(
199 value: F,
200 from_decimal_dtype: DecimalDType,
201 to_decimal_dtype: DecimalDType,
202) -> VortexError
203where
204 F: NativeDecimalType,
205 T: NativeDecimalType,
206 DecimalValue: From<F>,
207{
208 match DecimalValue::from(value)
209 .cast_decimal(from_decimal_dtype, to_decimal_dtype)
210 .and_then(|value| {
211 value.cast::<T>().ok_or_else(|| {
212 vortex_err!(
213 "decimal value cannot be represented as {} after casting to {}",
214 T::DECIMAL_TYPE,
215 to_decimal_dtype
216 )
217 })
218 }) {
219 Ok(_) => {
220 debug_assert!(
224 false,
225 "decimal fast-path cast rejected value {value} that the slow path accepts \
226 (from {from_decimal_dtype} to {to_decimal_dtype})"
227 );
228 vortex_err!(
229 "decimal value cannot be represented as {} after casting from {} to {}",
230 T::DECIMAL_TYPE,
231 from_decimal_dtype,
232 to_decimal_dtype
233 )
234 }
235 Err(error) => error,
236 }
237}
238
239#[derive(Debug, Clone, Copy)]
240enum DecimalCastPlan<T> {
241 SameScale { min: T, max: T },
242 ScaleUp { factor: T, min: T, max: T },
243 ScaleUpOverflow,
244 ScaleDown { factor: i256, min: i256, max: i256 },
245 ScaleDownOverflow,
246}
247
248impl<T> DecimalCastPlan<T>
249where
250 T: NativeDecimalType + CheckedMul,
251{
252 fn new(from_decimal_dtype: DecimalDType, to_decimal_dtype: DecimalDType) -> Self {
253 let scale_delta = to_decimal_dtype.scale() as i16 - from_decimal_dtype.scale() as i16;
254 if scale_delta == 0 {
255 let (min, max) = decimal_precision_range::<T>(to_decimal_dtype);
256 return Self::SameScale { min, max };
257 }
258
259 if scale_delta > 0 {
260 let Some(factor) = decimal_scale_factor::<T>(scale_delta as u32) else {
261 return Self::ScaleUpOverflow;
262 };
263 let (min, max) = decimal_precision_range::<T>(to_decimal_dtype);
264 return Self::ScaleUp { factor, min, max };
265 }
266
267 let Some(factor) = decimal_scale_factor::<i256>((-scale_delta) as u32) else {
268 return Self::ScaleDownOverflow;
269 };
270 let (min, max) = decimal_precision_range::<i256>(to_decimal_dtype);
271 Self::ScaleDown { factor, min, max }
272 }
273
274 #[inline]
275 fn cast<F>(&self, value: F) -> Option<T>
276 where
277 F: NativeDecimalType,
278 {
279 match *self {
280 DecimalCastPlan::SameScale { min, max } => {
281 let value = <T as BigCast>::from(value)?;
282 (value >= min && value <= max).then_some(value)
283 }
284 DecimalCastPlan::ScaleUp { factor, min, max } => {
285 let value = <T as BigCast>::from(value)?;
286 let value = value.checked_mul(&factor)?;
287 (value >= min && value <= max).then_some(value)
288 }
289 DecimalCastPlan::ScaleUpOverflow | DecimalCastPlan::ScaleDownOverflow => {
290 (value == F::default()).then_some(T::default())
291 }
292 DecimalCastPlan::ScaleDown { factor, min, max } => {
293 let value = <i256 as BigCast>::from(value)?;
294 if value == i256::ZERO {
295 return Some(T::default());
296 }
297 if value % factor != i256::ZERO {
298 return None;
299 }
300
301 let value = value / factor;
302 if value < min || value > max {
303 return None;
304 }
305 <T as BigCast>::from(value)
306 }
307 }
308 }
309}
310
311fn decimal_precision_range<T: NativeDecimalType>(decimal_dtype: DecimalDType) -> (T, T) {
312 let precision = usize::from(decimal_dtype.precision());
313 (
314 T::MIN_BY_PRECISION[precision],
315 T::MAX_BY_PRECISION[precision],
316 )
317}
318
319fn decimal_scale_factor<T>(exp: u32) -> Option<T>
320where
321 T: NativeDecimalType + CheckedMul,
322{
323 let ten = <T as BigCast>::from(10_i8)?;
324 let mut factor = <T as BigCast>::from(1_i8)?;
325 for _ in 0..exp {
326 factor = factor.checked_mul(&ten)?;
327 }
328 Some(factor)
329}
330
331pub fn upcast_decimal_values(
343 array: ArrayView<'_, Decimal>,
344 to_values_type: DecimalType,
345) -> VortexResult<DecimalArray> {
346 let from_values_type = array.values_type();
347
348 if from_values_type == to_values_type {
350 return Ok(array.array().as_::<Decimal>().into_owned());
351 }
352
353 if to_values_type < from_values_type {
355 vortex_bail!(
356 "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
357 from_values_type,
358 to_values_type
359 );
360 }
361
362 let decimal_dtype = array.decimal_dtype();
363 let validity = array.validity()?;
364
365 match_each_decimal_value_type!(from_values_type, |F| {
367 let from_buffer = array.buffer::<F>();
368 match_each_decimal_value_type!(to_values_type, |T| {
369 let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
370 Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
371 })
372 })
373}
374
375fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
378 from.iter()
379 .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
380 .collect()
381}
382
383#[cfg(test)]
384mod tests {
385 use rstest::rstest;
386 use vortex_buffer::buffer;
387
388 use super::upcast_decimal_values;
389 use crate::IntoArray;
390 use crate::VortexSessionExecute;
391 use crate::array_session;
392 use crate::arrays::DecimalArray;
393 use crate::builtins::ArrayBuiltins;
394 #[expect(deprecated)]
395 use crate::canonical::ToCanonical as _;
396 use crate::compute::conformance::cast::test_cast_conformance;
397 use crate::dtype::DType;
398 use crate::dtype::DecimalDType;
399 use crate::dtype::DecimalType;
400 use crate::dtype::Nullability;
401 use crate::validity::Validity;
402
403 #[test]
404 fn cast_decimal_to_nullable() {
405 let decimal_dtype = DecimalDType::new(10, 2);
406 let array = DecimalArray::new(
407 buffer![100i32, 200, 300],
408 decimal_dtype,
409 Validity::NonNullable,
410 );
411
412 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
414 #[expect(deprecated)]
415 let casted = array
416 .into_array()
417 .cast(nullable_dtype.clone())
418 .unwrap()
419 .to_decimal();
420
421 assert_eq!(casted.dtype(), &nullable_dtype);
422 assert!(matches!(casted.validity(), Ok(Validity::AllValid)));
423 assert_eq!(casted.len(), 3);
424 }
425
426 #[test]
427 fn cast_nullable_to_non_nullable() {
428 let decimal_dtype = DecimalDType::new(10, 2);
429
430 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
432
433 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
435 #[expect(deprecated)]
436 let casted = array
437 .into_array()
438 .cast(non_nullable_dtype.clone())
439 .unwrap()
440 .to_decimal();
441
442 assert_eq!(casted.dtype(), &non_nullable_dtype);
443 assert!(matches!(casted.validity(), Ok(Validity::NonNullable)));
444 }
445
446 #[test]
447 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
448 fn cast_nullable_with_nulls_to_non_nullable_fails() {
449 let decimal_dtype = DecimalDType::new(10, 2);
450
451 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
453
454 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
456 #[expect(deprecated)]
457 let result = array
458 .into_array()
459 .cast(non_nullable_dtype)
460 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
461 result.unwrap();
462 }
463
464 #[test]
465 fn cast_different_scale_rescales() {
466 let array = DecimalArray::new(
467 buffer![100i32],
468 DecimalDType::new(10, 2),
469 Validity::NonNullable,
470 );
471
472 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
474 #[expect(deprecated)]
475 let casted = array
476 .into_array()
477 .cast(different_dtype)
478 .unwrap()
479 .to_decimal();
480
481 assert_eq!(casted.precision(), 15);
482 assert_eq!(casted.scale(), 3);
483 assert_eq!(casted.values_type(), DecimalType::I64);
484 assert_eq!(casted.buffer::<i64>().as_ref(), &[1000]);
485 }
486
487 #[test]
488 fn cast_downcast_precision_succeeds_when_values_fit() {
489 let array = DecimalArray::new(
490 buffer![100i64],
491 DecimalDType::new(18, 2),
492 Validity::NonNullable,
493 );
494
495 let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
497 #[expect(deprecated)]
498 let casted = array.into_array().cast(smaller_dtype).unwrap().to_decimal();
499
500 assert_eq!(casted.precision(), 10);
501 assert_eq!(casted.scale(), 2);
502 assert_eq!(casted.buffer::<i64>().as_ref(), &[100]);
503 }
504
505 #[test]
506 fn cast_downcast_precision_checks_values() {
507 let array = DecimalArray::new(
508 buffer![1000i64],
509 DecimalDType::new(18, 0),
510 Validity::NonNullable,
511 );
512
513 let smaller_dtype = DType::Decimal(DecimalDType::new(3, 0), Nullability::NonNullable);
514 #[expect(deprecated)]
515 let result = array
516 .into_array()
517 .cast(smaller_dtype)
518 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
519
520 assert!(result.is_err());
521 assert!(
522 result
523 .unwrap_err()
524 .to_string()
525 .contains("does not fit in precision")
526 );
527 }
528
529 #[test]
530 fn cast_lower_scale_requires_exact_rescale() {
531 let array = DecimalArray::new(
532 buffer![123456i64],
533 DecimalDType::new(10, 4),
534 Validity::NonNullable,
535 );
536
537 let lower_scale_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
538 #[expect(deprecated)]
539 let result = array
540 .into_array()
541 .cast(lower_scale_dtype)
542 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
543
544 assert!(result.is_err());
545 assert!(
546 result
547 .unwrap_err()
548 .to_string()
549 .contains("would lose precision")
550 );
551 }
552
553 #[test]
554 fn cast_lower_scale_ignores_null_lane_failures() {
555 let array = DecimalArray::new(
556 buffer![100i64, 123456],
557 DecimalDType::new(10, 4),
558 Validity::from_iter([true, false]),
559 );
560
561 let lower_scale_dtype = DType::Decimal(DecimalDType::new(3, 2), Nullability::Nullable);
562 #[expect(deprecated)]
563 let casted = array
564 .into_array()
565 .cast(lower_scale_dtype)
566 .unwrap()
567 .to_decimal();
568
569 let mask = casted
570 .as_ref()
571 .validity()
572 .unwrap()
573 .execute_mask(
574 casted.as_ref().len(),
575 &mut array_session().create_execution_ctx(),
576 )
577 .unwrap();
578 assert!(mask.value(0));
579 assert!(!mask.value(1));
580 assert_eq!(casted.buffer::<i16>().as_ref()[0], 1);
581 }
582
583 #[test]
584 fn cast_upcast_precision_succeeds() {
585 let array = DecimalArray::new(
586 buffer![100i32, 200, 300],
587 DecimalDType::new(10, 2),
588 Validity::NonNullable,
589 );
590
591 let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
593 #[expect(deprecated)]
594 let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
595
596 assert_eq!(casted.precision(), 38);
597 assert_eq!(casted.scale(), 2);
598 assert_eq!(casted.len(), 3);
599 assert_eq!(casted.values_type(), DecimalType::I128);
601 }
602
603 #[test]
604 fn cast_widening_same_physical_type_is_zero_copy() {
605 let array = DecimalArray::new(
608 buffer![100i64, 200, 300],
609 DecimalDType::new(10, 2),
610 Validity::NonNullable,
611 );
612 let src_ptr = array.buffer::<i64>().as_ptr();
613
614 let wider_dtype = DType::Decimal(DecimalDType::new(18, 2), Nullability::NonNullable);
615 #[expect(deprecated)]
616 let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
617
618 assert_eq!(casted.precision(), 18);
619 assert_eq!(casted.scale(), 2);
620 assert_eq!(casted.values_type(), DecimalType::I64);
621 assert_eq!(casted.buffer::<i64>().as_ref(), &[100, 200, 300]);
622 assert_eq!(
624 casted.buffer::<i64>().as_ptr(),
625 src_ptr,
626 "precision-widening cast must reuse the source values buffer"
627 );
628 }
629
630 #[test]
631 fn cast_to_non_decimal_returns_err() {
632 let array = DecimalArray::new(
633 buffer![100i32],
634 DecimalDType::new(10, 2),
635 Validity::NonNullable,
636 );
637
638 #[expect(deprecated)]
640 let result = array
641 .into_array()
642 .cast(DType::Utf8(Nullability::NonNullable))
643 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
644
645 assert!(result.is_err());
646 assert!(
647 result
648 .unwrap_err()
649 .to_string()
650 .contains("No CastKernel to cast canonical array")
651 );
652 }
653
654 #[rstest]
655 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
656 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
657 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
658 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
659 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
660 test_cast_conformance(&array.into_array());
661 }
662
663 #[test]
664 fn upcast_decimal_values_i32_to_i64() {
665 let decimal_dtype = DecimalDType::new(10, 2);
666 let array = DecimalArray::new(
667 buffer![100i32, 200, 300],
668 decimal_dtype,
669 Validity::NonNullable,
670 );
671
672 assert_eq!(array.values_type(), DecimalType::I32);
673
674 let array = array.as_view();
675 let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
676
677 assert_eq!(casted.values_type(), DecimalType::I64);
678 assert_eq!(casted.decimal_dtype(), decimal_dtype);
679 assert_eq!(casted.len(), 3);
680
681 let buffer = casted.buffer::<i64>();
683 assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
684 }
685
686 #[test]
687 fn upcast_decimal_values_i64_to_i128() {
688 let decimal_dtype = DecimalDType::new(18, 4);
689 let array = DecimalArray::new(
690 buffer![10000i64, 20000, 30000],
691 decimal_dtype,
692 Validity::NonNullable,
693 );
694
695 let array = array.as_view();
696 let casted = upcast_decimal_values(array, DecimalType::I128).unwrap();
697
698 assert_eq!(casted.values_type(), DecimalType::I128);
699 assert_eq!(casted.decimal_dtype(), decimal_dtype);
700
701 let buffer = casted.buffer::<i128>();
702 assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
703 }
704
705 #[test]
706 fn upcast_decimal_values_same_type_returns_clone() {
707 let decimal_dtype = DecimalDType::new(10, 2);
708 let array = DecimalArray::new(
709 buffer![100i32, 200, 300],
710 decimal_dtype,
711 Validity::NonNullable,
712 );
713
714 let array = array.as_view();
715 let casted = upcast_decimal_values(array, DecimalType::I32).unwrap();
716
717 assert_eq!(casted.values_type(), DecimalType::I32);
718 assert_eq!(casted.decimal_dtype(), decimal_dtype);
719 }
720
721 #[test]
722 fn upcast_decimal_values_with_nulls() {
723 let decimal_dtype = DecimalDType::new(10, 2);
724 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
725
726 let array = array.as_view();
727 let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
728
729 assert_eq!(casted.values_type(), DecimalType::I64);
730 assert_eq!(casted.len(), 3);
731
732 let mask = casted
734 .as_ref()
735 .validity()
736 .unwrap()
737 .execute_mask(
738 casted.as_ref().len(),
739 &mut array_session().create_execution_ctx(),
740 )
741 .unwrap();
742 assert!(mask.value(0));
743 assert!(!mask.value(1));
744 assert!(mask.value(2));
745
746 let buffer = casted.buffer::<i64>();
748 assert_eq!(buffer[0], 100);
749 assert_eq!(buffer[2], 300);
750 }
751
752 #[test]
753 fn upcast_decimal_values_downcast_fails() {
754 let decimal_dtype = DecimalDType::new(18, 4);
755 let array = DecimalArray::new(
756 buffer![10000i64, 20000, 30000],
757 decimal_dtype,
758 Validity::NonNullable,
759 );
760
761 let array = array.as_view();
763 let result = upcast_decimal_values(array, DecimalType::I32);
764 assert!(result.is_err());
765 assert!(
766 result
767 .unwrap_err()
768 .to_string()
769 .contains("Cannot downcast decimal values")
770 );
771 }
772}