vortex_array/arrays/decimal/compute/
cast.rs1use vortex_dtype::DType;
5use vortex_error::{VortexResult, vortex_bail, vortex_panic};
6
7use crate::arrays::{DecimalArray, DecimalVTable};
8use crate::compute::{CastKernel, CastKernelAdapter};
9use crate::stats::ArrayStats;
10use crate::vtable::ValidityHelper;
11use crate::{ArrayRef, register_kernel};
12
13impl CastKernel for DecimalVTable {
14 fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15 let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
17 return Ok(None);
18 };
19 let DType::Decimal(from_precision_scale, _) = array.dtype() else {
20 vortex_panic!(
21 "DecimalArray must have decimal dtype, got {:?}",
22 array.dtype()
23 );
24 };
25
26 if from_precision_scale != to_precision_scale {
28 vortex_bail!(
29 "Cannot cast decimal({},{}) to decimal({},{})",
30 from_precision_scale.precision(),
31 from_precision_scale.scale(),
32 to_precision_scale.precision(),
33 to_precision_scale.scale()
34 );
35 }
36
37 if array.dtype() == dtype {
39 return Ok(Some(array.to_array()));
40 }
41
42 let new_validity = array.validity().clone().cast_nullability(*to_nullability)?;
44
45 Ok(Some(
47 DecimalArray {
48 dtype: DType::Decimal(*from_precision_scale, *to_nullability),
49 values: array.byte_buffer(),
50 values_type: array.values_type(),
51 validity: new_validity,
52 stats_set: ArrayStats::default(),
53 }
54 .to_array(),
55 ))
56 }
57}
58
59register_kernel!(CastKernelAdapter(DecimalVTable).lift());
60
61#[cfg(test)]
62mod tests {
63 use vortex_buffer::buffer;
64 use vortex_dtype::{DType, DecimalDType, Nullability};
65
66 use crate::arrays::DecimalArray;
67 use crate::canonical::ToCanonical;
68 use crate::compute::cast;
69 use crate::validity::Validity;
70 use crate::vtable::ValidityHelper;
71
72 #[test]
73 fn cast_decimal_to_nullable() {
74 let decimal_dtype = DecimalDType::new(10, 2);
75 let array = DecimalArray::new(
76 buffer![100i32, 200, 300],
77 decimal_dtype,
78 Validity::NonNullable,
79 );
80
81 let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
83 let casted = cast(array.as_ref(), &nullable_dtype)
84 .unwrap()
85 .to_decimal()
86 .unwrap();
87
88 assert_eq!(casted.dtype(), &nullable_dtype);
89 assert_eq!(casted.validity(), &Validity::AllValid);
90 assert_eq!(casted.len(), 3);
91 }
92
93 #[test]
94 fn cast_nullable_to_non_nullable() {
95 let decimal_dtype = DecimalDType::new(10, 2);
96
97 let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
99
100 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
102 let casted = cast(array.as_ref(), &non_nullable_dtype)
103 .unwrap()
104 .to_decimal()
105 .unwrap();
106
107 assert_eq!(casted.dtype(), &non_nullable_dtype);
108 assert_eq!(casted.validity(), &Validity::NonNullable);
109 }
110
111 #[test]
112 #[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
113 fn cast_nullable_with_nulls_to_non_nullable_fails() {
114 let decimal_dtype = DecimalDType::new(10, 2);
115
116 let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
118
119 let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
121 cast(array.as_ref(), &non_nullable_dtype).unwrap();
122 }
123
124 #[test]
125 fn cast_different_precision_fails() {
126 let array = DecimalArray::new(
127 buffer![100i32],
128 DecimalDType::new(10, 2),
129 Validity::NonNullable,
130 );
131
132 let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
134 let result = cast(array.as_ref(), &different_dtype);
135
136 assert!(result.is_err());
137 assert!(
138 result
139 .unwrap_err()
140 .to_string()
141 .contains("Cannot cast decimal(10,2) to decimal(15,3)")
142 );
143 }
144
145 #[test]
146 fn cast_to_non_decimal_returns_err() {
147 let array = DecimalArray::new(
148 buffer![100i32],
149 DecimalDType::new(10, 2),
150 Validity::NonNullable,
151 );
152
153 let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
155
156 assert!(result.is_err());
157 assert!(
158 result
159 .unwrap_err()
160 .to_string()
161 .contains("No compute kernel to cast")
162 );
163 }
164}