Skip to main content

vortex_array/arrays/decimal/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::max;
5use std::ops::Not;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use super::cast::upcast_decimal_values;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::BoolArray;
16use crate::arrays::DecimalArray;
17use crate::arrays::DecimalVTable;
18use crate::dtype::NativeDecimalType;
19use crate::match_each_decimal_value_type;
20use crate::scalar::DecimalValue;
21use crate::scalar::Scalar;
22use crate::scalar_fn::fns::fill_null::FillNullKernel;
23use crate::validity::Validity;
24use crate::vtable::ValidityHelper;
25
26impl FillNullKernel for DecimalVTable {
27    fn fill_null(
28        array: &DecimalArray,
29        fill_value: &Scalar,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let result_validity = Validity::from(fill_value.dtype().nullability());
33
34        Ok(Some(match array.validity() {
35            Validity::Array(is_valid) => {
36                let is_invalid = is_valid
37                    .clone()
38                    .execute::<BoolArray>(ctx)?
39                    .to_bit_buffer()
40                    .not();
41                let decimal_scalar = fill_value.as_decimal();
42                let decimal_value = decimal_scalar
43                    .decimal_value()
44                    .vortex_expect("fill_null requires a non-null fill value");
45                match_each_decimal_value_type!(array.values_type(), |T| {
46                    fill_invalid_positions::<T>(
47                        array,
48                        &is_invalid,
49                        &decimal_value,
50                        result_validity,
51                    )?
52                })
53            }
54            _ => unreachable!("checked in entry point"),
55        }))
56    }
57}
58
59fn fill_invalid_positions<T: NativeDecimalType>(
60    array: &DecimalArray,
61    is_invalid: &BitBuffer,
62    decimal_value: &DecimalValue,
63    result_validity: Validity,
64) -> VortexResult<ArrayRef> {
65    match decimal_value.cast::<T>() {
66        Some(fill_val) => fill_buffer::<T>(array, is_invalid, fill_val, result_validity),
67        None => {
68            let target = max(array.values_type(), decimal_value.decimal_type());
69            let upcasted = upcast_decimal_values(array, target)?;
70            match_each_decimal_value_type!(upcasted.values_type(), |U| {
71                fill_invalid_positions::<U>(&upcasted, is_invalid, decimal_value, result_validity)
72            })
73        }
74    }
75}
76
77fn fill_buffer<T: NativeDecimalType>(
78    array: &DecimalArray,
79    is_invalid: &BitBuffer,
80    fill_val: T,
81    result_validity: Validity,
82) -> VortexResult<ArrayRef> {
83    let mut buffer = array.buffer::<T>().into_mut();
84    for invalid_index in is_invalid.set_indices() {
85        buffer[invalid_index] = fill_val;
86    }
87    Ok(DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity).into_array())
88}
89
90#[cfg(test)]
91mod tests {
92    use vortex_buffer::buffer;
93
94    use crate::IntoArray;
95    use crate::arrays::DecimalArray;
96    use crate::assert_arrays_eq;
97    use crate::builtins::ArrayBuiltins;
98    use crate::canonical::ToCanonical;
99    use crate::dtype::DecimalDType;
100    use crate::dtype::Nullability;
101    use crate::scalar::DecimalValue;
102    use crate::scalar::Scalar;
103    use crate::validity::Validity;
104
105    #[test]
106    fn fill_null_leading_none() {
107        let decimal_dtype = DecimalDType::new(19, 2);
108        let arr = DecimalArray::from_option_iter(
109            [None, Some(800i128), None, Some(1000i128), None],
110            decimal_dtype,
111        );
112        let p = arr
113            .into_array()
114            .fill_null(Scalar::decimal(
115                DecimalValue::I128(4200i128),
116                DecimalDType::new(19, 2),
117                Nullability::NonNullable,
118            ))
119            .unwrap()
120            .to_decimal();
121        assert_arrays_eq!(
122            p,
123            DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
124        );
125        assert_eq!(
126            p.buffer::<i128>().as_slice(),
127            vec![4200, 800, 4200, 1000, 4200]
128        );
129        assert!(p.validity_mask().unwrap().all_true());
130    }
131
132    #[test]
133    fn fill_null_all_none() {
134        let decimal_dtype = DecimalDType::new(19, 2);
135
136        let arr = DecimalArray::from_option_iter(
137            [Option::<i128>::None, None, None, None, None],
138            decimal_dtype,
139        );
140
141        let p = arr
142            .into_array()
143            .fill_null(Scalar::decimal(
144                DecimalValue::I128(25500i128),
145                DecimalDType::new(19, 2),
146                Nullability::NonNullable,
147            ))
148            .unwrap()
149            .to_decimal();
150        assert_arrays_eq!(
151            p,
152            DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
153        );
154    }
155
156    /// fill_null with a value that overflows the array's storage type should upcast the array.
157    #[test]
158    fn fill_null_overflow_upcasts() {
159        let decimal_dtype = DecimalDType::new(3, 0);
160        let arr = DecimalArray::from_option_iter([None, Some(10i8), None], decimal_dtype);
161        // i8 max is 127, so 200 doesn't fit — the array should be widened to i16.
162        let result = arr
163            .into_array()
164            .fill_null(Scalar::decimal(
165                DecimalValue::I128(200i128),
166                DecimalDType::new(3, 0),
167                Nullability::NonNullable,
168            ))
169            .unwrap()
170            .to_decimal();
171        assert_arrays_eq!(
172            result,
173            DecimalArray::from_iter([200i16, 10, 200], decimal_dtype)
174        );
175    }
176
177    #[test]
178    fn fill_null_non_nullable() {
179        let decimal_dtype = DecimalDType::new(19, 2);
180
181        let arr = DecimalArray::new(
182            buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
183            decimal_dtype,
184            Validity::NonNullable,
185        );
186        let p = arr
187            .into_array()
188            .fill_null(Scalar::decimal(
189                DecimalValue::I128(25500i128),
190                DecimalDType::new(19, 2),
191                Nullability::NonNullable,
192            ))
193            .unwrap()
194            .to_decimal();
195        assert_arrays_eq!(
196            p,
197            DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
198        );
199    }
200}