vortex_array/arrays/decimal/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5
6use vortex_dtype::match_each_decimal_value_type;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_scalar::Scalar;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::ToCanonical;
14use crate::arrays::DecimalVTable;
15use crate::arrays::decimal::DecimalArray;
16use crate::compute::FillNullKernel;
17use crate::compute::FillNullKernelAdapter;
18use crate::register_kernel;
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21
22impl FillNullKernel for DecimalVTable {
23    fn fill_null(&self, array: &DecimalArray, fill_value: &Scalar) -> VortexResult<ArrayRef> {
24        let result_validity = Validity::from(fill_value.dtype().nullability());
25
26        Ok(match array.validity() {
27            Validity::Array(is_valid) => {
28                let is_invalid = is_valid.to_bool().bit_buffer().not();
29                match_each_decimal_value_type!(array.values_type(), |T| {
30                    let mut buffer = array.buffer::<T>().into_mut();
31                    let decimal_scalar = fill_value.as_decimal();
32                    let decimal_value = decimal_scalar
33                        .decimal_value()
34                        .vortex_expect("fill_null requires a non-null fill value");
35                    let fill_value = decimal_value
36                        .cast::<T>()
37                        .vortex_expect("fill value does not fit in array's decimal storage type");
38                    for invalid_index in is_invalid.set_indices() {
39                        buffer[invalid_index] = fill_value;
40                    }
41                    DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity)
42                        .into_array()
43                })
44            }
45            _ => unreachable!("checked in entry point"),
46        })
47    }
48}
49
50register_kernel!(FillNullKernelAdapter(DecimalVTable).lift());
51
52#[cfg(test)]
53mod tests {
54    use vortex_buffer::buffer;
55    use vortex_dtype::DecimalDType;
56    use vortex_dtype::Nullability;
57    use vortex_scalar::DecimalValue;
58    use vortex_scalar::Scalar;
59
60    use crate::arrays::decimal::DecimalArray;
61    use crate::assert_arrays_eq;
62    use crate::canonical::ToCanonical;
63    use crate::compute::fill_null;
64    use crate::validity::Validity;
65
66    #[test]
67    fn fill_null_leading_none() {
68        let decimal_dtype = DecimalDType::new(19, 2);
69        let arr = DecimalArray::from_option_iter(
70            [None, Some(800i128), None, Some(1000i128), None],
71            decimal_dtype,
72        );
73        let p = fill_null(
74            arr.as_ref(),
75            &Scalar::decimal(
76                DecimalValue::I128(4200i128),
77                DecimalDType::new(19, 2),
78                Nullability::NonNullable,
79            ),
80        )
81        .unwrap()
82        .to_decimal();
83        assert_arrays_eq!(
84            p,
85            DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
86        );
87        assert_eq!(
88            p.buffer::<i128>().as_slice(),
89            vec![4200, 800, 4200, 1000, 4200]
90        );
91        assert!(p.validity_mask().all_true());
92    }
93
94    #[test]
95    fn fill_null_all_none() {
96        let decimal_dtype = DecimalDType::new(19, 2);
97
98        let arr = DecimalArray::from_option_iter(
99            [Option::<i128>::None, None, None, None, None],
100            decimal_dtype,
101        );
102
103        let p = fill_null(
104            arr.as_ref(),
105            &Scalar::decimal(
106                DecimalValue::I128(25500i128),
107                DecimalDType::new(19, 2),
108                Nullability::NonNullable,
109            ),
110        )
111        .unwrap()
112        .to_decimal();
113        assert_arrays_eq!(
114            p,
115            DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
116        );
117    }
118
119    #[test]
120    fn fill_null_non_nullable() {
121        let decimal_dtype = DecimalDType::new(19, 2);
122
123        let arr = DecimalArray::new(
124            buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
125            decimal_dtype,
126            Validity::NonNullable,
127        );
128        let p = fill_null(
129            arr.as_ref(),
130            &Scalar::decimal(
131                DecimalValue::I128(25500i128),
132                DecimalDType::new(19, 2),
133                Nullability::NonNullable,
134            ),
135        )
136        .unwrap()
137        .to_decimal();
138        assert_arrays_eq!(
139            p,
140            DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
141        );
142    }
143}