Skip to main content

vortex_array/arrays/decimal/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::max;
5use std::ops::Not;
6
7use vortex_buffer::BitBuffer;
8use vortex_dtype::NativeDecimalType;
9use vortex_dtype::match_each_decimal_value_type;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12
13use super::cast::upcast_decimal_values;
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::ToCanonical;
18use crate::arrays::DecimalVTable;
19use crate::arrays::decimal::DecimalArray;
20use crate::expr::FillNullKernel;
21use crate::scalar::DecimalValue;
22use crate::scalar::Scalar;
23use crate::validity::Validity;
24use crate::vtable::ValidityHelper;
25
26impl FillNullKernel for DecimalVTable {
27    fn fill_null(
28        array: &DecimalArray,
29        fill_value: &Scalar,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let result_validity = Validity::from(fill_value.dtype().nullability());
33
34        Ok(Some(match array.validity() {
35            Validity::Array(is_valid) => {
36                let is_invalid = is_valid.to_bool().to_bit_buffer().not();
37                let decimal_scalar = fill_value.as_decimal();
38                let decimal_value = decimal_scalar
39                    .decimal_value()
40                    .vortex_expect("fill_null requires a non-null fill value");
41                match_each_decimal_value_type!(array.values_type(), |T| {
42                    fill_invalid_positions::<T>(
43                        array,
44                        &is_invalid,
45                        &decimal_value,
46                        result_validity,
47                    )?
48                })
49            }
50            _ => unreachable!("checked in entry point"),
51        }))
52    }
53}
54
55fn fill_invalid_positions<T: NativeDecimalType>(
56    array: &DecimalArray,
57    is_invalid: &BitBuffer,
58    decimal_value: &DecimalValue,
59    result_validity: Validity,
60) -> VortexResult<ArrayRef> {
61    match decimal_value.cast::<T>() {
62        Some(fill_val) => fill_buffer::<T>(array, is_invalid, fill_val, result_validity),
63        None => {
64            let target = max(array.values_type(), decimal_value.decimal_type());
65            let upcasted = upcast_decimal_values(array, target)?;
66            match_each_decimal_value_type!(upcasted.values_type(), |U| {
67                fill_invalid_positions::<U>(&upcasted, is_invalid, decimal_value, result_validity)
68            })
69        }
70    }
71}
72
73fn fill_buffer<T: NativeDecimalType>(
74    array: &DecimalArray,
75    is_invalid: &BitBuffer,
76    fill_val: T,
77    result_validity: Validity,
78) -> VortexResult<ArrayRef> {
79    let mut buffer = array.buffer::<T>().into_mut();
80    for invalid_index in is_invalid.set_indices() {
81        buffer[invalid_index] = fill_val;
82    }
83    Ok(DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity).into_array())
84}
85
86#[cfg(test)]
87mod tests {
88    use vortex_buffer::buffer;
89    use vortex_dtype::DecimalDType;
90    use vortex_dtype::Nullability;
91
92    use crate::arrays::decimal::DecimalArray;
93    use crate::assert_arrays_eq;
94    use crate::builtins::ArrayBuiltins;
95    use crate::canonical::ToCanonical;
96    use crate::scalar::DecimalValue;
97    use crate::scalar::Scalar;
98    use crate::validity::Validity;
99
100    #[test]
101    fn fill_null_leading_none() {
102        let decimal_dtype = DecimalDType::new(19, 2);
103        let arr = DecimalArray::from_option_iter(
104            [None, Some(800i128), None, Some(1000i128), None],
105            decimal_dtype,
106        );
107        let p = arr
108            .to_array()
109            .fill_null(Scalar::decimal(
110                DecimalValue::I128(4200i128),
111                DecimalDType::new(19, 2),
112                Nullability::NonNullable,
113            ))
114            .unwrap()
115            .to_decimal();
116        assert_arrays_eq!(
117            p,
118            DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
119        );
120        assert_eq!(
121            p.buffer::<i128>().as_slice(),
122            vec![4200, 800, 4200, 1000, 4200]
123        );
124        assert!(p.validity_mask().unwrap().all_true());
125    }
126
127    #[test]
128    fn fill_null_all_none() {
129        let decimal_dtype = DecimalDType::new(19, 2);
130
131        let arr = DecimalArray::from_option_iter(
132            [Option::<i128>::None, None, None, None, None],
133            decimal_dtype,
134        );
135
136        let p = arr
137            .to_array()
138            .fill_null(Scalar::decimal(
139                DecimalValue::I128(25500i128),
140                DecimalDType::new(19, 2),
141                Nullability::NonNullable,
142            ))
143            .unwrap()
144            .to_decimal();
145        assert_arrays_eq!(
146            p,
147            DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
148        );
149    }
150
151    /// fill_null with a value that overflows the array's storage type should upcast the array.
152    #[test]
153    fn fill_null_overflow_upcasts() {
154        let decimal_dtype = DecimalDType::new(3, 0);
155        let arr = DecimalArray::from_option_iter([None, Some(10i8), None], decimal_dtype);
156        // i8 max is 127, so 200 doesn't fit — the array should be widened to i16.
157        let result = arr
158            .to_array()
159            .fill_null(Scalar::decimal(
160                DecimalValue::I128(200i128),
161                DecimalDType::new(3, 0),
162                Nullability::NonNullable,
163            ))
164            .unwrap()
165            .to_decimal();
166        assert_arrays_eq!(
167            result,
168            DecimalArray::from_iter([200i16, 10, 200], decimal_dtype)
169        );
170    }
171
172    #[test]
173    fn fill_null_non_nullable() {
174        let decimal_dtype = DecimalDType::new(19, 2);
175
176        let arr = DecimalArray::new(
177            buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
178            decimal_dtype,
179            Validity::NonNullable,
180        );
181        let p = arr
182            .to_array()
183            .fill_null(Scalar::decimal(
184                DecimalValue::I128(25500i128),
185                DecimalDType::new(19, 2),
186                Nullability::NonNullable,
187            ))
188            .unwrap()
189            .to_decimal();
190        assert_arrays_eq!(
191            p,
192            DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
193        );
194    }
195}