Skip to main content

vortex_array/arrays/bool/compute/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use vortex_dtype::Nullability;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_mask::AllOr;
10
11use crate::arrays::BoolArray;
12use crate::arrays::BoolVTable;
13use crate::compute::SumKernel;
14use crate::compute::SumKernelAdapter;
15use crate::register_kernel;
16use crate::scalar::Scalar;
17
18impl SumKernel for BoolVTable {
19    fn sum(&self, array: &BoolArray, accumulator: &Scalar) -> VortexResult<Scalar> {
20        let true_count: Option<u64> = match array.validity_mask()?.bit_buffer() {
21            AllOr::All => {
22                // All-valid
23                Some(array.to_bit_buffer().true_count() as u64)
24            }
25            AllOr::None => {
26                // All-invalid
27                unreachable!("All-invalid boolean array should have been handled by entry-point")
28            }
29            AllOr::Some(validity_mask) => {
30                Some(array.to_bit_buffer().bitand(validity_mask).true_count() as u64)
31            }
32        };
33
34        let acc_value = accumulator
35            .as_primitive()
36            .as_::<u64>()
37            .vortex_expect("cannot be null");
38        let result = true_count.and_then(|tc| acc_value.checked_add(tc));
39        Ok(match result {
40            Some(v) => Scalar::primitive(v, Nullability::Nullable),
41            None => Scalar::null_native::<u64>(),
42        })
43    }
44}
45
46register_kernel!(SumKernelAdapter(BoolVTable).lift());