vortex_array/arrays/decimal/compute/
cast.rs1use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_panic;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Decimal;
15use crate::arrays::DecimalArray;
16use crate::dtype::DType;
17use crate::dtype::DecimalType;
18use crate::dtype::NativeDecimalType;
19use crate::match_each_decimal_value_type;
20use crate::scalar_fn::fns::cast::CastKernel;
21
22impl CastKernel for Decimal {
23 fn cast(
24 array: ArrayView<'_, Decimal>,
25 dtype: &DType,
26 _ctx: &mut ExecutionCtx,
27 ) -> VortexResult<Option<ArrayRef>> {
28 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
30 return Ok(None);
31 };
32 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
33 vortex_panic!(
34 "DecimalArray must have decimal dtype, got {:?}",
35 array.dtype()
36 );
37 };
38
39 if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
41 vortex_bail!(
42 "Casting decimal with scale {} to scale {} not yet implemented",
43 from_decimal_dtype.scale(),
44 to_decimal_dtype.scale()
45 );
46 }
47
48 if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
50 vortex_bail!(
51 "Downcasting decimal from precision {} to {} not yet implemented",
52 from_decimal_dtype.precision(),
53 to_decimal_dtype.precision()
54 );
55 }
56
57 if array.dtype() == dtype {
59 return Ok(Some(array.array().clone()));
60 }
61
62 let new_validity = array
64 .validity()?
65 .cast_nullability(*to_nullability, array.len())?;
66
67 let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
69 let array = if target_values_type > array.values_type() {
70 upcast_decimal_values(array, target_values_type)?
71 } else {
72 array.array().as_::<Decimal>().into_owned()
73 };
74
75 unsafe {
77 Ok(Some(
78 DecimalArray::new_unchecked_handle(
79 array.buffer_handle().clone(),
80 array.values_type(),
81 *to_decimal_dtype,
82 new_validity,
83 )
84 .into_array(),
85 ))
86 }
87 }
88}
89
90pub fn upcast_decimal_values(
102 array: ArrayView<'_, Decimal>,
103 to_values_type: DecimalType,
104) -> VortexResult<DecimalArray> {
105 let from_values_type = array.values_type();
106
107 if from_values_type == to_values_type {
109 return Ok(array.array().as_::<Decimal>().into_owned());
110 }
111
112 if to_values_type < from_values_type {
114 vortex_bail!(
115 "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
116 from_values_type,
117 to_values_type
118 );
119 }
120
121 let decimal_dtype = array.decimal_dtype();
122 let validity = array.validity()?;
123
124 match_each_decimal_value_type!(from_values_type, |F| {
126 let from_buffer = array.buffer::<F>();
127 match_each_decimal_value_type!(to_values_type, |T| {
128 let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
129 Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
130 })
131 })
132}
133
134fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
137 from.iter()
138 .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
139 .collect()
140}
141
142#[cfg(test)]
143mod tests {
144 use rstest::rstest;
145 use vortex_buffer::buffer;
146
147 use super::upcast_decimal_values;
148 use crate::IntoArray;
149 use crate::arrays::DecimalArray;
150 use crate::builtins::ArrayBuiltins;
151 use crate::canonical::ToCanonical;
152 use crate::compute::conformance::cast::test_cast_conformance;
153 use crate::dtype::DType;
154 use crate::dtype::DecimalDType;
155 use crate::dtype::DecimalType;
156 use crate::dtype::Nullability;
157 use crate::validity::Validity;
158
159 #[test]
160 fn cast_decimal_to_nullable() {
161 let decimal_dtype = DecimalDType::new(10, 2);
162 let array = DecimalArray::new(
163 buffer![100i32, 200, 300],
164 decimal_dtype,
165 Validity::NonNullable,
166 );
167
168 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
170 let casted = array
171 .into_array()
172 .cast(nullable_dtype.clone())
173 .unwrap()
174 .to_decimal();
175
176 assert_eq!(casted.dtype(), &nullable_dtype);
177 assert!(matches!(casted.validity(), Ok(Validity::AllValid)));
178 assert_eq!(casted.len(), 3);
179 }
180
181 #[test]
182 fn cast_nullable_to_non_nullable() {
183 let decimal_dtype = DecimalDType::new(10, 2);
184
185 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
187
188 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
190 let casted = array
191 .into_array()
192 .cast(non_nullable_dtype.clone())
193 .unwrap()
194 .to_decimal();
195
196 assert_eq!(casted.dtype(), &non_nullable_dtype);
197 assert!(matches!(casted.validity(), Ok(Validity::NonNullable)));
198 }
199
200 #[test]
201 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
202 fn cast_nullable_with_nulls_to_non_nullable_fails() {
203 let decimal_dtype = DecimalDType::new(10, 2);
204
205 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
207
208 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
210 array
211 .into_array()
212 .cast(non_nullable_dtype)
213 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
214 .unwrap();
215 }
216
217 #[test]
218 fn cast_different_scale_fails() {
219 let array = DecimalArray::new(
220 buffer![100i32],
221 DecimalDType::new(10, 2),
222 Validity::NonNullable,
223 );
224
225 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
227 let result = array
228 .into_array()
229 .cast(different_dtype)
230 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
231
232 assert!(result.is_err());
233 assert!(
234 result
235 .unwrap_err()
236 .to_string()
237 .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
238 );
239 }
240
241 #[test]
242 fn cast_downcast_precision_fails() {
243 let array = DecimalArray::new(
244 buffer![100i64],
245 DecimalDType::new(18, 2),
246 Validity::NonNullable,
247 );
248
249 let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
251 let result = array
252 .into_array()
253 .cast(smaller_dtype)
254 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
255
256 assert!(result.is_err());
257 assert!(
258 result
259 .unwrap_err()
260 .to_string()
261 .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
262 );
263 }
264
265 #[test]
266 fn cast_upcast_precision_succeeds() {
267 let array = DecimalArray::new(
268 buffer![100i32, 200, 300],
269 DecimalDType::new(10, 2),
270 Validity::NonNullable,
271 );
272
273 let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
275 let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
276
277 assert_eq!(casted.precision(), 38);
278 assert_eq!(casted.scale(), 2);
279 assert_eq!(casted.len(), 3);
280 assert_eq!(casted.values_type(), DecimalType::I128);
282 }
283
284 #[test]
285 fn cast_to_non_decimal_returns_err() {
286 let array = DecimalArray::new(
287 buffer![100i32],
288 DecimalDType::new(10, 2),
289 Validity::NonNullable,
290 );
291
292 let result = array
294 .into_array()
295 .cast(DType::Utf8(Nullability::NonNullable))
296 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
297
298 assert!(result.is_err());
299 assert!(
300 result
301 .unwrap_err()
302 .to_string()
303 .contains("No CastKernel to cast canonical array")
304 );
305 }
306
307 #[rstest]
308 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
309 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
310 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
311 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
312 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
313 test_cast_conformance(&array.into_array());
314 }
315
316 #[test]
317 fn upcast_decimal_values_i32_to_i64() {
318 let decimal_dtype = DecimalDType::new(10, 2);
319 let array = DecimalArray::new(
320 buffer![100i32, 200, 300],
321 decimal_dtype,
322 Validity::NonNullable,
323 );
324
325 assert_eq!(array.values_type(), DecimalType::I32);
326
327 let array = array.as_view();
328 let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
329
330 assert_eq!(casted.values_type(), DecimalType::I64);
331 assert_eq!(casted.decimal_dtype(), decimal_dtype);
332 assert_eq!(casted.len(), 3);
333
334 let buffer = casted.buffer::<i64>();
336 assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
337 }
338
339 #[test]
340 fn upcast_decimal_values_i64_to_i128() {
341 let decimal_dtype = DecimalDType::new(18, 4);
342 let array = DecimalArray::new(
343 buffer![10000i64, 20000, 30000],
344 decimal_dtype,
345 Validity::NonNullable,
346 );
347
348 let array = array.as_view();
349 let casted = upcast_decimal_values(array, DecimalType::I128).unwrap();
350
351 assert_eq!(casted.values_type(), DecimalType::I128);
352 assert_eq!(casted.decimal_dtype(), decimal_dtype);
353
354 let buffer = casted.buffer::<i128>();
355 assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
356 }
357
358 #[test]
359 fn upcast_decimal_values_same_type_returns_clone() {
360 let decimal_dtype = DecimalDType::new(10, 2);
361 let array = DecimalArray::new(
362 buffer![100i32, 200, 300],
363 decimal_dtype,
364 Validity::NonNullable,
365 );
366
367 let array = array.as_view();
368 let casted = upcast_decimal_values(array, DecimalType::I32).unwrap();
369
370 assert_eq!(casted.values_type(), DecimalType::I32);
371 assert_eq!(casted.decimal_dtype(), decimal_dtype);
372 }
373
374 #[test]
375 fn upcast_decimal_values_with_nulls() {
376 let decimal_dtype = DecimalDType::new(10, 2);
377 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
378
379 let array = array.as_view();
380 let casted = upcast_decimal_values(array, DecimalType::I64).unwrap();
381
382 assert_eq!(casted.values_type(), DecimalType::I64);
383 assert_eq!(casted.len(), 3);
384
385 let mask = casted.validity_mask().unwrap();
387 assert!(mask.value(0));
388 assert!(!mask.value(1));
389 assert!(mask.value(2));
390
391 let buffer = casted.buffer::<i64>();
393 assert_eq!(buffer[0], 100);
394 assert_eq!(buffer[2], 300);
395 }
396
397 #[test]
398 fn upcast_decimal_values_downcast_fails() {
399 let decimal_dtype = DecimalDType::new(18, 4);
400 let array = DecimalArray::new(
401 buffer![10000i64, 20000, 30000],
402 decimal_dtype,
403 Validity::NonNullable,
404 );
405
406 let array = array.as_view();
408 let result = upcast_decimal_values(array, DecimalType::I32);
409 assert!(result.is_err());
410 assert!(
411 result
412 .unwrap_err()
413 .to_string()
414 .contains("Cannot downcast decimal values")
415 );
416 }
417}