Skip to main content

vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_panic;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Decimal;
15use crate::arrays::DecimalArray;
16use crate::dtype::DType;
17use crate::dtype::DecimalType;
18use crate::dtype::NativeDecimalType;
19use crate::match_each_decimal_value_type;
20use crate::scalar_fn::fns::cast::CastKernel;
21use crate::scalar_fn::fns::cast::CastReduce;
22
23impl CastReduce for Decimal {
24    fn cast(array: ArrayView<'_, Decimal>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
25        // Only nullability changes within the same decimal dtype are reducible without execution.
26        // Precision/scale changes need the kernel.
27        let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
28            return Ok(None);
29        };
30        let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
31            vortex_panic!(
32                "DecimalArray must have decimal dtype, got {:?}",
33                array.dtype()
34            );
35        };
36
37        if from_decimal_dtype != to_decimal_dtype {
38            return Ok(None);
39        }
40
41        let Some(new_validity) = array
42            .validity()?
43            .trivial_cast_nullability(*to_nullability, array.len())?
44        else {
45            return Ok(None);
46        };
47
48        // SAFETY: validity has the same length, only its nullability tag changes.
49        unsafe {
50            Ok(Some(
51                DecimalArray::new_unchecked_handle(
52                    array.buffer_handle().clone(),
53                    array.values_type(),
54                    *to_decimal_dtype,
55                    new_validity,
56                )
57                .into_array(),
58            ))
59        }
60    }
61}
62
63impl CastKernel for Decimal {
64    fn cast(
65        array: ArrayView<'_, Decimal>,
66        dtype: &DType,
67        ctx: &mut ExecutionCtx,
68    ) -> VortexResult<Option<ArrayRef>> {
69        // Early return if not casting to decimal
70        let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
71            return Ok(None);
72        };
73        let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
74            vortex_panic!(
75                "DecimalArray must have decimal dtype, got {:?}",
76                array.dtype()
77            );
78        };
79
80        // Scale changes are not yet supported
81        if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
82            vortex_bail!(
83                "Casting decimal with scale {} to scale {} not yet implemented",
84                from_decimal_dtype.scale(),
85                to_decimal_dtype.scale()
86            );
87        }
88
89        // Downcasting precision is not yet supported
90        if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
91            vortex_bail!(
92                "Downcasting decimal from precision {} to {} not yet implemented",
93                from_decimal_dtype.precision(),
94                to_decimal_dtype.precision()
95            );
96        }
97
98        // If the dtype is exactly the same, return self
99        if array.dtype() == dtype {
100            return Ok(Some(array.array().clone()));
101        }
102
103        // Cast the validity to the new nullability
104        let new_validity = array
105            .validity()?
106            .cast_nullability(*to_nullability, array.len(), ctx)?;
107
108        // If the target needs a wider physical type, upcast the values
109        let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
110        let array = if target_values_type > array.values_type() {
111            upcast_decimal_values(array, target_values_type)?
112        } else {
113            array.array().as_::<Decimal>().into_owned()
114        };
115
116        // SAFETY: new_validity same length as previous validity, just cast
117        unsafe {
118            Ok(Some(
119                DecimalArray::new_unchecked_handle(
120                    array.buffer_handle().clone(),
121                    array.values_type(),
122                    *to_decimal_dtype,
123                    new_validity,
124                )
125                .into_array(),
126            ))
127        }
128    }
129}
130
131/// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping
132/// the same precision and scale.
133///
134/// This is useful when you need to widen the underlying storage type to accommodate operations
135/// that might overflow the current representation, or to match the physical type expected by
136/// downstream consumers.
137///
138/// # Errors
139///
140/// Returns an error if `to_values_type` is narrower than the array's current values type.
141/// Only upcasting (widening) is supported.
142pub fn upcast_decimal_values(
143    array: ArrayView<'_, Decimal>,
144    to_values_type: DecimalType,
145) -> VortexResult<DecimalArray> {
146    let from_values_type = array.values_type();
147
148    // If already the target type, just clone
149    if from_values_type == to_values_type {
150        return Ok(array.array().as_::<Decimal>().into_owned());
151    }
152
153    // Only allow upcasting (widening)
154    if to_values_type < from_values_type {
155        vortex_bail!(
156            "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
157            from_values_type,
158            to_values_type
159        );
160    }
161
162    let decimal_dtype = array.decimal_dtype();
163    let validity = array.validity()?;
164
165    // Use match_each_decimal_value_type to dispatch based on source and target types
166    match_each_decimal_value_type!(from_values_type, |F| {
167        let from_buffer = array.buffer::<F>();
168        match_each_decimal_value_type!(to_values_type, |T| {
169            let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
170            Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
171        })
172    })
173}
174
175/// Upcast a buffer of decimal values from type F to type T.
176/// Since T is wider than F, this conversion never fails.
177fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
178    from.iter()
179        .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
180        .collect()
181}
182
183#[cfg(test)]
184mod tests {
185    use rstest::rstest;
186    use vortex_buffer::buffer;
187
188    use super::upcast_decimal_values;
189    use crate::IntoArray;
190    use crate::LEGACY_SESSION;
191    use crate::VortexSessionExecute;
192    use crate::arrays::DecimalArray;
193    use crate::builtins::ArrayBuiltins;
194    #[expect(deprecated)]
195    use crate::canonical::ToCanonical as _;
196    use crate::compute::conformance::cast::test_cast_conformance;
197    use crate::dtype::DType;
198    use crate::dtype::DecimalDType;
199    use crate::dtype::DecimalType;
200    use crate::dtype::Nullability;
201    use crate::validity::Validity;
202
203    #[test]
204    fn cast_decimal_to_nullable() {
205        let decimal_dtype = DecimalDType::new(10, 2);
206        let array = DecimalArray::new(
207            buffer![100i32, 200, 300],
208            decimal_dtype,
209            Validity::NonNullable,
210        );
211
212        // Cast to nullable
213        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
214        #[expect(deprecated)]
215        let casted = array
216            .into_array()
217            .cast(nullable_dtype.clone())
218            .unwrap()
219            .to_decimal();
220
221        assert_eq!(casted.dtype(), &nullable_dtype);
222        assert!(matches!(casted.validity(), Ok(Validity::AllValid)));
223        assert_eq!(casted.len(), 3);
224    }
225
226    #[test]
227    fn cast_nullable_to_non_nullable() {
228        let decimal_dtype = DecimalDType::new(10, 2);
229
230        // Create nullable array with no nulls
231        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
232
233        // Cast to non-nullable
234        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
235        #[expect(deprecated)]
236        let casted = array
237            .into_array()
238            .cast(non_nullable_dtype.clone())
239            .unwrap()
240            .to_decimal();
241
242        assert_eq!(casted.dtype(), &non_nullable_dtype);
243        assert!(matches!(casted.validity(), Ok(Validity::NonNullable)));
244    }
245
246    #[test]
247    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
248    fn cast_nullable_with_nulls_to_non_nullable_fails() {
249        let decimal_dtype = DecimalDType::new(10, 2);
250
251        // Create nullable array with nulls
252        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
253
254        // Attempt to cast to non-nullable should fail
255        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
256        #[expect(deprecated)]
257        let result = array
258            .into_array()
259            .cast(non_nullable_dtype)
260            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
261        result.unwrap();
262    }
263
264    #[test]
265    fn cast_different_scale_fails() {
266        let array = DecimalArray::new(
267            buffer![100i32],
268            DecimalDType::new(10, 2),
269            Validity::NonNullable,
270        );
271
272        // Try to cast to different scale - not supported
273        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
274        #[expect(deprecated)]
275        let result = array
276            .into_array()
277            .cast(different_dtype)
278            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
279
280        assert!(result.is_err());
281        assert!(
282            result
283                .unwrap_err()
284                .to_string()
285                .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
286        );
287    }
288
289    #[test]
290    fn cast_downcast_precision_fails() {
291        let array = DecimalArray::new(
292            buffer![100i64],
293            DecimalDType::new(18, 2),
294            Validity::NonNullable,
295        );
296
297        // Try to downcast precision - not supported
298        let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
299        #[expect(deprecated)]
300        let result = array
301            .into_array()
302            .cast(smaller_dtype)
303            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
304
305        assert!(result.is_err());
306        assert!(
307            result
308                .unwrap_err()
309                .to_string()
310                .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
311        );
312    }
313
314    #[test]
315    fn cast_upcast_precision_succeeds() {
316        let array = DecimalArray::new(
317            buffer![100i32, 200, 300],
318            DecimalDType::new(10, 2),
319            Validity::NonNullable,
320        );
321
322        // Cast to higher precision with same scale - should succeed
323        let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
324        #[expect(deprecated)]
325        let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
326
327        assert_eq!(casted.precision(), 38);
328        assert_eq!(casted.scale(), 2);
329        assert_eq!(casted.len(), 3);
330        // Should be stored in i128 now (precision 38 requires i128)
331        assert_eq!(casted.values_type(), DecimalType::I128);
332    }
333
334    #[test]
335    fn cast_to_non_decimal_returns_err() {
336        let array = DecimalArray::new(
337            buffer![100i32],
338            DecimalDType::new(10, 2),
339            Validity::NonNullable,
340        );
341
342        // Try to cast to non-decimal type - should fail since no kernel can handle it
343        #[expect(deprecated)]
344        let result = array
345            .into_array()
346            .cast(DType::Utf8(Nullability::NonNullable))
347            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
348
349        assert!(result.is_err());
350        assert!(
351            result
352                .unwrap_err()
353                .to_string()
354                .contains("No CastKernel to cast canonical array")
355        );
356    }
357
358    #[rstest]
359    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
360    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
361    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
362    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
363    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
364        test_cast_conformance(&array.into_array());
365    }
366
367    #[test]
368    fn upcast_decimal_values_i32_to_i64() {
369        let decimal_dtype = DecimalDType::new(10, 2);
370        let array = DecimalArray::new(
371            buffer![100i32, 200, 300],
372            decimal_dtype,
373            Validity::NonNullable,
374        );
375
376        assert_eq!(array.values_type(), DecimalType::I32);
377
378        let array = array.as_view();
379        let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
380
381        assert_eq!(casted.values_type(), DecimalType::I64);
382        assert_eq!(casted.decimal_dtype(), decimal_dtype);
383        assert_eq!(casted.len(), 3);
384
385        // Verify values are preserved
386        let buffer = casted.buffer::<i64>();
387        assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
388    }
389
390    #[test]
391    fn upcast_decimal_values_i64_to_i128() {
392        let decimal_dtype = DecimalDType::new(18, 4);
393        let array = DecimalArray::new(
394            buffer![10000i64, 20000, 30000],
395            decimal_dtype,
396            Validity::NonNullable,
397        );
398
399        let array = array.as_view();
400        let casted = upcast_decimal_values(array, DecimalType::I128).unwrap();
401
402        assert_eq!(casted.values_type(), DecimalType::I128);
403        assert_eq!(casted.decimal_dtype(), decimal_dtype);
404
405        let buffer = casted.buffer::<i128>();
406        assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
407    }
408
409    #[test]
410    fn upcast_decimal_values_same_type_returns_clone() {
411        let decimal_dtype = DecimalDType::new(10, 2);
412        let array = DecimalArray::new(
413            buffer![100i32, 200, 300],
414            decimal_dtype,
415            Validity::NonNullable,
416        );
417
418        let array = array.as_view();
419        let casted = upcast_decimal_values(array, DecimalType::I32).unwrap();
420
421        assert_eq!(casted.values_type(), DecimalType::I32);
422        assert_eq!(casted.decimal_dtype(), decimal_dtype);
423    }
424
425    #[test]
426    fn upcast_decimal_values_with_nulls() {
427        let decimal_dtype = DecimalDType::new(10, 2);
428        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
429
430        let array = array.as_view();
431        let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
432
433        assert_eq!(casted.values_type(), DecimalType::I64);
434        assert_eq!(casted.len(), 3);
435
436        // Check validity is preserved
437        let mask = casted
438            .as_ref()
439            .validity()
440            .unwrap()
441            .execute_mask(
442                casted.as_ref().len(),
443                &mut LEGACY_SESSION.create_execution_ctx(),
444            )
445            .unwrap();
446        assert!(mask.value(0));
447        assert!(!mask.value(1));
448        assert!(mask.value(2));
449
450        // Check non-null values
451        let buffer = casted.buffer::<i64>();
452        assert_eq!(buffer[0], 100);
453        assert_eq!(buffer[2], 300);
454    }
455
456    #[test]
457    fn upcast_decimal_values_downcast_fails() {
458        let decimal_dtype = DecimalDType::new(18, 4);
459        let array = DecimalArray::new(
460            buffer![10000i64, 20000, 30000],
461            decimal_dtype,
462            Validity::NonNullable,
463        );
464
465        // Attempt to downcast from i64 to i32 should fail
466        let array = array.as_view();
467        let result = upcast_decimal_values(array, DecimalType::I32);
468        assert!(result.is_err());
469        assert!(
470            result
471                .unwrap_err()
472                .to_string()
473                .contains("Cannot downcast decimal values")
474        );
475    }
476}