vortex_array/arrays/decimal/compute/
fill_null.rs1use std::ops::Not;
5
6use vortex_dtype::match_each_decimal_value_type;
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_scalar::Scalar;
9
10use crate::arrays::DecimalVTable;
11use crate::arrays::decimal::DecimalArray;
12use crate::compute::{FillNullKernel, FillNullKernelAdapter};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl FillNullKernel for DecimalVTable {
18 fn fill_null(&self, array: &DecimalArray, fill_value: &Scalar) -> VortexResult<ArrayRef> {
19 let result_validity = Validity::from(fill_value.dtype().nullability());
20
21 Ok(match array.validity() {
22 Validity::Array(is_valid) => {
23 let is_invalid = is_valid.to_bool().bit_buffer().not();
24 match_each_decimal_value_type!(array.values_type(), |T| {
25 let mut buffer = array.buffer::<T>().into_mut();
26 let fill_value = fill_value
27 .as_decimal()
28 .decimal_value()
29 .and_then(|v| v.cast::<T>())
30 .vortex_expect("top-level fill_null ensure non-null fill value");
31 for invalid_index in is_invalid.set_indices() {
32 buffer[invalid_index] = fill_value;
33 }
34 DecimalArray::new(buffer.freeze(), array.decimal_dtype(), result_validity)
35 .into_array()
36 })
37 }
38 _ => unreachable!("checked in entry point"),
39 })
40 }
41}
42
43register_kernel!(FillNullKernelAdapter(DecimalVTable).lift());
44
45#[cfg(test)]
46mod tests {
47 use vortex_buffer::buffer;
48 use vortex_dtype::{DecimalDType, Nullability};
49 use vortex_scalar::{DecimalValue, Scalar};
50
51 use crate::arrays::decimal::DecimalArray;
52 use crate::assert_arrays_eq;
53 use crate::canonical::ToCanonical;
54 use crate::compute::fill_null;
55 use crate::validity::Validity;
56
57 #[test]
58 fn fill_null_leading_none() {
59 let decimal_dtype = DecimalDType::new(19, 2);
60 let arr = DecimalArray::from_option_iter(
61 [None, Some(800i128), None, Some(1000i128), None],
62 decimal_dtype,
63 );
64 let p = fill_null(
65 arr.as_ref(),
66 &Scalar::decimal(
67 DecimalValue::I128(4200i128),
68 DecimalDType::new(19, 2),
69 Nullability::NonNullable,
70 ),
71 )
72 .unwrap()
73 .to_decimal();
74 assert_arrays_eq!(
75 p,
76 DecimalArray::from_iter([4200, 800, 4200, 1000, 4200], decimal_dtype)
77 );
78 assert_eq!(
79 p.buffer::<i128>().as_slice(),
80 vec![4200, 800, 4200, 1000, 4200]
81 );
82 assert!(p.validity_mask().all_true());
83 }
84
85 #[test]
86 fn fill_null_all_none() {
87 let decimal_dtype = DecimalDType::new(19, 2);
88
89 let arr = DecimalArray::from_option_iter(
90 [Option::<i128>::None, None, None, None, None],
91 decimal_dtype,
92 );
93
94 let p = fill_null(
95 arr.as_ref(),
96 &Scalar::decimal(
97 DecimalValue::I128(25500i128),
98 DecimalDType::new(19, 2),
99 Nullability::NonNullable,
100 ),
101 )
102 .unwrap()
103 .to_decimal();
104 assert_arrays_eq!(
105 p,
106 DecimalArray::from_iter([25500, 25500, 25500, 25500, 25500], decimal_dtype)
107 );
108 }
109
110 #[test]
111 fn fill_null_non_nullable() {
112 let decimal_dtype = DecimalDType::new(19, 2);
113
114 let arr = DecimalArray::new(
115 buffer![800i128, 1000i128, 1200i128, 1400i128, 1600i128],
116 decimal_dtype,
117 Validity::NonNullable,
118 );
119 let p = fill_null(
120 arr.as_ref(),
121 &Scalar::decimal(
122 DecimalValue::I128(25500i128),
123 DecimalDType::new(19, 2),
124 Nullability::NonNullable,
125 ),
126 )
127 .unwrap()
128 .to_decimal();
129 assert_arrays_eq!(
130 p,
131 DecimalArray::from_iter([800i128, 1000, 1200, 1400, 1600], decimal_dtype)
132 );
133 }
134}