vortex_array/arrays/decimal/compute/
fill_null.rs1use std::cmp::max;
5use std::ops::Not;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use super::cast::upcast_decimal_values;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::BoolArray;
17use crate::arrays::Decimal;
18use crate::arrays::DecimalArray;
19use crate::dtype::NativeDecimalType;
20use crate::match_each_decimal_value_type;
21use crate::scalar::DecimalValue;
22use crate::scalar::Scalar;
23use crate::scalar_fn::fns::fill_null::FillNullKernel;
24use crate::validity::Validity;
25
26impl FillNullKernel for Decimal {
27 fn fill_null(
28 array: ArrayView<'_, Decimal>,
29 fill_value: &Scalar,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let result_validity = Validity::from(fill_value.dtype().nullability());
33
34 Ok(Some(match array.validity()? {
35 Validity::Array(is_valid) => {
36 let is_invalid = is_valid.execute::<BoolArray>(ctx)?.into_bit_buffer().not();
37 let decimal_scalar = fill_value.as_decimal();
38 let decimal_value = decimal_scalar
39 .decimal_value()
40 .vortex_expect("fill_null requires a non-null fill value");
41 match_each_decimal_value_type!(array.values_type(), |T| {
42 fill_invalid_positions::<T>(
43 array,
44 &is_invalid,
45 &decimal_value,
46 result_validity,
47 )?
48 })
49 }
50 _ => unreachable!("checked in entry point"),
51 }))
52 }
53}
54
55fn fill_invalid_positions<T: NativeDecimalType>(
56 array: ArrayView<'_, Decimal>,
57 is_invalid: &BitBuffer,
58 decimal_value: &DecimalValue,
59 result_validity: Validity,
60) -> VortexResult<ArrayRef> {
61 match decimal_value.cast::<T>() {
62 Some(fill_val) => fill_buffer::<T>(array, is_invalid, fill_val, result_validity),
63 None => {
64 let target = max(array.values_type(), decimal_value.decimal_type());
65 let upcasted = upcast_decimal_values(array, target)?;
66 match_each_decimal_value_type!(upcasted.values_type(), |U| {
67 let upcasted = upcasted.as_view();
68 fill_invalid_positions::<U>(upcasted, is_invalid, decimal_value, result_validity)
69 })
70 }
71 }
72}
73
74fn fill_buffer<T: NativeDecimalType>(
75 array: ArrayView<'_, Decimal>,
76 is_invalid: &BitBuffer,
77 fill_val: T,
78 result_validity: Validity,
79) -> VortexResult<ArrayRef> {
80 let mut buffer = array.buffer::<T>().into_mut();
81 for invalid_index in is_invalid.set_indices() {
82 buffer[invalid_index] = fill_val;
83 }
84 Ok(DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity).into_array())
85}
86
87#[cfg(test)]
88mod tests {
89 use vortex_buffer::buffer;
90
91 use crate::IntoArray;
92 use crate::VortexSessionExecute;
93 use crate::array_session;
94 use crate::arrays::DecimalArray;
95 use crate::assert_arrays_eq;
96 use crate::builtins::ArrayBuiltins;
97 #[expect(deprecated)]
98 use crate::canonical::ToCanonical as _;
99 use crate::dtype::DecimalDType;
100 use crate::dtype::Nullability;
101 use crate::scalar::DecimalValue;
102 use crate::scalar::Scalar;
103 use crate::validity::Validity;
104
105 #[test]
106 fn fill_null_leading_none() {
107 let mut ctx = array_session().create_execution_ctx();
108 let decimal_dtype = DecimalDType::new(19, 2);
109 let arr = DecimalArray::from_option_iter(
110 [None, Some(800i128), None, Some(1000i128), None],
111 decimal_dtype,
112 );
113 #[expect(deprecated)]
114 let p = arr
115 .into_array()
116 .fill_null(Scalar::decimal(
117 DecimalValue::I128(4200i128),
118 DecimalDType::new(19, 2),
119 Nullability::NonNullable,
120 ))
121 .unwrap()
122 .to_decimal();
123 assert_arrays_eq!(
124 p,
125 DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype),
126 &mut ctx
127 );
128 assert_eq!(
129 p.buffer::<i128>().as_slice(),
130 vec![4200, 800, 4200, 1000, 4200]
131 );
132 assert!(
133 p.as_ref()
134 .validity()
135 .unwrap()
136 .execute_mask(
137 p.as_ref().len(),
138 &mut array_session().create_execution_ctx()
139 )
140 .unwrap()
141 .all_true()
142 );
143 }
144
145 #[test]
146 fn fill_null_all_none() {
147 let mut ctx = array_session().create_execution_ctx();
148 let decimal_dtype = DecimalDType::new(19, 2);
149
150 let arr = DecimalArray::from_option_iter(
151 [Option::<i128>::None, None, None, None, None],
152 decimal_dtype,
153 );
154
155 #[expect(deprecated)]
156 let p = arr
157 .into_array()
158 .fill_null(Scalar::decimal(
159 DecimalValue::I128(25500i128),
160 DecimalDType::new(19, 2),
161 Nullability::NonNullable,
162 ))
163 .unwrap()
164 .to_decimal();
165 assert_arrays_eq!(
166 p,
167 DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype),
168 &mut ctx
169 );
170 }
171
172 #[test]
174 fn fill_null_overflow_upcasts() {
175 let mut ctx = array_session().create_execution_ctx();
176 let decimal_dtype = DecimalDType::new(3, 0);
177 let arr = DecimalArray::from_option_iter([None, Some(10i8), None], decimal_dtype);
178 #[expect(deprecated)]
180 let result = arr
181 .into_array()
182 .fill_null(Scalar::decimal(
183 DecimalValue::I128(200i128),
184 DecimalDType::new(3, 0),
185 Nullability::NonNullable,
186 ))
187 .unwrap()
188 .to_decimal();
189 assert_arrays_eq!(
190 result,
191 DecimalArray::from_iter([200i16, 10, 200], decimal_dtype),
192 &mut ctx
193 );
194 }
195
196 #[test]
197 fn fill_null_non_nullable() {
198 let mut ctx = array_session().create_execution_ctx();
199 let decimal_dtype = DecimalDType::new(19, 2);
200
201 let arr = DecimalArray::new(
202 buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
203 decimal_dtype,
204 Validity::NonNullable,
205 );
206 #[expect(deprecated)]
207 let p = arr
208 .into_array()
209 .fill_null(Scalar::decimal(
210 DecimalValue::I128(25500i128),
211 DecimalDType::new(19, 2),
212 Nullability::NonNullable,
213 ))
214 .unwrap()
215 .to_decimal();
216 assert_arrays_eq!(
217 p,
218 DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype),
219 &mut ctx
220 );
221 }
222}