1use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_panic;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Decimal;
15use crate::arrays::DecimalArray;
16use crate::dtype::DType;
17use crate::dtype::DecimalType;
18use crate::dtype::NativeDecimalType;
19use crate::match_each_decimal_value_type;
20use crate::scalar_fn::fns::cast::CastKernel;
21use crate::scalar_fn::fns::cast::CastReduce;
22
23impl CastReduce for Decimal {
24 fn cast(array: ArrayView<'_, Decimal>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
25 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
28 return Ok(None);
29 };
30 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
31 vortex_panic!(
32 "DecimalArray must have decimal dtype, got {:?}",
33 array.dtype()
34 );
35 };
36
37 if from_decimal_dtype != to_decimal_dtype {
38 return Ok(None);
39 }
40
41 let Some(new_validity) = array
42 .validity()?
43 .trivial_cast_nullability(*to_nullability, array.len())?
44 else {
45 return Ok(None);
46 };
47
48 unsafe {
50 Ok(Some(
51 DecimalArray::new_unchecked_handle(
52 array.buffer_handle().clone(),
53 array.values_type(),
54 *to_decimal_dtype,
55 new_validity,
56 )
57 .into_array(),
58 ))
59 }
60 }
61}
62
63impl CastKernel for Decimal {
64 fn cast(
65 array: ArrayView<'_, Decimal>,
66 dtype: &DType,
67 ctx: &mut ExecutionCtx,
68 ) -> VortexResult<Option<ArrayRef>> {
69 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
71 return Ok(None);
72 };
73 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
74 vortex_panic!(
75 "DecimalArray must have decimal dtype, got {:?}",
76 array.dtype()
77 );
78 };
79
80 if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
82 vortex_bail!(
83 "Casting decimal with scale {} to scale {} not yet implemented",
84 from_decimal_dtype.scale(),
85 to_decimal_dtype.scale()
86 );
87 }
88
89 if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
91 vortex_bail!(
92 "Downcasting decimal from precision {} to {} not yet implemented",
93 from_decimal_dtype.precision(),
94 to_decimal_dtype.precision()
95 );
96 }
97
98 if array.dtype() == dtype {
100 return Ok(Some(array.array().clone()));
101 }
102
103 let new_validity = array
105 .validity()?
106 .cast_nullability(*to_nullability, array.len(), ctx)?;
107
108 let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
110 let array = if target_values_type > array.values_type() {
111 upcast_decimal_values(array, target_values_type)?
112 } else {
113 array.array().as_::<Decimal>().into_owned()
114 };
115
116 unsafe {
118 Ok(Some(
119 DecimalArray::new_unchecked_handle(
120 array.buffer_handle().clone(),
121 array.values_type(),
122 *to_decimal_dtype,
123 new_validity,
124 )
125 .into_array(),
126 ))
127 }
128 }
129}
130
131pub fn upcast_decimal_values(
143 array: ArrayView<'_, Decimal>,
144 to_values_type: DecimalType,
145) -> VortexResult<DecimalArray> {
146 let from_values_type = array.values_type();
147
148 if from_values_type == to_values_type {
150 return Ok(array.array().as_::<Decimal>().into_owned());
151 }
152
153 if to_values_type < from_values_type {
155 vortex_bail!(
156 "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
157 from_values_type,
158 to_values_type
159 );
160 }
161
162 let decimal_dtype = array.decimal_dtype();
163 let validity = array.validity()?;
164
165 match_each_decimal_value_type!(from_values_type, |F| {
167 let from_buffer = array.buffer::<F>();
168 match_each_decimal_value_type!(to_values_type, |T| {
169 let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
170 Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
171 })
172 })
173}
174
175fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
178 from.iter()
179 .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
180 .collect()
181}
182
183#[cfg(test)]
184mod tests {
185 use rstest::rstest;
186 use vortex_buffer::buffer;
187
188 use super::upcast_decimal_values;
189 use crate::IntoArray;
190 use crate::LEGACY_SESSION;
191 use crate::VortexSessionExecute;
192 use crate::arrays::DecimalArray;
193 use crate::builtins::ArrayBuiltins;
194 #[expect(deprecated)]
195 use crate::canonical::ToCanonical as _;
196 use crate::compute::conformance::cast::test_cast_conformance;
197 use crate::dtype::DType;
198 use crate::dtype::DecimalDType;
199 use crate::dtype::DecimalType;
200 use crate::dtype::Nullability;
201 use crate::validity::Validity;
202
203 #[test]
204 fn cast_decimal_to_nullable() {
205 let decimal_dtype = DecimalDType::new(10, 2);
206 let array = DecimalArray::new(
207 buffer![100i32, 200, 300],
208 decimal_dtype,
209 Validity::NonNullable,
210 );
211
212 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
214 #[expect(deprecated)]
215 let casted = array
216 .into_array()
217 .cast(nullable_dtype.clone())
218 .unwrap()
219 .to_decimal();
220
221 assert_eq!(casted.dtype(), &nullable_dtype);
222 assert!(matches!(casted.validity(), Ok(Validity::AllValid)));
223 assert_eq!(casted.len(), 3);
224 }
225
226 #[test]
227 fn cast_nullable_to_non_nullable() {
228 let decimal_dtype = DecimalDType::new(10, 2);
229
230 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
232
233 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
235 #[expect(deprecated)]
236 let casted = array
237 .into_array()
238 .cast(non_nullable_dtype.clone())
239 .unwrap()
240 .to_decimal();
241
242 assert_eq!(casted.dtype(), &non_nullable_dtype);
243 assert!(matches!(casted.validity(), Ok(Validity::NonNullable)));
244 }
245
246 #[test]
247 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
248 fn cast_nullable_with_nulls_to_non_nullable_fails() {
249 let decimal_dtype = DecimalDType::new(10, 2);
250
251 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
253
254 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
256 #[expect(deprecated)]
257 let result = array
258 .into_array()
259 .cast(non_nullable_dtype)
260 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
261 result.unwrap();
262 }
263
264 #[test]
265 fn cast_different_scale_fails() {
266 let array = DecimalArray::new(
267 buffer![100i32],
268 DecimalDType::new(10, 2),
269 Validity::NonNullable,
270 );
271
272 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
274 #[expect(deprecated)]
275 let result = array
276 .into_array()
277 .cast(different_dtype)
278 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
279
280 assert!(result.is_err());
281 assert!(
282 result
283 .unwrap_err()
284 .to_string()
285 .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
286 );
287 }
288
289 #[test]
290 fn cast_downcast_precision_fails() {
291 let array = DecimalArray::new(
292 buffer![100i64],
293 DecimalDType::new(18, 2),
294 Validity::NonNullable,
295 );
296
297 let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
299 #[expect(deprecated)]
300 let result = array
301 .into_array()
302 .cast(smaller_dtype)
303 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
304
305 assert!(result.is_err());
306 assert!(
307 result
308 .unwrap_err()
309 .to_string()
310 .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
311 );
312 }
313
314 #[test]
315 fn cast_upcast_precision_succeeds() {
316 let array = DecimalArray::new(
317 buffer![100i32, 200, 300],
318 DecimalDType::new(10, 2),
319 Validity::NonNullable,
320 );
321
322 let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
324 #[expect(deprecated)]
325 let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
326
327 assert_eq!(casted.precision(), 38);
328 assert_eq!(casted.scale(), 2);
329 assert_eq!(casted.len(), 3);
330 assert_eq!(casted.values_type(), DecimalType::I128);
332 }
333
334 #[test]
335 fn cast_to_non_decimal_returns_err() {
336 let array = DecimalArray::new(
337 buffer![100i32],
338 DecimalDType::new(10, 2),
339 Validity::NonNullable,
340 );
341
342 #[expect(deprecated)]
344 let result = array
345 .into_array()
346 .cast(DType::Utf8(Nullability::NonNullable))
347 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
348
349 assert!(result.is_err());
350 assert!(
351 result
352 .unwrap_err()
353 .to_string()
354 .contains("No CastKernel to cast canonical array")
355 );
356 }
357
358 #[rstest]
359 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
360 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
361 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
362 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
363 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
364 test_cast_conformance(&array.into_array());
365 }
366
367 #[test]
368 fn upcast_decimal_values_i32_to_i64() {
369 let decimal_dtype = DecimalDType::new(10, 2);
370 let array = DecimalArray::new(
371 buffer![100i32, 200, 300],
372 decimal_dtype,
373 Validity::NonNullable,
374 );
375
376 assert_eq!(array.values_type(), DecimalType::I32);
377
378 let array = array.as_view();
379 let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
380
381 assert_eq!(casted.values_type(), DecimalType::I64);
382 assert_eq!(casted.decimal_dtype(), decimal_dtype);
383 assert_eq!(casted.len(), 3);
384
385 let buffer = casted.buffer::<i64>();
387 assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
388 }
389
390 #[test]
391 fn upcast_decimal_values_i64_to_i128() {
392 let decimal_dtype = DecimalDType::new(18, 4);
393 let array = DecimalArray::new(
394 buffer![10000i64, 20000, 30000],
395 decimal_dtype,
396 Validity::NonNullable,
397 );
398
399 let array = array.as_view();
400 let casted = upcast_decimal_values(array, DecimalType::I128).unwrap();
401
402 assert_eq!(casted.values_type(), DecimalType::I128);
403 assert_eq!(casted.decimal_dtype(), decimal_dtype);
404
405 let buffer = casted.buffer::<i128>();
406 assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
407 }
408
409 #[test]
410 fn upcast_decimal_values_same_type_returns_clone() {
411 let decimal_dtype = DecimalDType::new(10, 2);
412 let array = DecimalArray::new(
413 buffer![100i32, 200, 300],
414 decimal_dtype,
415 Validity::NonNullable,
416 );
417
418 let array = array.as_view();
419 let casted = upcast_decimal_values(array, DecimalType::I32).unwrap();
420
421 assert_eq!(casted.values_type(), DecimalType::I32);
422 assert_eq!(casted.decimal_dtype(), decimal_dtype);
423 }
424
425 #[test]
426 fn upcast_decimal_values_with_nulls() {
427 let decimal_dtype = DecimalDType::new(10, 2);
428 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
429
430 let array = array.as_view();
431 let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
432
433 assert_eq!(casted.values_type(), DecimalType::I64);
434 assert_eq!(casted.len(), 3);
435
436 let mask = casted
438 .as_ref()
439 .validity()
440 .unwrap()
441 .execute_mask(
442 casted.as_ref().len(),
443 &mut LEGACY_SESSION.create_execution_ctx(),
444 )
445 .unwrap();
446 assert!(mask.value(0));
447 assert!(!mask.value(1));
448 assert!(mask.value(2));
449
450 let buffer = casted.buffer::<i64>();
452 assert_eq!(buffer[0], 100);
453 assert_eq!(buffer[2], 300);
454 }
455
456 #[test]
457 fn upcast_decimal_values_downcast_fails() {
458 let decimal_dtype = DecimalDType::new(18, 4);
459 let array = DecimalArray::new(
460 buffer![10000i64, 20000, 30000],
461 decimal_dtype,
462 Validity::NonNullable,
463 );
464
465 let array = array.as_view();
467 let result = upcast_decimal_values(array, DecimalType::I32);
468 assert!(result.is_err());
469 assert!(
470 result
471 .unwrap_err()
472 .to_string()
473 .contains("Cannot downcast decimal values")
474 );
475 }
476}