Skip to main content

vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::CheckedMul;
5use vortex_buffer::Buffer;
6use vortex_buffer::BufferMut;
7use vortex_compute::lane_kernels::IndexedSourceExt;
8use vortex_error::VortexError;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14use vortex_mask::Mask;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::arrays::Decimal;
21use crate::arrays::DecimalArray;
22use crate::dtype::BigCast;
23use crate::dtype::DType;
24use crate::dtype::DecimalDType;
25use crate::dtype::DecimalType;
26use crate::dtype::NativeDecimalType;
27use crate::dtype::i256;
28use crate::match_each_decimal_value_type;
29use crate::scalar::DecimalValue;
30use crate::scalar_fn::fns::cast::CastKernel;
31use crate::scalar_fn::fns::cast::CastReduce;
32use crate::validity::Validity;
33
34impl CastReduce for Decimal {
35    fn cast(array: ArrayView<'_, Decimal>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
36        // Only nullability changes within the same decimal dtype are reducible without execution.
37        // Precision/scale changes need the kernel.
38        let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
39            return Ok(None);
40        };
41        let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
42            vortex_panic!(
43                "DecimalArray must have decimal dtype, got {:?}",
44                array.dtype()
45            );
46        };
47
48        if from_decimal_dtype != to_decimal_dtype {
49            return Ok(None);
50        }
51
52        let Some(new_validity) = array
53            .validity()?
54            .trivially_cast_nullability(*to_nullability, array.len())?
55        else {
56            return Ok(None);
57        };
58
59        // SAFETY: validity has the same length, only its nullability tag changes.
60        unsafe {
61            Ok(Some(
62                DecimalArray::new_unchecked_handle(
63                    array.buffer_handle().clone(),
64                    array.values_type(),
65                    *to_decimal_dtype,
66                    new_validity,
67                )
68                .into_array(),
69            ))
70        }
71    }
72}
73
74impl CastKernel for Decimal {
75    fn cast(
76        array: ArrayView<'_, Decimal>,
77        dtype: &DType,
78        ctx: &mut ExecutionCtx,
79    ) -> VortexResult<Option<ArrayRef>> {
80        // Early return if not casting to decimal
81        let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
82            return Ok(None);
83        };
84        let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
85            vortex_panic!(
86                "DecimalArray must have decimal dtype, got {:?}",
87                array.dtype()
88            );
89        };
90
91        // If the dtype is exactly the same, return self
92        if array.dtype() == dtype {
93            return Ok(Some(array.array().clone()));
94        }
95
96        let validity = array.validity()?;
97
98        // Cast the validity to the new nullability
99        let new_validity = validity
100            .clone()
101            .cast_nullability(*to_nullability, array.len(), ctx)?;
102
103        // Reuse the values buffer untouched when no rescale is required, the target precision
104        // only widens (so every value still fits), and the current physical type is already wide
105        // enough to hold the target precision. This keeps the common precision-widening cast
106        // (and pure nullability changes) zero-copy instead of allocating and re-scanning.
107        if from_decimal_dtype.scale() == to_decimal_dtype.scale()
108            && to_decimal_dtype.precision() >= from_decimal_dtype.precision()
109            && array
110                .values_type()
111                .is_compatible_decimal_value_type(*to_decimal_dtype)
112        {
113            // SAFETY: the source values are bit-identical and remain in range for the wider
114            // precision, and new_validity has the same length, only its nullability tag changes.
115            unsafe {
116                return Ok(Some(
117                    DecimalArray::new_unchecked_handle(
118                        array.buffer_handle().clone(),
119                        array.values_type(),
120                        *to_decimal_dtype,
121                        new_validity,
122                    )
123                    .into_array(),
124                ));
125            }
126        }
127
128        let valid_values = validity.execute_mask(array.len(), ctx)?;
129        let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
130
131        match_each_decimal_value_type!(array.values_type(), |F| {
132            match_each_decimal_value_type!(target_values_type, |T| {
133                cast_decimal_values::<F, T>(
134                    array,
135                    *from_decimal_dtype,
136                    *to_decimal_dtype,
137                    new_validity,
138                    &valid_values,
139                )
140                .map(Some)
141            })
142        })
143    }
144}
145
146fn cast_decimal_values<F, T>(
147    array: ArrayView<'_, Decimal>,
148    from_decimal_dtype: DecimalDType,
149    to_decimal_dtype: DecimalDType,
150    validity: Validity,
151    valid_values: &Mask,
152) -> VortexResult<ArrayRef>
153where
154    F: NativeDecimalType,
155    T: NativeDecimalType + CheckedMul,
156    DecimalValue: From<F>,
157{
158    let values = array.buffer::<F>();
159    let values = values.as_slice();
160    let cast_plan = DecimalCastPlan::<T>::new(from_decimal_dtype, to_decimal_dtype);
161
162    let buffer = match valid_values {
163        Mask::AllTrue(_) => {
164            let mut buffer = BufferMut::<T>::with_capacity(values.len());
165            values
166                .try_map_into(&mut buffer.spare_capacity_mut()[..values.len()], |value| {
167                    cast_plan.cast(value)
168                })
169                .map_err(|idx| {
170                    decimal_cast_error::<F, T>(values[idx], from_decimal_dtype, to_decimal_dtype)
171                })?;
172            // SAFETY: try_map_into initializes every lane before returning Ok.
173            unsafe { buffer.set_len(values.len()) };
174            buffer.freeze()
175        }
176        Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
177        Mask::Values(mask) => {
178            let mut buffer = BufferMut::<T>::with_capacity(values.len());
179            values
180                .try_map_masked_into(
181                    mask.bit_buffer(),
182                    &mut buffer.spare_capacity_mut()[..values.len()],
183                    |value| cast_plan.cast(value),
184                )
185                .map_err(|idx| {
186                    decimal_cast_error::<F, T>(values[idx], from_decimal_dtype, to_decimal_dtype)
187                })?;
188            // SAFETY: try_map_masked_into initializes every lane before returning Ok.
189            unsafe { buffer.set_len(values.len()) };
190            buffer.freeze()
191        }
192    };
193
194    Ok(DecimalArray::new(buffer, to_decimal_dtype, validity).into_array())
195}
196
197#[cold]
198fn decimal_cast_error<F, T>(
199    value: F,
200    from_decimal_dtype: DecimalDType,
201    to_decimal_dtype: DecimalDType,
202) -> VortexError
203where
204    F: NativeDecimalType,
205    T: NativeDecimalType,
206    DecimalValue: From<F>,
207{
208    match DecimalValue::from(value)
209        .cast_decimal(from_decimal_dtype, to_decimal_dtype)
210        .and_then(|value| {
211            value.cast::<T>().ok_or_else(|| {
212                vortex_err!(
213                    "decimal value cannot be represented as {} after casting to {}",
214                    T::DECIMAL_TYPE,
215                    to_decimal_dtype
216                )
217            })
218        }) {
219        Ok(_) => {
220            // The fast path only returns `None` for values the slow path also rejects, so this
221            // arm should be unreachable. If it is hit, the fast and slow paths have drifted and
222            // we are erroring on a value that is actually representable.
223            debug_assert!(
224                false,
225                "decimal fast-path cast rejected value {value} that the slow path accepts \
226                 (from {from_decimal_dtype} to {to_decimal_dtype})"
227            );
228            vortex_err!(
229                "decimal value cannot be represented as {} after casting from {} to {}",
230                T::DECIMAL_TYPE,
231                from_decimal_dtype,
232                to_decimal_dtype
233            )
234        }
235        Err(error) => error,
236    }
237}
238
239#[derive(Debug, Clone, Copy)]
240enum DecimalCastPlan<T> {
241    SameScale { min: T, max: T },
242    ScaleUp { factor: T, min: T, max: T },
243    ScaleUpOverflow,
244    ScaleDown { factor: i256, min: i256, max: i256 },
245    ScaleDownOverflow,
246}
247
248impl<T> DecimalCastPlan<T>
249where
250    T: NativeDecimalType + CheckedMul,
251{
252    fn new(from_decimal_dtype: DecimalDType, to_decimal_dtype: DecimalDType) -> Self {
253        let scale_delta = to_decimal_dtype.scale() as i16 - from_decimal_dtype.scale() as i16;
254        if scale_delta == 0 {
255            let (min, max) = decimal_precision_range::<T>(to_decimal_dtype);
256            return Self::SameScale { min, max };
257        }
258
259        if scale_delta > 0 {
260            let Some(factor) = decimal_scale_factor::<T>(scale_delta as u32) else {
261                return Self::ScaleUpOverflow;
262            };
263            let (min, max) = decimal_precision_range::<T>(to_decimal_dtype);
264            return Self::ScaleUp { factor, min, max };
265        }
266
267        let Some(factor) = decimal_scale_factor::<i256>((-scale_delta) as u32) else {
268            return Self::ScaleDownOverflow;
269        };
270        let (min, max) = decimal_precision_range::<i256>(to_decimal_dtype);
271        Self::ScaleDown { factor, min, max }
272    }
273
274    #[inline]
275    fn cast<F>(&self, value: F) -> Option<T>
276    where
277        F: NativeDecimalType,
278    {
279        match *self {
280            DecimalCastPlan::SameScale { min, max } => {
281                let value = <T as BigCast>::from(value)?;
282                (value >= min && value <= max).then_some(value)
283            }
284            DecimalCastPlan::ScaleUp { factor, min, max } => {
285                let value = <T as BigCast>::from(value)?;
286                let value = value.checked_mul(&factor)?;
287                (value >= min && value <= max).then_some(value)
288            }
289            DecimalCastPlan::ScaleUpOverflow | DecimalCastPlan::ScaleDownOverflow => {
290                (value == F::default()).then_some(T::default())
291            }
292            DecimalCastPlan::ScaleDown { factor, min, max } => {
293                let value = <i256 as BigCast>::from(value)?;
294                if value == i256::ZERO {
295                    return Some(T::default());
296                }
297                if value % factor != i256::ZERO {
298                    return None;
299                }
300
301                let value = value / factor;
302                if value < min || value > max {
303                    return None;
304                }
305                <T as BigCast>::from(value)
306            }
307        }
308    }
309}
310
311fn decimal_precision_range<T: NativeDecimalType>(decimal_dtype: DecimalDType) -> (T, T) {
312    let precision = usize::from(decimal_dtype.precision());
313    (
314        T::MIN_BY_PRECISION[precision],
315        T::MAX_BY_PRECISION[precision],
316    )
317}
318
319fn decimal_scale_factor<T>(exp: u32) -> Option<T>
320where
321    T: NativeDecimalType + CheckedMul,
322{
323    let ten = <T as BigCast>::from(10_i8)?;
324    let mut factor = <T as BigCast>::from(1_i8)?;
325    for _ in 0..exp {
326        factor = factor.checked_mul(&ten)?;
327    }
328    Some(factor)
329}
330
331/// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping
332/// the same precision and scale.
333///
334/// This is useful when you need to widen the underlying storage type to accommodate operations
335/// that might overflow the current representation, or to match the physical type expected by
336/// downstream consumers.
337///
338/// # Errors
339///
340/// Returns an error if `to_values_type` is narrower than the array's current values type.
341/// Only upcasting (widening) is supported.
342pub fn upcast_decimal_values(
343    array: ArrayView<'_, Decimal>,
344    to_values_type: DecimalType,
345) -> VortexResult<DecimalArray> {
346    let from_values_type = array.values_type();
347
348    // If already the target type, just clone
349    if from_values_type == to_values_type {
350        return Ok(array.array().as_::<Decimal>().into_owned());
351    }
352
353    // Only allow upcasting (widening)
354    if to_values_type < from_values_type {
355        vortex_bail!(
356            "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
357            from_values_type,
358            to_values_type
359        );
360    }
361
362    let decimal_dtype = array.decimal_dtype();
363    let validity = array.validity()?;
364
365    // Use match_each_decimal_value_type to dispatch based on source and target types
366    match_each_decimal_value_type!(from_values_type, |F| {
367        let from_buffer = array.buffer::<F>();
368        match_each_decimal_value_type!(to_values_type, |T| {
369            let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
370            Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
371        })
372    })
373}
374
375/// Upcast a buffer of decimal values from type F to type T.
376/// Since T is wider than F, this conversion never fails.
377fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
378    from.iter()
379        .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
380        .collect()
381}
382
383#[cfg(test)]
384mod tests {
385    use rstest::rstest;
386    use vortex_buffer::buffer;
387
388    use super::upcast_decimal_values;
389    use crate::IntoArray;
390    use crate::VortexSessionExecute;
391    use crate::array_session;
392    use crate::arrays::DecimalArray;
393    use crate::builtins::ArrayBuiltins;
394    #[expect(deprecated)]
395    use crate::canonical::ToCanonical as _;
396    use crate::compute::conformance::cast::test_cast_conformance;
397    use crate::dtype::DType;
398    use crate::dtype::DecimalDType;
399    use crate::dtype::DecimalType;
400    use crate::dtype::Nullability;
401    use crate::validity::Validity;
402
403    #[test]
404    fn cast_decimal_to_nullable() {
405        let decimal_dtype = DecimalDType::new(10, 2);
406        let array = DecimalArray::new(
407            buffer![100i32, 200, 300],
408            decimal_dtype,
409            Validity::NonNullable,
410        );
411
412        // Cast to nullable
413        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
414        #[expect(deprecated)]
415        let casted = array
416            .into_array()
417            .cast(nullable_dtype.clone())
418            .unwrap()
419            .to_decimal();
420
421        assert_eq!(casted.dtype(), &nullable_dtype);
422        assert!(matches!(casted.validity(), Ok(Validity::AllValid)));
423        assert_eq!(casted.len(), 3);
424    }
425
426    #[test]
427    fn cast_nullable_to_non_nullable() {
428        let decimal_dtype = DecimalDType::new(10, 2);
429
430        // Create nullable array with no nulls
431        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
432
433        // Cast to non-nullable
434        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
435        #[expect(deprecated)]
436        let casted = array
437            .into_array()
438            .cast(non_nullable_dtype.clone())
439            .unwrap()
440            .to_decimal();
441
442        assert_eq!(casted.dtype(), &non_nullable_dtype);
443        assert!(matches!(casted.validity(), Ok(Validity::NonNullable)));
444    }
445
446    #[test]
447    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
448    fn cast_nullable_with_nulls_to_non_nullable_fails() {
449        let decimal_dtype = DecimalDType::new(10, 2);
450
451        // Create nullable array with nulls
452        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
453
454        // Attempt to cast to non-nullable should fail
455        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
456        #[expect(deprecated)]
457        let result = array
458            .into_array()
459            .cast(non_nullable_dtype)
460            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
461        result.unwrap();
462    }
463
464    #[test]
465    fn cast_different_scale_rescales() {
466        let array = DecimalArray::new(
467            buffer![100i32],
468            DecimalDType::new(10, 2),
469            Validity::NonNullable,
470        );
471
472        // Cast 1.00 to scale 3, where it is stored as 1000.
473        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
474        #[expect(deprecated)]
475        let casted = array
476            .into_array()
477            .cast(different_dtype)
478            .unwrap()
479            .to_decimal();
480
481        assert_eq!(casted.precision(), 15);
482        assert_eq!(casted.scale(), 3);
483        assert_eq!(casted.values_type(), DecimalType::I64);
484        assert_eq!(casted.buffer::<i64>().as_ref(), &[1000]);
485    }
486
487    #[test]
488    fn cast_downcast_precision_succeeds_when_values_fit() {
489        let array = DecimalArray::new(
490            buffer![100i64],
491            DecimalDType::new(18, 2),
492            Validity::NonNullable,
493        );
494
495        // Downcasting precision is allowed when every value fits.
496        let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
497        #[expect(deprecated)]
498        let casted = array.into_array().cast(smaller_dtype).unwrap().to_decimal();
499
500        assert_eq!(casted.precision(), 10);
501        assert_eq!(casted.scale(), 2);
502        assert_eq!(casted.buffer::<i64>().as_ref(), &[100]);
503    }
504
505    #[test]
506    fn cast_downcast_precision_checks_values() {
507        let array = DecimalArray::new(
508            buffer![1000i64],
509            DecimalDType::new(18, 0),
510            Validity::NonNullable,
511        );
512
513        let smaller_dtype = DType::Decimal(DecimalDType::new(3, 0), Nullability::NonNullable);
514        #[expect(deprecated)]
515        let result = array
516            .into_array()
517            .cast(smaller_dtype)
518            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
519
520        assert!(result.is_err());
521        assert!(
522            result
523                .unwrap_err()
524                .to_string()
525                .contains("does not fit in precision")
526        );
527    }
528
529    #[test]
530    fn cast_lower_scale_requires_exact_rescale() {
531        let array = DecimalArray::new(
532            buffer![123456i64],
533            DecimalDType::new(10, 4),
534            Validity::NonNullable,
535        );
536
537        let lower_scale_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
538        #[expect(deprecated)]
539        let result = array
540            .into_array()
541            .cast(lower_scale_dtype)
542            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
543
544        assert!(result.is_err());
545        assert!(
546            result
547                .unwrap_err()
548                .to_string()
549                .contains("would lose precision")
550        );
551    }
552
553    #[test]
554    fn cast_lower_scale_ignores_null_lane_failures() {
555        let array = DecimalArray::new(
556            buffer![100i64, 123456],
557            DecimalDType::new(10, 4),
558            Validity::from_iter([true, false]),
559        );
560
561        let lower_scale_dtype = DType::Decimal(DecimalDType::new(3, 2), Nullability::Nullable);
562        #[expect(deprecated)]
563        let casted = array
564            .into_array()
565            .cast(lower_scale_dtype)
566            .unwrap()
567            .to_decimal();
568
569        let mask = casted
570            .as_ref()
571            .validity()
572            .unwrap()
573            .execute_mask(
574                casted.as_ref().len(),
575                &mut array_session().create_execution_ctx(),
576            )
577            .unwrap();
578        assert!(mask.value(0));
579        assert!(!mask.value(1));
580        assert_eq!(casted.buffer::<i16>().as_ref()[0], 1);
581    }
582
583    #[test]
584    fn cast_upcast_precision_succeeds() {
585        let array = DecimalArray::new(
586            buffer![100i32, 200, 300],
587            DecimalDType::new(10, 2),
588            Validity::NonNullable,
589        );
590
591        // Cast to higher precision with same scale - should succeed
592        let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
593        #[expect(deprecated)]
594        let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
595
596        assert_eq!(casted.precision(), 38);
597        assert_eq!(casted.scale(), 2);
598        assert_eq!(casted.len(), 3);
599        // Should be stored in i128 now (precision 38 requires i128)
600        assert_eq!(casted.values_type(), DecimalType::I128);
601    }
602
603    #[test]
604    fn cast_widening_same_physical_type_is_zero_copy() {
605        // Decimal(10,2) and Decimal(18,2) are both physically i64 with the same scale, so widening
606        // the precision must reuse the values buffer rather than allocate and re-scan it.
607        let array = DecimalArray::new(
608            buffer![100i64, 200, 300],
609            DecimalDType::new(10, 2),
610            Validity::NonNullable,
611        );
612        let src_ptr = array.buffer::<i64>().as_ptr();
613
614        let wider_dtype = DType::Decimal(DecimalDType::new(18, 2), Nullability::NonNullable);
615        #[expect(deprecated)]
616        let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
617
618        assert_eq!(casted.precision(), 18);
619        assert_eq!(casted.scale(), 2);
620        assert_eq!(casted.values_type(), DecimalType::I64);
621        assert_eq!(casted.buffer::<i64>().as_ref(), &[100, 200, 300]);
622        // The values buffer must be shared with the source (zero-copy), not reallocated.
623        assert_eq!(
624            casted.buffer::<i64>().as_ptr(),
625            src_ptr,
626            "precision-widening cast must reuse the source values buffer"
627        );
628    }
629
630    #[test]
631    fn cast_to_non_decimal_returns_err() {
632        let array = DecimalArray::new(
633            buffer![100i32],
634            DecimalDType::new(10, 2),
635            Validity::NonNullable,
636        );
637
638        // Try to cast to non-decimal type - should fail since no kernel can handle it
639        #[expect(deprecated)]
640        let result = array
641            .into_array()
642            .cast(DType::Utf8(Nullability::NonNullable))
643            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
644
645        assert!(result.is_err());
646        assert!(
647            result
648                .unwrap_err()
649                .to_string()
650                .contains("No CastKernel to cast canonical array")
651        );
652    }
653
654    #[rstest]
655    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
656    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
657    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
658    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
659    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
660        test_cast_conformance(&array.into_array());
661    }
662
663    #[test]
664    fn upcast_decimal_values_i32_to_i64() {
665        let decimal_dtype = DecimalDType::new(10, 2);
666        let array = DecimalArray::new(
667            buffer![100i32, 200, 300],
668            decimal_dtype,
669            Validity::NonNullable,
670        );
671
672        assert_eq!(array.values_type(), DecimalType::I32);
673
674        let array = array.as_view();
675        let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
676
677        assert_eq!(casted.values_type(), DecimalType::I64);
678        assert_eq!(casted.decimal_dtype(), decimal_dtype);
679        assert_eq!(casted.len(), 3);
680
681        // Verify values are preserved
682        let buffer = casted.buffer::<i64>();
683        assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
684    }
685
686    #[test]
687    fn upcast_decimal_values_i64_to_i128() {
688        let decimal_dtype = DecimalDType::new(18, 4);
689        let array = DecimalArray::new(
690            buffer![10000i64, 20000, 30000],
691            decimal_dtype,
692            Validity::NonNullable,
693        );
694
695        let array = array.as_view();
696        let casted = upcast_decimal_values(array, DecimalType::I128).unwrap();
697
698        assert_eq!(casted.values_type(), DecimalType::I128);
699        assert_eq!(casted.decimal_dtype(), decimal_dtype);
700
701        let buffer = casted.buffer::<i128>();
702        assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
703    }
704
705    #[test]
706    fn upcast_decimal_values_same_type_returns_clone() {
707        let decimal_dtype = DecimalDType::new(10, 2);
708        let array = DecimalArray::new(
709            buffer![100i32, 200, 300],
710            decimal_dtype,
711            Validity::NonNullable,
712        );
713
714        let array = array.as_view();
715        let casted = upcast_decimal_values(array, DecimalType::I32).unwrap();
716
717        assert_eq!(casted.values_type(), DecimalType::I32);
718        assert_eq!(casted.decimal_dtype(), decimal_dtype);
719    }
720
721    #[test]
722    fn upcast_decimal_values_with_nulls() {
723        let decimal_dtype = DecimalDType::new(10, 2);
724        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
725
726        let array = array.as_view();
727        let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
728
729        assert_eq!(casted.values_type(), DecimalType::I64);
730        assert_eq!(casted.len(), 3);
731
732        // Check validity is preserved
733        let mask = casted
734            .as_ref()
735            .validity()
736            .unwrap()
737            .execute_mask(
738                casted.as_ref().len(),
739                &mut array_session().create_execution_ctx(),
740            )
741            .unwrap();
742        assert!(mask.value(0));
743        assert!(!mask.value(1));
744        assert!(mask.value(2));
745
746        // Check non-null values
747        let buffer = casted.buffer::<i64>();
748        assert_eq!(buffer[0], 100);
749        assert_eq!(buffer[2], 300);
750    }
751
752    #[test]
753    fn upcast_decimal_values_downcast_fails() {
754        let decimal_dtype = DecimalDType::new(18, 4);
755        let array = DecimalArray::new(
756            buffer![10000i64, 20000, 30000],
757            decimal_dtype,
758            Validity::NonNullable,
759        );
760
761        // Attempt to downcast from i64 to i32 should fail
762        let array = array.as_view();
763        let result = upcast_decimal_values(array, DecimalType::I32);
764        assert!(result.is_err());
765        assert!(
766            result
767                .unwrap_err()
768                .to_string()
769                .contains("Cannot downcast decimal values")
770        );
771    }
772}