vortex_array/arrays/decimal/compute/
fill_null.rs1use std::cmp::max;
5use std::ops::Not;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use super::cast::upcast_decimal_values;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::dtype::NativeDecimalType;
19use crate::match_each_decimal_value_type;
20use crate::scalar::DecimalValue;
21use crate::scalar::Scalar;
22use crate::scalar_fn::fns::fill_null::FillNullKernel;
23use crate::validity::Validity;
24use crate::vtable::ValidityHelper;
25
26impl FillNullKernel for DecimalVTable {
27 fn fill_null(
28 array: &DecimalArray,
29 fill_value: &Scalar,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let result_validity = Validity::from(fill_value.dtype().nullability());
33
34 Ok(Some(match array.validity() {
35 Validity::Array(is_valid) => {
36 let is_invalid = is_valid
37 .clone()
38 .execute::<BoolArray>(ctx)?
39 .to_bit_buffer()
40 .not();
41 let decimal_scalar = fill_value.as_decimal();
42 let decimal_value = decimal_scalar
43 .decimal_value()
44 .vortex_expect("fill_null requires a non-null fill value");
45 match_each_decimal_value_type!(array.values_type(), |T| {
46 fill_invalid_positions::<T>(
47 array,
48 &is_invalid,
49 &decimal_value,
50 result_validity,
51 )?
52 })
53 }
54 _ => unreachable!("checked in entry point"),
55 }))
56 }
57}
58
59fn fill_invalid_positions<T: NativeDecimalType>(
60 array: &DecimalArray,
61 is_invalid: &BitBuffer,
62 decimal_value: &DecimalValue,
63 result_validity: Validity,
64) -> VortexResult<ArrayRef> {
65 match decimal_value.cast::<T>() {
66 Some(fill_val) => fill_buffer::<T>(array, is_invalid, fill_val, result_validity),
67 None => {
68 let target = max(array.values_type(), decimal_value.decimal_type());
69 let upcasted = upcast_decimal_values(array, target)?;
70 match_each_decimal_value_type!(upcasted.values_type(), |U| {
71 fill_invalid_positions::<U>(&upcasted, is_invalid, decimal_value, result_validity)
72 })
73 }
74 }
75}
76
77fn fill_buffer<T: NativeDecimalType>(
78 array: &DecimalArray,
79 is_invalid: &BitBuffer,
80 fill_val: T,
81 result_validity: Validity,
82) -> VortexResult<ArrayRef> {
83 let mut buffer = array.buffer::<T>().into_mut();
84 for invalid_index in is_invalid.set_indices() {
85 buffer[invalid_index] = fill_val;
86 }
87 Ok(DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity).into_array())
88}
89
90#[cfg(test)]
91mod tests {
92 use vortex_buffer::buffer;
93
94 use crate::IntoArray;
95 use crate::arrays::DecimalArray;
96 use crate::assert_arrays_eq;
97 use crate::builtins::ArrayBuiltins;
98 use crate::canonical::ToCanonical;
99 use crate::dtype::DecimalDType;
100 use crate::dtype::Nullability;
101 use crate::scalar::DecimalValue;
102 use crate::scalar::Scalar;
103 use crate::validity::Validity;
104
105 #[test]
106 fn fill_null_leading_none() {
107 let decimal_dtype = DecimalDType::new(19, 2);
108 let arr = DecimalArray::from_option_iter(
109 [None, Some(800i128), None, Some(1000i128), None],
110 decimal_dtype,
111 );
112 let p = arr
113 .into_array()
114 .fill_null(Scalar::decimal(
115 DecimalValue::I128(4200i128),
116 DecimalDType::new(19, 2),
117 Nullability::NonNullable,
118 ))
119 .unwrap()
120 .to_decimal();
121 assert_arrays_eq!(
122 p,
123 DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
124 );
125 assert_eq!(
126 p.buffer::<i128>().as_slice(),
127 vec![4200, 800, 4200, 1000, 4200]
128 );
129 assert!(p.validity_mask().unwrap().all_true());
130 }
131
132 #[test]
133 fn fill_null_all_none() {
134 let decimal_dtype = DecimalDType::new(19, 2);
135
136 let arr = DecimalArray::from_option_iter(
137 [Option::<i128>::None, None, None, None, None],
138 decimal_dtype,
139 );
140
141 let p = arr
142 .into_array()
143 .fill_null(Scalar::decimal(
144 DecimalValue::I128(25500i128),
145 DecimalDType::new(19, 2),
146 Nullability::NonNullable,
147 ))
148 .unwrap()
149 .to_decimal();
150 assert_arrays_eq!(
151 p,
152 DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
153 );
154 }
155
156 #[test]
158 fn fill_null_overflow_upcasts() {
159 let decimal_dtype = DecimalDType::new(3, 0);
160 let arr = DecimalArray::from_option_iter([None, Some(10i8), None], decimal_dtype);
161 let result = arr
163 .into_array()
164 .fill_null(Scalar::decimal(
165 DecimalValue::I128(200i128),
166 DecimalDType::new(3, 0),
167 Nullability::NonNullable,
168 ))
169 .unwrap()
170 .to_decimal();
171 assert_arrays_eq!(
172 result,
173 DecimalArray::from_iter([200i16, 10, 200], decimal_dtype)
174 );
175 }
176
177 #[test]
178 fn fill_null_non_nullable() {
179 let decimal_dtype = DecimalDType::new(19, 2);
180
181 let arr = DecimalArray::new(
182 buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
183 decimal_dtype,
184 Validity::NonNullable,
185 );
186 let p = arr
187 .into_array()
188 .fill_null(Scalar::decimal(
189 DecimalValue::I128(25500i128),
190 DecimalDType::new(19, 2),
191 Nullability::NonNullable,
192 ))
193 .unwrap()
194 .to_decimal();
195 assert_arrays_eq!(
196 p,
197 DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
198 );
199 }
200}