vortex_array/arrays/decimal/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5
6use vortex_dtype::match_each_decimal_value_type;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_scalar::Scalar;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::ToCanonical;
14use crate::arrays::DecimalVTable;
15use crate::arrays::decimal::DecimalArray;
16use crate::compute::FillNullKernel;
17use crate::compute::FillNullKernelAdapter;
18use crate::register_kernel;
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21
22impl FillNullKernel for DecimalVTable {
23    fn fill_null(&self, array: &DecimalArray, fill_value: &Scalar) -> VortexResult<ArrayRef> {
24        let result_validity = Validity::from(fill_value.dtype().nullability());
25
26        Ok(match array.validity() {
27            Validity::Array(is_valid) => {
28                let is_invalid = is_valid.to_bool().bit_buffer().not();
29                match_each_decimal_value_type!(array.values_type(), |T| {
30                    let mut buffer = array.buffer::<T>().into_mut();
31                    let fill_value = fill_value
32                        .as_decimal()
33                        .decimal_value()
34                        .and_then(|v| v.cast::<T>())
35                        .vortex_expect("top-level fill_null ensure non-null fill value");
36                    for invalid_index in is_invalid.set_indices() {
37                        buffer[invalid_index] = fill_value;
38                    }
39                    DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity)
40                        .into_array()
41                })
42            }
43            _ => unreachable!("checked in entry point"),
44        })
45    }
46}
47
48register_kernel!(FillNullKernelAdapter(DecimalVTable).lift());
49
50#[cfg(test)]
51mod tests {
52    use vortex_buffer::buffer;
53    use vortex_dtype::DecimalDType;
54    use vortex_dtype::Nullability;
55    use vortex_scalar::DecimalValue;
56    use vortex_scalar::Scalar;
57
58    use crate::arrays::decimal::DecimalArray;
59    use crate::assert_arrays_eq;
60    use crate::canonical::ToCanonical;
61    use crate::compute::fill_null;
62    use crate::validity::Validity;
63
64    #[test]
65    fn fill_null_leading_none() {
66        let decimal_dtype = DecimalDType::new(19, 2);
67        let arr = DecimalArray::from_option_iter(
68            [None, Some(800i128), None, Some(1000i128), None],
69            decimal_dtype,
70        );
71        let p = fill_null(
72            arr.as_ref(),
73            &Scalar::decimal(
74                DecimalValue::I128(4200i128),
75                DecimalDType::new(19, 2),
76                Nullability::NonNullable,
77            ),
78        )
79        .unwrap()
80        .to_decimal();
81        assert_arrays_eq!(
82            p,
83            DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
84        );
85        assert_eq!(
86            p.buffer::<i128>().as_slice(),
87            vec![4200, 800, 4200, 1000, 4200]
88        );
89        assert!(p.validity_mask().all_true());
90    }
91
92    #[test]
93    fn fill_null_all_none() {
94        let decimal_dtype = DecimalDType::new(19, 2);
95
96        let arr = DecimalArray::from_option_iter(
97            [Option::<i128>::None, None, None, None, None],
98            decimal_dtype,
99        );
100
101        let p = fill_null(
102            arr.as_ref(),
103            &Scalar::decimal(
104                DecimalValue::I128(25500i128),
105                DecimalDType::new(19, 2),
106                Nullability::NonNullable,
107            ),
108        )
109        .unwrap()
110        .to_decimal();
111        assert_arrays_eq!(
112            p,
113            DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
114        );
115    }
116
117    #[test]
118    fn fill_null_non_nullable() {
119        let decimal_dtype = DecimalDType::new(19, 2);
120
121        let arr = DecimalArray::new(
122            buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
123            decimal_dtype,
124            Validity::NonNullable,
125        );
126        let p = fill_null(
127            arr.as_ref(),
128            &Scalar::decimal(
129                DecimalValue::I128(25500i128),
130                DecimalDType::new(19, 2),
131                Nullability::NonNullable,
132            ),
133        )
134        .unwrap()
135        .to_decimal();
136        assert_arrays_eq!(
137            p,
138            DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
139        );
140    }
141}