vortex_array/arrays/decimal/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5
6use vortex_dtype::match_each_decimal_value_type;
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_scalar::Scalar;
9
10use crate::arrays::DecimalVTable;
11use crate::arrays::decimal::DecimalArray;
12use crate::compute::{FillNullKernel, FillNullKernelAdapter};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl FillNullKernel for DecimalVTable {
18    fn fill_null(&self, array: &DecimalArray, fill_value: &Scalar) -> VortexResult<ArrayRef> {
19        let result_validity = Validity::from(fill_value.dtype().nullability());
20
21        Ok(match array.validity() {
22            Validity::Array(is_valid) => {
23                let is_invalid = is_valid.to_bool().bit_buffer().not();
24                match_each_decimal_value_type!(array.values_type(), |T| {
25                    let mut buffer = array.buffer::<T>().into_mut();
26                    let fill_value = fill_value
27                        .as_decimal()
28                        .decimal_value()
29                        .and_then(|v| v.cast::<T>())
30                        .vortex_expect("top-level fill_null ensure non-null fill value");
31                    for invalid_index in is_invalid.set_indices() {
32                        buffer[invalid_index] = fill_value;
33                    }
34                    DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity)
35                        .into_array()
36                })
37            }
38            _ => unreachable!("checked in entry point"),
39        })
40    }
41}
42
43register_kernel!(FillNullKernelAdapter(DecimalVTable).lift());
44
45#[cfg(test)]
46mod tests {
47    use vortex_buffer::buffer;
48    use vortex_dtype::{DecimalDType, Nullability};
49    use vortex_scalar::{DecimalValue, Scalar};
50
51    use crate::arrays::decimal::DecimalArray;
52    use crate::assert_arrays_eq;
53    use crate::canonical::ToCanonical;
54    use crate::compute::fill_null;
55    use crate::validity::Validity;
56
57    #[test]
58    fn fill_null_leading_none() {
59        let decimal_dtype = DecimalDType::new(19, 2);
60        let arr = DecimalArray::from_option_iter(
61            [None, Some(800i128), None, Some(1000i128), None],
62            decimal_dtype,
63        );
64        let p = fill_null(
65            arr.as_ref(),
66            &Scalar::decimal(
67                DecimalValue::I128(4200i128),
68                DecimalDType::new(19, 2),
69                Nullability::NonNullable,
70            ),
71        )
72        .unwrap()
73        .to_decimal();
74        assert_arrays_eq!(
75            p,
76            DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
77        );
78        assert_eq!(
79            p.buffer::<i128>().as_slice(),
80            vec![4200, 800, 4200, 1000, 4200]
81        );
82        assert!(p.validity_mask().all_true());
83    }
84
85    #[test]
86    fn fill_null_all_none() {
87        let decimal_dtype = DecimalDType::new(19, 2);
88
89        let arr = DecimalArray::from_option_iter(
90            [Option::<i128>::None, None, None, None, None],
91            decimal_dtype,
92        );
93
94        let p = fill_null(
95            arr.as_ref(),
96            &Scalar::decimal(
97                DecimalValue::I128(25500i128),
98                DecimalDType::new(19, 2),
99                Nullability::NonNullable,
100            ),
101        )
102        .unwrap()
103        .to_decimal();
104        assert_arrays_eq!(
105            p,
106            DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
107        );
108    }
109
110    #[test]
111    fn fill_null_non_nullable() {
112        let decimal_dtype = DecimalDType::new(19, 2);
113
114        let arr = DecimalArray::new(
115            buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
116            decimal_dtype,
117            Validity::NonNullable,
118        );
119        let p = fill_null(
120            arr.as_ref(),
121            &Scalar::decimal(
122                DecimalValue::I128(25500i128),
123                DecimalDType::new(19, 2),
124                Nullability::NonNullable,
125            ),
126        )
127        .unwrap()
128        .to_decimal();
129        assert_arrays_eq!(
130            p,
131            DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
132        );
133    }
134}