vortex_array/arrays/decimal/compute/
fill_null.rs1use std::cmp::max;
5use std::ops::Not;
6
7use vortex_buffer::BitBuffer;
8use vortex_dtype::NativeDecimalType;
9use vortex_dtype::match_each_decimal_value_type;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12
13use super::cast::upcast_decimal_values;
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::ToCanonical;
18use crate::arrays::DecimalVTable;
19use crate::arrays::decimal::DecimalArray;
20use crate::expr::FillNullKernel;
21use crate::scalar::DecimalValue;
22use crate::scalar::Scalar;
23use crate::validity::Validity;
24use crate::vtable::ValidityHelper;
25
26impl FillNullKernel for DecimalVTable {
27 fn fill_null(
28 array: &DecimalArray,
29 fill_value: &Scalar,
30 _ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let result_validity = Validity::from(fill_value.dtype().nullability());
33
34 Ok(Some(match array.validity() {
35 Validity::Array(is_valid) => {
36 let is_invalid = is_valid.to_bool().to_bit_buffer().not();
37 let decimal_scalar = fill_value.as_decimal();
38 let decimal_value = decimal_scalar
39 .decimal_value()
40 .vortex_expect("fill_null requires a non-null fill value");
41 match_each_decimal_value_type!(array.values_type(), |T| {
42 fill_invalid_positions::<T>(
43 array,
44 &is_invalid,
45 &decimal_value,
46 result_validity,
47 )?
48 })
49 }
50 _ => unreachable!("checked in entry point"),
51 }))
52 }
53}
54
55fn fill_invalid_positions<T: NativeDecimalType>(
56 array: &DecimalArray,
57 is_invalid: &BitBuffer,
58 decimal_value: &DecimalValue,
59 result_validity: Validity,
60) -> VortexResult<ArrayRef> {
61 match decimal_value.cast::<T>() {
62 Some(fill_val) => fill_buffer::<T>(array, is_invalid, fill_val, result_validity),
63 None => {
64 let target = max(array.values_type(), decimal_value.decimal_type());
65 let upcasted = upcast_decimal_values(array, target)?;
66 match_each_decimal_value_type!(upcasted.values_type(), |U| {
67 fill_invalid_positions::<U>(&upcasted, is_invalid, decimal_value, result_validity)
68 })
69 }
70 }
71}
72
73fn fill_buffer<T: NativeDecimalType>(
74 array: &DecimalArray,
75 is_invalid: &BitBuffer,
76 fill_val: T,
77 result_validity: Validity,
78) -> VortexResult<ArrayRef> {
79 let mut buffer = array.buffer::<T>().into_mut();
80 for invalid_index in is_invalid.set_indices() {
81 buffer[invalid_index] = fill_val;
82 }
83 Ok(DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity).into_array())
84}
85
86#[cfg(test)]
87mod tests {
88 use vortex_buffer::buffer;
89 use vortex_dtype::DecimalDType;
90 use vortex_dtype::Nullability;
91
92 use crate::arrays::decimal::DecimalArray;
93 use crate::assert_arrays_eq;
94 use crate::builtins::ArrayBuiltins;
95 use crate::canonical::ToCanonical;
96 use crate::scalar::DecimalValue;
97 use crate::scalar::Scalar;
98 use crate::validity::Validity;
99
100 #[test]
101 fn fill_null_leading_none() {
102 let decimal_dtype = DecimalDType::new(19, 2);
103 let arr = DecimalArray::from_option_iter(
104 [None, Some(800i128), None, Some(1000i128), None],
105 decimal_dtype,
106 );
107 let p = arr
108 .to_array()
109 .fill_null(Scalar::decimal(
110 DecimalValue::I128(4200i128),
111 DecimalDType::new(19, 2),
112 Nullability::NonNullable,
113 ))
114 .unwrap()
115 .to_decimal();
116 assert_arrays_eq!(
117 p,
118 DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
119 );
120 assert_eq!(
121 p.buffer::<i128>().as_slice(),
122 vec![4200, 800, 4200, 1000, 4200]
123 );
124 assert!(p.validity_mask().unwrap().all_true());
125 }
126
127 #[test]
128 fn fill_null_all_none() {
129 let decimal_dtype = DecimalDType::new(19, 2);
130
131 let arr = DecimalArray::from_option_iter(
132 [Option::<i128>::None, None, None, None, None],
133 decimal_dtype,
134 );
135
136 let p = arr
137 .to_array()
138 .fill_null(Scalar::decimal(
139 DecimalValue::I128(25500i128),
140 DecimalDType::new(19, 2),
141 Nullability::NonNullable,
142 ))
143 .unwrap()
144 .to_decimal();
145 assert_arrays_eq!(
146 p,
147 DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
148 );
149 }
150
151 #[test]
153 fn fill_null_overflow_upcasts() {
154 let decimal_dtype = DecimalDType::new(3, 0);
155 let arr = DecimalArray::from_option_iter([None, Some(10i8), None], decimal_dtype);
156 let result = arr
158 .to_array()
159 .fill_null(Scalar::decimal(
160 DecimalValue::I128(200i128),
161 DecimalDType::new(3, 0),
162 Nullability::NonNullable,
163 ))
164 .unwrap()
165 .to_decimal();
166 assert_arrays_eq!(
167 result,
168 DecimalArray::from_iter([200i16, 10, 200], decimal_dtype)
169 );
170 }
171
172 #[test]
173 fn fill_null_non_nullable() {
174 let decimal_dtype = DecimalDType::new(19, 2);
175
176 let arr = DecimalArray::new(
177 buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
178 decimal_dtype,
179 Validity::NonNullable,
180 );
181 let p = arr
182 .to_array()
183 .fill_null(Scalar::decimal(
184 DecimalValue::I128(25500i128),
185 DecimalDType::new(19, 2),
186 Nullability::NonNullable,
187 ))
188 .unwrap()
189 .to_decimal();
190 assert_arrays_eq!(
191 p,
192 DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
193 );
194 }
195}