vortex_array/arrays/primitive/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::NativePType;
5use vortex_dtype::half::f16;
6use vortex_dtype::match_each_native_ptype;
7use vortex_error::VortexResult;
8
9use crate::arrays::PrimitiveArray;
10use crate::arrays::PrimitiveVTable;
11use crate::compute::IsConstantKernel;
12use crate::compute::IsConstantKernelAdapter;
13use crate::compute::IsConstantOpts;
14use crate::register_kernel;
15
16cfg_if::cfg_if! {
17    if #[cfg(target_feature = "avx2")] {
18        pub const IS_CONST_LANE_WIDTH: usize = 32;
19    } else {
20        pub const IS_CONST_LANE_WIDTH: usize = 16;
21    }
22}
23
24impl IsConstantKernel for PrimitiveVTable {
25    fn is_constant(
26        &self,
27        array: &PrimitiveArray,
28        opts: &IsConstantOpts,
29    ) -> VortexResult<Option<bool>> {
30        if opts.is_negligible_cost() {
31            return Ok(None);
32        }
33
34        let is_constant = match_each_native_ptype!(array.ptype(), integral: |P| {
35            compute_is_constant::<_, {IS_CONST_LANE_WIDTH / size_of::<P>()}>(array.as_slice::<P>())
36        }, floating: |P| {
37            compute_is_constant::<_, {IS_CONST_LANE_WIDTH / size_of::<P>()}>(unsafe { std::mem::transmute::<&[P], &[<P as EqFloat>::IntType]>(array.as_slice::<P>()) })
38        });
39
40        Ok(Some(is_constant))
41    }
42}
43
44register_kernel!(IsConstantKernelAdapter(PrimitiveVTable).lift());
45
46/// Assumes any floating point has been cast into its bit representation for which != and !is_eq are the same
47/// Assumes there's at least 1 value in the slice, which is an invariant of the entry level function.
48pub fn compute_is_constant<T: NativePType, const WIDTH: usize>(values: &[T]) -> bool {
49    let first_value = values[0];
50    let first_vec = &[first_value; WIDTH];
51
52    let mut chunks = values[1..].chunks_exact(WIDTH);
53    for chunk in &mut chunks {
54        assert_eq!(chunk.len(), WIDTH); // let the compiler know each chunk is WIDTH.
55        if first_vec != chunk {
56            return false;
57        }
58    }
59
60    for value in chunks.remainder() {
61        if !value.is_eq(first_value) {
62            return false;
63        }
64    }
65
66    true
67}
68
69trait EqFloat {
70    type IntType;
71}
72
73impl EqFloat for f16 {
74    type IntType = u16;
75}
76impl EqFloat for f32 {
77    type IntType = u32;
78}
79impl EqFloat for f64 {
80    type IntType = u64;
81}