vortex_array/arrays/decimal/compute/
cast.rs1use vortex_buffer::Buffer;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7use vortex_error::vortex_bail;
8use vortex_error::vortex_panic;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::DecimalArray;
14use crate::arrays::DecimalVTable;
15use crate::dtype::DType;
16use crate::dtype::DecimalType;
17use crate::dtype::NativeDecimalType;
18use crate::match_each_decimal_value_type;
19use crate::scalar_fn::fns::cast::CastKernel;
20use crate::vtable::ValidityHelper;
21
22impl CastKernel for DecimalVTable {
23 fn cast(
24 array: &DecimalArray,
25 dtype: &DType,
26 _ctx: &mut ExecutionCtx,
27 ) -> VortexResult<Option<ArrayRef>> {
28 let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
30 return Ok(None);
31 };
32 let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
33 vortex_panic!(
34 "DecimalArray must have decimal dtype, got {:?}",
35 array.dtype()
36 );
37 };
38
39 if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
41 vortex_bail!(
42 "Casting decimal with scale {} to scale {} not yet implemented",
43 from_decimal_dtype.scale(),
44 to_decimal_dtype.scale()
45 );
46 }
47
48 if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
50 vortex_bail!(
51 "Downcasting decimal from precision {} to {} not yet implemented",
52 from_decimal_dtype.precision(),
53 to_decimal_dtype.precision()
54 );
55 }
56
57 if array.dtype() == dtype {
59 return Ok(Some(array.clone().into_array()));
60 }
61
62 let new_validity = array
64 .validity()
65 .clone()
66 .cast_nullability(*to_nullability, array.len())?;
67
68 let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
70 let array = if target_values_type > array.values_type() {
71 upcast_decimal_values(array, target_values_type)?
72 } else {
73 array.clone()
74 };
75
76 unsafe {
78 Ok(Some(
79 DecimalArray::new_unchecked_handle(
80 array.buffer_handle().clone(),
81 array.values_type(),
82 *to_decimal_dtype,
83 new_validity,
84 )
85 .into_array(),
86 ))
87 }
88 }
89}
90
91pub fn upcast_decimal_values(
103 array: &DecimalArray,
104 to_values_type: DecimalType,
105) -> VortexResult<DecimalArray> {
106 let from_values_type = array.values_type();
107
108 if from_values_type == to_values_type {
110 return Ok(array.clone());
111 }
112
113 if to_values_type < from_values_type {
115 vortex_bail!(
116 "Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
117 from_values_type,
118 to_values_type
119 );
120 }
121
122 let decimal_dtype = array.decimal_dtype();
123 let validity = array.validity().clone();
124
125 match_each_decimal_value_type!(from_values_type, |F| {
127 let from_buffer = array.buffer::<F>();
128 match_each_decimal_value_type!(to_values_type, |T| {
129 let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
130 Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
131 })
132 })
133}
134
135fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
138 from.iter()
139 .map(|&v| T::from(v).vortex_expect("upcast should never fail"))
140 .collect()
141}
142
143#[cfg(test)]
144mod tests {
145 use rstest::rstest;
146 use vortex_buffer::buffer;
147
148 use super::upcast_decimal_values;
149 use crate::IntoArray;
150 use crate::arrays::DecimalArray;
151 use crate::builtins::ArrayBuiltins;
152 use crate::canonical::ToCanonical;
153 use crate::compute::conformance::cast::test_cast_conformance;
154 use crate::dtype::DType;
155 use crate::dtype::DecimalDType;
156 use crate::dtype::DecimalType;
157 use crate::dtype::Nullability;
158 use crate::validity::Validity;
159 use crate::vtable::ValidityHelper;
160
161 #[test]
162 fn cast_decimal_to_nullable() {
163 let decimal_dtype = DecimalDType::new(10, 2);
164 let array = DecimalArray::new(
165 buffer![100i32, 200, 300],
166 decimal_dtype,
167 Validity::NonNullable,
168 );
169
170 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
172 let casted = array
173 .into_array()
174 .cast(nullable_dtype.clone())
175 .unwrap()
176 .to_decimal();
177
178 assert_eq!(casted.dtype(), &nullable_dtype);
179 assert_eq!(casted.validity(), &Validity::AllValid);
180 assert_eq!(casted.len(), 3);
181 }
182
183 #[test]
184 fn cast_nullable_to_non_nullable() {
185 let decimal_dtype = DecimalDType::new(10, 2);
186
187 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
189
190 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
192 let casted = array
193 .into_array()
194 .cast(non_nullable_dtype.clone())
195 .unwrap()
196 .to_decimal();
197
198 assert_eq!(casted.dtype(), &non_nullable_dtype);
199 assert_eq!(casted.validity(), &Validity::NonNullable);
200 }
201
202 #[test]
203 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
204 fn cast_nullable_with_nulls_to_non_nullable_fails() {
205 let decimal_dtype = DecimalDType::new(10, 2);
206
207 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
209
210 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
212 array
213 .into_array()
214 .cast(non_nullable_dtype)
215 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
216 .unwrap();
217 }
218
219 #[test]
220 fn cast_different_scale_fails() {
221 let array = DecimalArray::new(
222 buffer![100i32],
223 DecimalDType::new(10, 2),
224 Validity::NonNullable,
225 );
226
227 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
229 let result = array
230 .into_array()
231 .cast(different_dtype)
232 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
233
234 assert!(result.is_err());
235 assert!(
236 result
237 .unwrap_err()
238 .to_string()
239 .contains("Casting decimal with scale 2 to scale 3 not yet implemented")
240 );
241 }
242
243 #[test]
244 fn cast_downcast_precision_fails() {
245 let array = DecimalArray::new(
246 buffer![100i64],
247 DecimalDType::new(18, 2),
248 Validity::NonNullable,
249 );
250
251 let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
253 let result = array
254 .into_array()
255 .cast(smaller_dtype)
256 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
257
258 assert!(result.is_err());
259 assert!(
260 result
261 .unwrap_err()
262 .to_string()
263 .contains("Downcasting decimal from precision 18 to 10 not yet implemented")
264 );
265 }
266
267 #[test]
268 fn cast_upcast_precision_succeeds() {
269 let array = DecimalArray::new(
270 buffer![100i32, 200, 300],
271 DecimalDType::new(10, 2),
272 Validity::NonNullable,
273 );
274
275 let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
277 let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
278
279 assert_eq!(casted.precision(), 38);
280 assert_eq!(casted.scale(), 2);
281 assert_eq!(casted.len(), 3);
282 assert_eq!(casted.values_type(), DecimalType::I128);
284 }
285
286 #[test]
287 fn cast_to_non_decimal_returns_err() {
288 let array = DecimalArray::new(
289 buffer![100i32],
290 DecimalDType::new(10, 2),
291 Validity::NonNullable,
292 );
293
294 let result = array
296 .into_array()
297 .cast(DType::Utf8(Nullability::NonNullable))
298 .and_then(|a| a.to_canonical().map(|c| c.into_array()));
299
300 assert!(result.is_err());
301 assert!(
302 result
303 .unwrap_err()
304 .to_string()
305 .contains("No CastKernel to cast canonical array")
306 );
307 }
308
309 #[rstest]
310 #[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
311 #[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
312 #[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
313 #[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
314 fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
315 test_cast_conformance(&array.into_array());
316 }
317
318 #[test]
319 fn upcast_decimal_values_i32_to_i64() {
320 let decimal_dtype = DecimalDType::new(10, 2);
321 let array = DecimalArray::new(
322 buffer![100i32, 200, 300],
323 decimal_dtype,
324 Validity::NonNullable,
325 );
326
327 assert_eq!(array.values_type(), DecimalType::I32);
328
329 let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
330
331 assert_eq!(casted.values_type(), DecimalType::I64);
332 assert_eq!(casted.decimal_dtype(), decimal_dtype);
333 assert_eq!(casted.len(), 3);
334
335 let buffer = casted.buffer::<i64>();
337 assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
338 }
339
340 #[test]
341 fn upcast_decimal_values_i64_to_i128() {
342 let decimal_dtype = DecimalDType::new(18, 4);
343 let array = DecimalArray::new(
344 buffer![10000i64, 20000, 30000],
345 decimal_dtype,
346 Validity::NonNullable,
347 );
348
349 let casted = upcast_decimal_values(&array, DecimalType::I128).unwrap();
350
351 assert_eq!(casted.values_type(), DecimalType::I128);
352 assert_eq!(casted.decimal_dtype(), decimal_dtype);
353
354 let buffer = casted.buffer::<i128>();
355 assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
356 }
357
358 #[test]
359 fn upcast_decimal_values_same_type_returns_clone() {
360 let decimal_dtype = DecimalDType::new(10, 2);
361 let array = DecimalArray::new(
362 buffer![100i32, 200, 300],
363 decimal_dtype,
364 Validity::NonNullable,
365 );
366
367 let casted = upcast_decimal_values(&array, DecimalType::I32).unwrap();
368
369 assert_eq!(casted.values_type(), DecimalType::I32);
370 assert_eq!(casted.decimal_dtype(), decimal_dtype);
371 }
372
373 #[test]
374 fn upcast_decimal_values_with_nulls() {
375 let decimal_dtype = DecimalDType::new(10, 2);
376 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
377
378 let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
379
380 assert_eq!(casted.values_type(), DecimalType::I64);
381 assert_eq!(casted.len(), 3);
382
383 let mask = casted.validity_mask().unwrap();
385 assert!(mask.value(0));
386 assert!(!mask.value(1));
387 assert!(mask.value(2));
388
389 let buffer = casted.buffer::<i64>();
391 assert_eq!(buffer[0], 100);
392 assert_eq!(buffer[2], 300);
393 }
394
395 #[test]
396 fn upcast_decimal_values_downcast_fails() {
397 let decimal_dtype = DecimalDType::new(18, 4);
398 let array = DecimalArray::new(
399 buffer![10000i64, 20000, 30000],
400 decimal_dtype,
401 Validity::NonNullable,
402 );
403
404 let result = upcast_decimal_values(&array, DecimalType::I32);
406 assert!(result.is_err());
407 assert!(
408 result
409 .unwrap_err()
410 .to_string()
411 .contains("Cannot downcast decimal values")
412 );
413 }
414}