Skip to main content

vortex_array/arrays/decimal/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_panic;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Decimal;
15use crate::arrays::DecimalArray;
16use crate::dtype::DType;
17use crate::dtype::DecimalType;
18use crate::dtype::NativeDecimalType;
19use crate::match_each_decimal_value_type;
20use crate::scalar_fn::fns::cast::CastKernel;
21
22impl CastKernel for Decimal {
23    fn cast(
24        array: ArrayView<'_, Decimal>,
25        dtype: &DType,
26        _ctx: &mut ExecutionCtx,
27    ) -> VortexResult<Option<ArrayRef>> {
28        // Early return if not casting to decimal
29        let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
30            return Ok(None);
31        };
32        let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
33            vortex_panic!(
34                "DecimalArray must have decimal dtype, got {:?}",
35                array.dtype()
36            );
37        };
38
39        // Scale changes are not yet supported
40        if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
41            vortex_bail!(
42                "Casting decimal with scale {} to scale {} not yet implemented",
43                from_decimal_dtype.scale(),
44                to_decimal_dtype.scale()
45            );
46        }
47
48        // Downcasting precision is not yet supported
49        if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
50            vortex_bail!(
51                "Downcasting decimal from precision {} to {} not yet implemented",
52                from_decimal_dtype.precision(),
53                to_decimal_dtype.precision()
54            );
55        }
56
57        // If the dtype is exactly the same, return self
58        if array.dtype() == dtype {
59            return Ok(Some(array.array().clone()));
60        }
61
62        // Cast the validity to the new nullability
63        let new_validity = array
64            .validity()?
65            .cast_nullability(*to_nullability, array.len())?;
66
67        // If the target needs a wider physical type, upcast the values
68        let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
69        let array = if target_values_type > array.values_type() {
70            upcast_decimal_values(array, target_values_type)?
71        } else {
72            array.array().as_::<Decimal>().into_owned()
73        };
74
75        // SAFETY: new_validity same length as previous validity, just cast
76        unsafe {
77            Ok(Some(
78                DecimalArray::new_unchecked_handle(
79                    array.buffer_handle().clone(),
80                    array.values_type(),
81                    *to_decimal_dtype,
82                    new_validity,
83                )
84                .into_array(),
85            ))
86        }
87    }
88}
89
90/// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping
91/// the same precision and scale.
92///
93/// This is useful when you need to widen the underlying storage type to accommodate operations
94/// that might overflow the current representation, or to match the physical type expected by
95/// downstream consumers.
96///
97/// # Errors
98///
99/// Returns an error if `to_values_type` is narrower than the array's current values type.
100/// Only upcasting (widening) is supported.
101pub fn upcast_decimal_values(
102    array: ArrayView<'_, Decimal>,
103    to_values_type: DecimalType,
104) -> VortexResult<DecimalArray> {
105    let from_values_type = array.values_type();
106
107    // If already the target type, just clone
108    if from_values_type == to_values_type {
109        return Ok(array.array().as_::<Decimal>().into_owned());
110    }
111
112    // Only allow upcasting (widening)
113    if to_values_type < from_values_type {
114        vortex_bail!(
115            "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
116            from_values_type,
117            to_values_type
118        );
119    }
120
121    let decimal_dtype = array.decimal_dtype();
122    let validity = array.validity()?;
123
124    // Use match_each_decimal_value_type to dispatch based on source and target types
125    match_each_decimal_value_type!(from_values_type, |F| {
126        let from_buffer = array.buffer::<F>();
127        match_each_decimal_value_type!(to_values_type, |T| {
128            let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
129            Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
130        })
131    })
132}
133
134/// Upcast a buffer of decimal values from type F to type T.
135/// Since T is wider than F, this conversion never fails.
136fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
137    from.iter()
138        .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
139        .collect()
140}
141
142#[cfg(test)]
143mod tests {
144    use rstest::rstest;
145    use vortex_buffer::buffer;
146
147    use super::upcast_decimal_values;
148    use crate::IntoArray;
149    use crate::LEGACY_SESSION;
150    use crate::VortexSessionExecute;
151    use crate::arrays::DecimalArray;
152    use crate::builtins::ArrayBuiltins;
153    use crate::canonical::ToCanonical;
154    use crate::compute::conformance::cast::test_cast_conformance;
155    use crate::dtype::DType;
156    use crate::dtype::DecimalDType;
157    use crate::dtype::DecimalType;
158    use crate::dtype::Nullability;
159    use crate::validity::Validity;
160
161    #[test]
162    fn cast_decimal_to_nullable() {
163        let decimal_dtype = DecimalDType::new(10, 2);
164        let array = DecimalArray::new(
165            buffer![100i32, 200, 300],
166            decimal_dtype,
167            Validity::NonNullable,
168        );
169
170        // Cast to nullable
171        let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
172        let casted = array
173            .into_array()
174            .cast(nullable_dtype.clone())
175            .unwrap()
176            .to_decimal();
177
178        assert_eq!(casted.dtype(), &nullable_dtype);
179        assert!(matches!(casted.validity(), Ok(Validity::AllValid)));
180        assert_eq!(casted.len(), 3);
181    }
182
183    #[test]
184    fn cast_nullable_to_non_nullable() {
185        let decimal_dtype = DecimalDType::new(10, 2);
186
187        // Create nullable array with no nulls
188        let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
189
190        // Cast to non-nullable
191        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
192        let casted = array
193            .into_array()
194            .cast(non_nullable_dtype.clone())
195            .unwrap()
196            .to_decimal();
197
198        assert_eq!(casted.dtype(), &non_nullable_dtype);
199        assert!(matches!(casted.validity(), Ok(Validity::NonNullable)));
200    }
201
202    #[test]
203    #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
204    fn cast_nullable_with_nulls_to_non_nullable_fails() {
205        let decimal_dtype = DecimalDType::new(10, 2);
206
207        // Create nullable array with nulls
208        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
209
210        // Attempt to cast to non-nullable should fail
211        let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
212        array
213            .into_array()
214            .cast(non_nullable_dtype)
215            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
216            .unwrap();
217    }
218
219    #[test]
220    fn cast_different_scale_fails() {
221        let array = DecimalArray::new(
222            buffer![100i32],
223            DecimalDType::new(10, 2),
224            Validity::NonNullable,
225        );
226
227        // Try to cast to different scale - not supported
228        let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
229        let result = array
230            .into_array()
231            .cast(different_dtype)
232            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
233
234        assert!(result.is_err());
235        assert!(
236            result
237                .unwrap_err()
238                .to_string()
239                .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
240        );
241    }
242
243    #[test]
244    fn cast_downcast_precision_fails() {
245        let array = DecimalArray::new(
246            buffer![100i64],
247            DecimalDType::new(18, 2),
248            Validity::NonNullable,
249        );
250
251        // Try to downcast precision - not supported
252        let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
253        let result = array
254            .into_array()
255            .cast(smaller_dtype)
256            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
257
258        assert!(result.is_err());
259        assert!(
260            result
261                .unwrap_err()
262                .to_string()
263                .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
264        );
265    }
266
267    #[test]
268    fn cast_upcast_precision_succeeds() {
269        let array = DecimalArray::new(
270            buffer![100i32, 200, 300],
271            DecimalDType::new(10, 2),
272            Validity::NonNullable,
273        );
274
275        // Cast to higher precision with same scale - should succeed
276        let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
277        let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
278
279        assert_eq!(casted.precision(), 38);
280        assert_eq!(casted.scale(), 2);
281        assert_eq!(casted.len(), 3);
282        // Should be stored in i128 now (precision 38 requires i128)
283        assert_eq!(casted.values_type(), DecimalType::I128);
284    }
285
286    #[test]
287    fn cast_to_non_decimal_returns_err() {
288        let array = DecimalArray::new(
289            buffer![100i32],
290            DecimalDType::new(10, 2),
291            Validity::NonNullable,
292        );
293
294        // Try to cast to non-decimal type - should fail since no kernel can handle it
295        let result = array
296            .into_array()
297            .cast(DType::Utf8(Nullability::NonNullable))
298            .and_then(|a| a.to_canonical().map(|c| c.into_array()));
299
300        assert!(result.is_err());
301        assert!(
302            result
303                .unwrap_err()
304                .to_string()
305                .contains("No CastKernel to cast canonical array")
306        );
307    }
308
309    #[rstest]
310    #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
311    #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
312    #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
313    #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
314    fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
315        test_cast_conformance(&array.into_array());
316    }
317
318    #[test]
319    fn upcast_decimal_values_i32_to_i64() {
320        let decimal_dtype = DecimalDType::new(10, 2);
321        let array = DecimalArray::new(
322            buffer![100i32, 200, 300],
323            decimal_dtype,
324            Validity::NonNullable,
325        );
326
327        assert_eq!(array.values_type(), DecimalType::I32);
328
329        let array = array.as_view();
330        let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
331
332        assert_eq!(casted.values_type(), DecimalType::I64);
333        assert_eq!(casted.decimal_dtype(), decimal_dtype);
334        assert_eq!(casted.len(), 3);
335
336        // Verify values are preserved
337        let buffer = casted.buffer::<i64>();
338        assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
339    }
340
341    #[test]
342    fn upcast_decimal_values_i64_to_i128() {
343        let decimal_dtype = DecimalDType::new(18, 4);
344        let array = DecimalArray::new(
345            buffer![10000i64, 20000, 30000],
346            decimal_dtype,
347            Validity::NonNullable,
348        );
349
350        let array = array.as_view();
351        let casted = upcast_decimal_values(array, DecimalType::I128).unwrap();
352
353        assert_eq!(casted.values_type(), DecimalType::I128);
354        assert_eq!(casted.decimal_dtype(), decimal_dtype);
355
356        let buffer = casted.buffer::<i128>();
357        assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
358    }
359
360    #[test]
361    fn upcast_decimal_values_same_type_returns_clone() {
362        let decimal_dtype = DecimalDType::new(10, 2);
363        let array = DecimalArray::new(
364            buffer![100i32, 200, 300],
365            decimal_dtype,
366            Validity::NonNullable,
367        );
368
369        let array = array.as_view();
370        let casted = upcast_decimal_values(array, DecimalType::I32).unwrap();
371
372        assert_eq!(casted.values_type(), DecimalType::I32);
373        assert_eq!(casted.decimal_dtype(), decimal_dtype);
374    }
375
376    #[test]
377    fn upcast_decimal_values_with_nulls() {
378        let decimal_dtype = DecimalDType::new(10, 2);
379        let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
380
381        let array = array.as_view();
382        let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
383
384        assert_eq!(casted.values_type(), DecimalType::I64);
385        assert_eq!(casted.len(), 3);
386
387        // Check validity is preserved
388        let mask = casted
389            .as_ref()
390            .validity()
391            .unwrap()
392            .to_mask(
393                casted.as_ref().len(),
394                &mut LEGACY_SESSION.create_execution_ctx(),
395            )
396            .unwrap();
397        assert!(mask.value(0));
398        assert!(!mask.value(1));
399        assert!(mask.value(2));
400
401        // Check non-null values
402        let buffer = casted.buffer::<i64>();
403        assert_eq!(buffer[0], 100);
404        assert_eq!(buffer[2], 300);
405    }
406
407    #[test]
408    fn upcast_decimal_values_downcast_fails() {
409        let decimal_dtype = DecimalDType::new(18, 4);
410        let array = DecimalArray::new(
411            buffer![10000i64, 20000, 30000],
412            decimal_dtype,
413            Validity::NonNullable,
414        );
415
416        // Attempt to downcast from i64 to i32 should fail
417        let array = array.as_view();
418        let result = upcast_decimal_values(array, DecimalType::I32);
419        assert!(result.is_err());
420        assert!(
421            result
422                .unwrap_err()
423                .to_string()
424                .contains("Cannot downcast decimal values")
425        );
426    }
427}