vortex_array/arrays/primitive/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::half::f16;
5use vortex_dtype::{NativePType, match_each_native_ptype};
6use vortex_error::VortexResult;
7
8use crate::arrays::{PrimitiveArray, PrimitiveVTable};
9use crate::compute::{IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts};
10use crate::register_kernel;
11
12cfg_if::cfg_if! {
13    if #[cfg(target_feature = "avx2")] {
14        pub const IS_CONST_LANE_WIDTH: usize = 32;
15    } else {
16        pub const IS_CONST_LANE_WIDTH: usize = 16;
17    }
18}
19
20impl IsConstantKernel for PrimitiveVTable {
21    fn is_constant(
22        &self,
23        array: &PrimitiveArray,
24        opts: &IsConstantOpts,
25    ) -> VortexResult<Option<bool>> {
26        if opts.is_negligible_cost() {
27            return Ok(None);
28        }
29
30        let is_constant = match_each_native_ptype!(array.ptype(), integral: |P| {
31            compute_is_constant::<_, {IS_CONST_LANE_WIDTH / size_of::<P>()}>(array.as_slice::<P>())
32        }, floating: |P| {
33            compute_is_constant::<_, {IS_CONST_LANE_WIDTH / size_of::<P>()}>(unsafe { std::mem::transmute::<&[P], &[<P as EqFloat>::IntType]>(array.as_slice::<P>()) })
34        });
35
36        Ok(Some(is_constant))
37    }
38}
39
40register_kernel!(IsConstantKernelAdapter(PrimitiveVTable).lift());
41
42/// Assumes any floating point has been cast into its bit representation for which != and !is_eq are the same
43/// Assumes there's at least 1 value in the slice, which is an invariant of the entry level function.
44pub fn compute_is_constant<T: NativePType, const WIDTH: usize>(values: &[T]) -> bool {
45    let first_value = values[0];
46    let first_vec = &[first_value; WIDTH];
47
48    let mut chunks = values[1..].chunks_exact(WIDTH);
49    for chunk in &mut chunks {
50        assert_eq!(chunk.len(), WIDTH); // let the compiler know each chunk is WIDTH.
51        if first_vec != chunk {
52            return false;
53        }
54    }
55
56    for value in chunks.remainder() {
57        if !value.is_eq(first_value) {
58            return false;
59        }
60    }
61
62    true
63}
64
65trait EqFloat {
66    type IntType;
67}
68
69impl EqFloat for f16 {
70    type IntType = u16;
71}
72impl EqFloat for f32 {
73    type IntType = u32;
74}
75impl EqFloat for f64 {
76    type IntType = u64;
77}