Skip to main content

vortex_array/arrays/decimal/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::max;
5use std::ops::Not;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10
11use super::cast::upcast_decimal_values;
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::BoolArray;
17use crate::arrays::Decimal;
18use crate::arrays::DecimalArray;
19use crate::dtype::NativeDecimalType;
20use crate::match_each_decimal_value_type;
21use crate::scalar::DecimalValue;
22use crate::scalar::Scalar;
23use crate::scalar_fn::fns::fill_null::FillNullKernel;
24use crate::validity::Validity;
25
26impl FillNullKernel for Decimal {
27    fn fill_null(
28        array: ArrayView<'_, Decimal>,
29        fill_value: &Scalar,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let result_validity = Validity::from(fill_value.dtype().nullability());
33
34        Ok(Some(match array.validity()? {
35            Validity::Array(is_valid) => {
36                let is_invalid = is_valid.execute::<BoolArray>(ctx)?.into_bit_buffer().not();
37                let decimal_scalar = fill_value.as_decimal();
38                let decimal_value = decimal_scalar
39                    .decimal_value()
40                    .vortex_expect("fill_null requires a non-null fill value");
41                match_each_decimal_value_type!(array.values_type(), |T| {
42                    fill_invalid_positions::<T>(
43                        array,
44                        &is_invalid,
45                        &decimal_value,
46                        result_validity,
47                    )?
48                })
49            }
50            _ => unreachable!("checked in entry point"),
51        }))
52    }
53}
54
55fn fill_invalid_positions<T: NativeDecimalType>(
56    array: ArrayView<'_, Decimal>,
57    is_invalid: &BitBuffer,
58    decimal_value: &DecimalValue,
59    result_validity: Validity,
60) -> VortexResult<ArrayRef> {
61    match decimal_value.cast::<T>() {
62        Some(fill_val) => fill_buffer::<T>(array, is_invalid, fill_val, result_validity),
63        None => {
64            let target = max(array.values_type(), decimal_value.decimal_type());
65            let upcasted = upcast_decimal_values(array, target)?;
66            match_each_decimal_value_type!(upcasted.values_type(), |U| {
67                let upcasted = upcasted.as_view();
68                fill_invalid_positions::<U>(upcasted, is_invalid, decimal_value, result_validity)
69            })
70        }
71    }
72}
73
74fn fill_buffer<T: NativeDecimalType>(
75    array: ArrayView<'_, Decimal>,
76    is_invalid: &BitBuffer,
77    fill_val: T,
78    result_validity: Validity,
79) -> VortexResult<ArrayRef> {
80    let mut buffer = array.buffer::<T>().into_mut();
81    for invalid_index in is_invalid.set_indices() {
82        buffer[invalid_index] = fill_val;
83    }
84    Ok(DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity).into_array())
85}
86
87#[cfg(test)]
88mod tests {
89    use vortex_buffer::buffer;
90
91    use crate::IntoArray;
92    use crate::VortexSessionExecute;
93    use crate::array_session;
94    use crate::arrays::DecimalArray;
95    use crate::assert_arrays_eq;
96    use crate::builtins::ArrayBuiltins;
97    #[expect(deprecated)]
98    use crate::canonical::ToCanonical as _;
99    use crate::dtype::DecimalDType;
100    use crate::dtype::Nullability;
101    use crate::scalar::DecimalValue;
102    use crate::scalar::Scalar;
103    use crate::validity::Validity;
104
105    #[test]
106    fn fill_null_leading_none() {
107        let mut ctx = array_session().create_execution_ctx();
108        let decimal_dtype = DecimalDType::new(19, 2);
109        let arr = DecimalArray::from_option_iter(
110            [None, Some(800i128), None, Some(1000i128), None],
111            decimal_dtype,
112        );
113        #[expect(deprecated)]
114        let p = arr
115            .into_array()
116            .fill_null(Scalar::decimal(
117                DecimalValue::I128(4200i128),
118                DecimalDType::new(19, 2),
119                Nullability::NonNullable,
120            ))
121            .unwrap()
122            .to_decimal();
123        assert_arrays_eq!(
124            p,
125            DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype),
126            &mut ctx
127        );
128        assert_eq!(
129            p.buffer::<i128>().as_slice(),
130            vec![4200, 800, 4200, 1000, 4200]
131        );
132        assert!(
133            p.as_ref()
134                .validity()
135                .unwrap()
136                .execute_mask(
137                    p.as_ref().len(),
138                    &mut array_session().create_execution_ctx()
139                )
140                .unwrap()
141                .all_true()
142        );
143    }
144
145    #[test]
146    fn fill_null_all_none() {
147        let mut ctx = array_session().create_execution_ctx();
148        let decimal_dtype = DecimalDType::new(19, 2);
149
150        let arr = DecimalArray::from_option_iter(
151            [Option::<i128>::None, None, None, None, None],
152            decimal_dtype,
153        );
154
155        #[expect(deprecated)]
156        let p = arr
157            .into_array()
158            .fill_null(Scalar::decimal(
159                DecimalValue::I128(25500i128),
160                DecimalDType::new(19, 2),
161                Nullability::NonNullable,
162            ))
163            .unwrap()
164            .to_decimal();
165        assert_arrays_eq!(
166            p,
167            DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype),
168            &mut ctx
169        );
170    }
171
172    /// fill_null with a value that overflows the array's storage type should upcast the array.
173    #[test]
174    fn fill_null_overflow_upcasts() {
175        let mut ctx = array_session().create_execution_ctx();
176        let decimal_dtype = DecimalDType::new(3, 0);
177        let arr = DecimalArray::from_option_iter([None, Some(10i8), None], decimal_dtype);
178        // i8 max is 127, so 200 doesn't fit — the array should be widened to i16.
179        #[expect(deprecated)]
180        let result = arr
181            .into_array()
182            .fill_null(Scalar::decimal(
183                DecimalValue::I128(200i128),
184                DecimalDType::new(3, 0),
185                Nullability::NonNullable,
186            ))
187            .unwrap()
188            .to_decimal();
189        assert_arrays_eq!(
190            result,
191            DecimalArray::from_iter([200i16, 10, 200], decimal_dtype),
192            &mut ctx
193        );
194    }
195
196    #[test]
197    fn fill_null_non_nullable() {
198        let mut ctx = array_session().create_execution_ctx();
199        let decimal_dtype = DecimalDType::new(19, 2);
200
201        let arr = DecimalArray::new(
202            buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
203            decimal_dtype,
204            Validity::NonNullable,
205        );
206        #[expect(deprecated)]
207        let p = arr
208            .into_array()
209            .fill_null(Scalar::decimal(
210                DecimalValue::I128(25500i128),
211                DecimalDType::new(19, 2),
212                Nullability::NonNullable,
213            ))
214            .unwrap()
215            .to_decimal();
216        assert_arrays_eq!(
217            p,
218            DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype),
219            &mut ctx
220        );
221    }
222}