Skip to main content

vortex_array/aggregate_fn/fns/is_constant/
primitive.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use crate::arrays::PrimitiveArray;
5use crate::dtype::NativePType;
6use crate::dtype::half::f16;
7use crate::match_each_native_ptype;
8
9cfg_if::cfg_if! {
10    if #[cfg(target_feature = "avx2")] {
11        pub const IS_CONST_LANE_WIDTH: usize = 32;
12    } else {
13        pub const IS_CONST_LANE_WIDTH: usize = 16;
14    }
15}
16
17/// Assumes any floating point has been cast into its bit representation for which != and !is_eq are the same
18/// Assumes there's at least 1 value in the slice, which is an invariant of the entry level function.
19pub fn compute_is_constant<T: NativePType, const WIDTH: usize>(values: &[T]) -> bool {
20    let first_value = values[0];
21    let first_vec = &[first_value; WIDTH];
22
23    let mut chunks = values[1..].chunks_exact(WIDTH);
24    for chunk in &mut chunks {
25        assert_eq!(chunk.len(), WIDTH); // let the compiler know each chunk is WIDTH.
26        if first_vec != chunk {
27            return false;
28        }
29    }
30
31    for value in chunks.remainder() {
32        if !value.is_eq(first_value) {
33            return false;
34        }
35    }
36
37    true
38}
39
40trait EqFloat {
41    type IntType;
42}
43
44impl EqFloat for f16 {
45    type IntType = u16;
46}
47impl EqFloat for f32 {
48    type IntType = u32;
49}
50impl EqFloat for f64 {
51    type IntType = u64;
52}
53
54pub(super) fn check_primitive_constant(array: &PrimitiveArray) -> bool {
55    match_each_native_ptype!(array.ptype(), integral: |P| {
56        compute_is_constant::<_, {IS_CONST_LANE_WIDTH / size_of::<P>()}>(array.as_slice::<P>())
57    }, floating: |P| {
58        compute_is_constant::<_, {IS_CONST_LANE_WIDTH / size_of::<P>()}>(unsafe { std::mem::transmute::<&[P], &[<P as EqFloat>::IntType]>(array.as_slice::<P>()) })
59    })
60}