Skip to main content

vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_dtype::DType;
6use vortex_dtype::DecimalType;
7use vortex_dtype::NativeDecimalType;
8use vortex_dtype::match_each_decimal_value_type;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_panic;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::compute::CastKernel;
19use crate::vtable::ValidityHelper;
20
21impl CastKernel for DecimalVTable {
22    fn cast(
23        array: &DecimalArray,
24        dtype: &DType,
25        _ctx: &mut ExecutionCtx,
26    ) -> VortexResult<Option<ArrayRef>> {
27        // Early return if not casting to decimal
28        let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
29            return Ok(None);
30        };
31        let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
32            vortex_panic!(
33                "DecimalArray must have decimal dtype, got {:?}",
34                array.dtype()
35            );
36        };
37
38        // Scale changes are not yet supported
39        if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
40            vortex_bail!(
41                "Casting decimal with scale {} to scale {} not yet implemented",
42                from_decimal_dtype.scale(),
43                to_decimal_dtype.scale()
44            );
45        }
46
47        // Downcasting precision is not yet supported
48        if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
49            vortex_bail!(
50                "Downcasting decimal from precision {} to {} not yet implemented",
51                from_decimal_dtype.precision(),
52                to_decimal_dtype.precision()
53            );
54        }
55
56        // If the dtype is exactly the same, return self
57        if array.dtype() == dtype {
58            return Ok(Some(array.to_array()));
59        }
60
61        // Cast the validity to the new nullability
62        let new_validity = array
63            .validity()
64            .clone()
65            .cast_nullability(*to_nullability, array.len())?;
66
67        // If the target needs a wider physical type, upcast the values
68        let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
69        let array = if target_values_type > array.values_type() {
70            upcast_decimal_values(array, target_values_type)?
71        } else {
72            array.clone()
73        };
74
75        // SAFETY: new_validity same length as previous validity, just cast
76        unsafe {
77            Ok(Some(
78                DecimalArray::new_unchecked_handle(
79                    array.buffer_handle().clone(),
80                    array.values_type(),
81                    *to_decimal_dtype,
82                    new_validity,
83                )
84                .to_array(),
85            ))
86        }
87    }
88}
89
90/// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping
91/// the same precision and scale.
92///
93/// This is useful when you need to widen the underlying storage type to accommodate operations
94/// that might overflow the current representation, or to match the physical type expected by
95/// downstream consumers.
96///
97/// # Errors
98///
99/// Returns an error if `to_values_type` is narrower than the array's current values type.
100/// Only upcasting (widening) is supported.
101pub fn upcast_decimal_values(
102    array: &DecimalArray,
103    to_values_type: DecimalType,
104) -> VortexResult<DecimalArray> {
105    let from_values_type = array.values_type();
106
107    // If already the target type, just clone
108    if from_values_type == to_values_type {
109        return Ok(array.clone());
110    }
111
112    // Only allow upcasting (widening)
113    if to_values_type < from_values_type {
114        vortex_bail!(
115            "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
116            from_values_type,
117            to_values_type
118        );
119    }
120
121    let decimal_dtype = array.decimal_dtype();
122    let validity = array.validity().clone();
123
124    // Use match_each_decimal_value_type to dispatch based on source and target types
125    match_each_decimal_value_type!(from_values_type, |F| {
126        let from_buffer = array.buffer::<F>();
127        match_each_decimal_value_type!(to_values_type, |T| {
128            let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
129            Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
130        })
131    })
132}
133
134/// Upcast a buffer of decimal values from type F to type T.
135/// Since T is wider than F, this conversion never fails.
136fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
137    from.iter()
138        .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
139        .collect()
140}
141
142#[cfg(test)]
143mod tests {
144    use rstest::rstest;
145    use vortex_buffer::buffer;
146    use vortex_dtype::DType;
147    use vortex_dtype::DecimalDType;
148    use vortex_dtype::DecimalType;
149    use vortex_dtype::Nullability;
150
151    use super::upcast_decimal_values;
152    use crate::IntoArray;
153    use crate::arrays::DecimalArray;
154    use crate::builtins::ArrayBuiltins;
155    use crate::canonical::ToCanonical;
156    use crate::compute::conformance::cast::test_cast_conformance;
157    use crate::validity::Validity;
158    use crate::vtable::ValidityHelper;
159
160    #[test]
161    fn cast_decimal_to_nullable() {
162        let decimal_dtype = DecimalDType::new(10, 2);
163        let array = DecimalArray::new(
164            buffer![100i32, 200, 300],
165            decimal_dtype,
166            Validity::NonNullable,
167        );
168
169        // Cast to nullable
170        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
171        let casted = array
172            .to_array()
173            .cast(nullable_dtype.clone())
174            .unwrap()
175            .to_decimal();
176
177        assert_eq!(casted.dtype(), &nullable_dtype);
178        assert_eq!(casted.validity(), &Validity::AllValid);
179        assert_eq!(casted.len(), 3);
180    }
181
182    #[test]
183    fn cast_nullable_to_non_nullable() {
184        let decimal_dtype = DecimalDType::new(10, 2);
185
186        // Create nullable array with no nulls
187        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
188
189        // Cast to non-nullable
190        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
191        let casted = array
192            .to_array()
193            .cast(non_nullable_dtype.clone())
194            .unwrap()
195            .to_decimal();
196
197        assert_eq!(casted.dtype(), &non_nullable_dtype);
198        assert_eq!(casted.validity(), &Validity::NonNullable);
199    }
200
201    #[test]
202    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
203    fn cast_nullable_with_nulls_to_non_nullable_fails() {
204        let decimal_dtype = DecimalDType::new(10, 2);
205
206        // Create nullable array with nulls
207        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
208
209        // Attempt to cast to non-nullable should fail
210        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
211        array
212            .to_array()
213            .cast(non_nullable_dtype)
214            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
215            .unwrap();
216    }
217
218    #[test]
219    fn cast_different_scale_fails() {
220        let array = DecimalArray::new(
221            buffer![100i32],
222            DecimalDType::new(10, 2),
223            Validity::NonNullable,
224        );
225
226        // Try to cast to different scale - not supported
227        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
228        let result = array
229            .to_array()
230            .cast(different_dtype)
231            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
232
233        assert!(result.is_err());
234        assert!(
235            result
236                .unwrap_err()
237                .to_string()
238                .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
239        );
240    }
241
242    #[test]
243    fn cast_downcast_precision_fails() {
244        let array = DecimalArray::new(
245            buffer![100i64],
246            DecimalDType::new(18, 2),
247            Validity::NonNullable,
248        );
249
250        // Try to downcast precision - not supported
251        let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
252        let result = array
253            .to_array()
254            .cast(smaller_dtype)
255            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
256
257        assert!(result.is_err());
258        assert!(
259            result
260                .unwrap_err()
261                .to_string()
262                .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
263        );
264    }
265
266    #[test]
267    fn cast_upcast_precision_succeeds() {
268        let array = DecimalArray::new(
269            buffer![100i32, 200, 300],
270            DecimalDType::new(10, 2),
271            Validity::NonNullable,
272        );
273
274        // Cast to higher precision with same scale - should succeed
275        let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
276        let casted = array.to_array().cast(wider_dtype).unwrap().to_decimal();
277
278        assert_eq!(casted.precision(), 38);
279        assert_eq!(casted.scale(), 2);
280        assert_eq!(casted.len(), 3);
281        // Should be stored in i128 now (precision 38 requires i128)
282        assert_eq!(casted.values_type(), DecimalType::I128);
283    }
284
285    #[test]
286    fn cast_to_non_decimal_returns_err() {
287        let array = DecimalArray::new(
288            buffer![100i32],
289            DecimalDType::new(10, 2),
290            Validity::NonNullable,
291        );
292
293        // Try to cast to non-decimal type - should fail since no kernel can handle it
294        let result = array
295            .to_array()
296            .cast(DType::Utf8(Nullability::NonNullable))
297            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
298
299        assert!(result.is_err());
300        assert!(
301            result
302                .unwrap_err()
303                .to_string()
304                .contains("No CastKernel to cast canonical array")
305        );
306    }
307
308    #[rstest]
309    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
310    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
311    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
312    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
313    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
314        test_cast_conformance(array.as_ref());
315    }
316
317    #[test]
318    fn upcast_decimal_values_i32_to_i64() {
319        let decimal_dtype = DecimalDType::new(10, 2);
320        let array = DecimalArray::new(
321            buffer![100i32, 200, 300],
322            decimal_dtype,
323            Validity::NonNullable,
324        );
325
326        assert_eq!(array.values_type(), DecimalType::I32);
327
328        let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
329
330        assert_eq!(casted.values_type(), DecimalType::I64);
331        assert_eq!(casted.decimal_dtype(), decimal_dtype);
332        assert_eq!(casted.len(), 3);
333
334        // Verify values are preserved
335        let buffer = casted.buffer::<i64>();
336        assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
337    }
338
339    #[test]
340    fn upcast_decimal_values_i64_to_i128() {
341        let decimal_dtype = DecimalDType::new(18, 4);
342        let array = DecimalArray::new(
343            buffer![10000i64, 20000, 30000],
344            decimal_dtype,
345            Validity::NonNullable,
346        );
347
348        let casted = upcast_decimal_values(&array, DecimalType::I128).unwrap();
349
350        assert_eq!(casted.values_type(), DecimalType::I128);
351        assert_eq!(casted.decimal_dtype(), decimal_dtype);
352
353        let buffer = casted.buffer::<i128>();
354        assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
355    }
356
357    #[test]
358    fn upcast_decimal_values_same_type_returns_clone() {
359        let decimal_dtype = DecimalDType::new(10, 2);
360        let array = DecimalArray::new(
361            buffer![100i32, 200, 300],
362            decimal_dtype,
363            Validity::NonNullable,
364        );
365
366        let casted = upcast_decimal_values(&array, DecimalType::I32).unwrap();
367
368        assert_eq!(casted.values_type(), DecimalType::I32);
369        assert_eq!(casted.decimal_dtype(), decimal_dtype);
370    }
371
372    #[test]
373    fn upcast_decimal_values_with_nulls() {
374        let decimal_dtype = DecimalDType::new(10, 2);
375        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
376
377        let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
378
379        assert_eq!(casted.values_type(), DecimalType::I64);
380        assert_eq!(casted.len(), 3);
381
382        // Check validity is preserved
383        let mask = casted.validity_mask().unwrap();
384        assert!(mask.value(0));
385        assert!(!mask.value(1));
386        assert!(mask.value(2));
387
388        // Check non-null values
389        let buffer = casted.buffer::<i64>();
390        assert_eq!(buffer[0], 100);
391        assert_eq!(buffer[2], 300);
392    }
393
394    #[test]
395    fn upcast_decimal_values_downcast_fails() {
396        let decimal_dtype = DecimalDType::new(18, 4);
397        let array = DecimalArray::new(
398            buffer![10000i64, 20000, 30000],
399            decimal_dtype,
400            Validity::NonNullable,
401        );
402
403        // Attempt to downcast from i64 to i32 should fail
404        let result = upcast_decimal_values(&array, DecimalType::I32);
405        assert!(result.is_err());
406        assert!(
407            result
408                .unwrap_err()
409                .to_string()
410                .contains("Cannot downcast decimal values")
411        );
412    }
413}