vortex_array/arrays/primitive/compute/
nan_count.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::NativePType;
5use vortex_dtype::match_each_float_ptype;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use crate::arrays::PrimitiveArray;
10use crate::arrays::PrimitiveVTable;
11use crate::compute::NaNCountKernel;
12use crate::compute::NaNCountKernelAdapter;
13use crate::register_kernel;
14
15impl NaNCountKernel for PrimitiveVTable {
16    fn nan_count(&self, array: &PrimitiveArray) -> VortexResult<usize> {
17        Ok(match_each_float_ptype!(array.ptype(), |F| {
18            compute_nan_count_with_validity(array.as_slice::<F>(), array.validity_mask())
19        }))
20    }
21}
22
23register_kernel!(NaNCountKernelAdapter(PrimitiveVTable).lift());
24
25#[inline]
26fn compute_nan_count_with_validity<T: NativePType>(values: &[T], validity: Mask) -> usize {
27    match validity {
28        Mask::AllTrue(_) => values.iter().filter(|v| v.is_nan()).count(),
29        Mask::AllFalse(_) => 0,
30        Mask::Values(v) => values
31            .iter()
32            .zip(v.bit_buffer().iter())
33            .filter_map(|(v, m)| m.then_some(v))
34            .filter(|v| v.is_nan())
35            .count(),
36    }
37}
38
39#[cfg(test)]
40mod tests {
41    use vortex_buffer::buffer;
42
43    use crate::arrays::PrimitiveArray;
44    use crate::compute::nan_count;
45    use crate::validity::Validity;
46
47    #[test]
48    fn primitive_nan_count() {
49        let p = PrimitiveArray::new(
50            buffer![
51                -f32::NAN,
52                f32::NAN,
53                0.1,
54                1.1,
55                -0.0,
56                f32::INFINITY,
57                f32::NEG_INFINITY
58            ],
59            Validity::NonNullable,
60        );
61        assert_eq!(nan_count(p.as_ref()).unwrap(), 2);
62    }
63}