Skip to main content

vortex_array/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::IntoArray as _;
16use crate::arrays::ConstantVTable;
17use crate::arrays::NullVTable;
18use crate::compute::ComputeFn;
19use crate::compute::ComputeFnVTable;
20use crate::compute::InvocationArgs;
21use crate::compute::Kernel;
22use crate::compute::Options;
23use crate::compute::Output;
24use crate::dtype::DType;
25use crate::dtype::Nullability;
26use crate::expr::stats::Precision;
27use crate::expr::stats::Stat;
28use crate::expr::stats::StatsProvider;
29use crate::expr::stats::StatsProviderExt;
30use crate::scalar::Scalar;
31use crate::vtable::VTable;
32
33static IS_CONSTANT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
34    let compute = ComputeFn::new("is_constant".into(), ArcRef::new_ref(&IsConstant));
35    for kernel in inventory::iter::<IsConstantKernelRef> {
36        compute.register_kernel(kernel.0.clone());
37    }
38    compute
39});
40
41pub(crate) fn warm_up_vtable() -> usize {
42    IS_CONSTANT_FN.kernels().len()
43}
44
45/// Computes whether an array has constant values. If the array's encoding doesn't implement the
46/// relevant VTable, it'll try and canonicalize in order to make a determination.
47///
48/// An array is constant IFF at least one of the following conditions apply:
49/// 1. It has at least one element (**Note** - an empty array isn't constant).
50/// 1. It's encoded as a [`crate::arrays::ConstantArray`] or [`crate::arrays::NullArray`]
51/// 1. Has an exact statistic attached to it, saying its constant.
52/// 1. Is all invalid.
53/// 1. Is all valid AND has minimum and maximum statistics that are equal.
54///
55/// If the array has some null values but is not all null, it'll never be constant.
56///
57/// Returns `Ok(None)` if we could not determine whether the array is constant, e.g. if
58/// canonicalization is disabled and the no kernel exists for the array's encoding.
59pub fn is_constant(array: &ArrayRef) -> VortexResult<Option<bool>> {
60    let opts = IsConstantOpts::default();
61    is_constant_opts(array, &opts)
62}
63
64/// Computes whether an array has constant values. Configurable by [`IsConstantOpts`].
65///
66/// Please see [`is_constant`] for a more detailed explanation of its behavior.
67pub fn is_constant_opts(array: &ArrayRef, options: &IsConstantOpts) -> VortexResult<Option<bool>> {
68    Ok(IS_CONSTANT_FN
69        .invoke(&InvocationArgs {
70            inputs: &[array.into()],
71            options,
72        })?
73        .unwrap_scalar()?
74        .as_bool()
75        .value())
76}
77
78struct IsConstant;
79
80impl ComputeFnVTable for IsConstant {
81    fn invoke(
82        &self,
83        args: &InvocationArgs,
84        kernels: &[ArcRef<dyn Kernel>],
85    ) -> VortexResult<Output> {
86        let IsConstantArgs { array, options } = IsConstantArgs::try_from(args)?;
87        let array = array.to_array();
88
89        // We try and rely on some easy-to-get stats
90        if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
91            let scalar: Scalar = Some(value).into();
92            return Ok(scalar.into());
93        }
94
95        let value = is_constant_impl(&array, options, kernels)?;
96
97        if options.cost == Cost::Canonicalize {
98            // When we run linear canonicalize, there we must always return an exact answer.
99            assert!(
100                value.is_some(),
101                "is constant in array {array} canonicalize returned None"
102            );
103        }
104
105        // Only if we made a determination do we update the stats.
106        if let Some(value) = value {
107            array
108                .statistics()
109                .set(Stat::IsConstant, Precision::Exact(value.into()));
110        }
111
112        let scalar: Scalar = value.into();
113        Ok(scalar.into())
114    }
115
116    fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
117        // We always return a nullable boolean where `null` indicates we couldn't determine
118        // whether the array is constant.
119        Ok(DType::Bool(Nullability::Nullable))
120    }
121
122    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
123        Ok(1)
124    }
125
126    fn is_elementwise(&self) -> bool {
127        false
128    }
129}
130
131fn is_constant_impl(
132    array: &ArrayRef,
133    options: &IsConstantOpts,
134    kernels: &[ArcRef<dyn Kernel>],
135) -> VortexResult<Option<bool>> {
136    match array.len() {
137        // Our current semantics are that we can always get a value out of a constant array. We might want to change that in the future.
138        0 => return Ok(Some(false)),
139        // Array of length 1 is always constant.
140        1 => return Ok(Some(true)),
141        _ => {}
142    }
143
144    // Constant and null arrays are always constant
145    if array.is::<ConstantVTable>() || array.is::<NullVTable>() {
146        return Ok(Some(true));
147    }
148
149    let all_invalid = array.all_invalid()?;
150    if all_invalid {
151        return Ok(Some(true));
152    }
153
154    let all_valid = array.all_valid()?;
155
156    // If we have some nulls, array can't be constant
157    if !all_valid && !all_invalid {
158        return Ok(Some(false));
159    }
160
161    // We already know here that the array is all valid, so we check for min/max stats.
162    let min = array.statistics().get(Stat::Min);
163    let max = array.statistics().get(Stat::Max);
164
165    if let Some((min, max)) = min.zip(max) {
166        // min/max are equal and exact and there are no NaNs
167        if min.is_exact()
168            && min == max
169            && (Stat::NaNCount.dtype(array.dtype()).is_none()
170                || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
171        {
172            return Ok(Some(true));
173        }
174    }
175
176    assert!(
177        all_valid,
178        "All values must be valid as an invariant of the VTable."
179    );
180    let args = InvocationArgs {
181        inputs: &[array.into()],
182        options,
183    };
184    for kernel in kernels {
185        if let Some(output) = kernel.invoke(&args)? {
186            return Ok(output.unwrap_scalar()?.as_bool().value());
187        }
188    }
189
190    tracing::debug!(
191        "No is_constant implementation found for {}",
192        array.encoding_id()
193    );
194
195    if options.cost == Cost::Canonicalize && !array.is_canonical() {
196        let array = array.to_canonical()?.into_array();
197        let is_constant = is_constant_opts(&array, options)?;
198        return Ok(is_constant);
199    }
200
201    // Otherwise, we cannot determine if the array is constant.
202    Ok(None)
203}
204
205pub struct IsConstantKernelRef(ArcRef<dyn Kernel>);
206inventory::collect!(IsConstantKernelRef);
207
208pub trait IsConstantKernel: VTable {
209    /// # Preconditions
210    ///
211    /// * All values are valid
212    /// * array.len() > 1
213    ///
214    /// Returns `Ok(None)` to signal we couldn't make an exact determination.
215    fn is_constant(&self, array: &Self::Array, opts: &IsConstantOpts)
216    -> VortexResult<Option<bool>>;
217}
218
219#[derive(Debug)]
220pub struct IsConstantKernelAdapter<V: VTable>(pub V);
221
222impl<V: VTable + IsConstantKernel> IsConstantKernelAdapter<V> {
223    pub const fn lift(&'static self) -> IsConstantKernelRef {
224        IsConstantKernelRef(ArcRef::new_ref(self))
225    }
226}
227
228impl<V: VTable + IsConstantKernel> Kernel for IsConstantKernelAdapter<V> {
229    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
230        let args = IsConstantArgs::try_from(args)?;
231        let Some(array) = args.array.as_opt::<V>() else {
232            return Ok(None);
233        };
234        let is_constant = V::is_constant(&self.0, array, args.options)?;
235        let scalar: Scalar = is_constant.into();
236        Ok(Some(scalar.into()))
237    }
238}
239
240struct IsConstantArgs<'a> {
241    array: &'a dyn Array,
242    options: &'a IsConstantOpts,
243}
244
245impl<'a> TryFrom<&InvocationArgs<'a>> for IsConstantArgs<'a> {
246    type Error = VortexError;
247
248    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
249        if value.inputs.len() != 1 {
250            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
251        }
252        let array = value.inputs[0]
253            .array()
254            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
255        let options = value
256            .options
257            .as_any()
258            .downcast_ref::<IsConstantOpts>()
259            .ok_or_else(|| vortex_err!("Expected options to be of type IsConstantOpts"))?;
260        Ok(Self { array, options })
261    }
262}
263
264/// When calling `is_constant` the children are all checked for constantness.
265/// This enum decide at each precision/cost level the constant check should run as.
266/// The cost increase as we move down the list.
267#[derive(Clone, Copy, Debug, Eq, PartialEq)]
268pub enum Cost {
269    /// Only apply constant time computation to estimate constantness.
270    Negligible,
271    /// Allow the encoding to do a linear amount of work to determine is constant.
272    /// Each encoding should implement short-circuiting make the common case runtime well below
273    /// a linear scan.
274    Specialized,
275    /// Same as linear, but when necessary canonicalize the array and check is constant.
276    /// This *must* always return a known answer.
277    Canonicalize,
278}
279
280/// Configuration for [`is_constant_opts`] operations.
281#[derive(Clone, Debug)]
282pub struct IsConstantOpts {
283    /// What precision cost trade off should be used
284    pub cost: Cost,
285}
286
287impl Default for IsConstantOpts {
288    fn default() -> Self {
289        Self {
290            cost: Cost::Canonicalize,
291        }
292    }
293}
294
295impl Options for IsConstantOpts {
296    fn as_any(&self) -> &dyn Any {
297        self
298    }
299}
300
301impl IsConstantOpts {
302    pub fn is_negligible_cost(&self) -> bool {
303        self.cost == Cost::Negligible
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use vortex_buffer::buffer;
310
311    use crate::IntoArray as _;
312    use crate::arrays::PrimitiveArray;
313    use crate::compute::is_constant;
314    use crate::expr::stats::Stat;
315
316    #[test]
317    fn is_constant_min_max_no_nan() {
318        let arr = buffer![0, 1].into_array();
319        arr.statistics()
320            .compute_all(&[Stat::Min, Stat::Max])
321            .unwrap();
322        assert!(!is_constant(&arr).unwrap().unwrap_or_default());
323
324        let arr = buffer![0, 0].into_array();
325        arr.statistics()
326            .compute_all(&[Stat::Min, Stat::Max])
327            .unwrap();
328        assert!(is_constant(&arr).unwrap().unwrap_or_default());
329
330        let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]).into_array();
331        assert!(is_constant(&arr).unwrap().unwrap_or_default());
332    }
333
334    #[test]
335    fn is_constant_min_max_with_nan() {
336        let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]).into_array();
337        arr.statistics()
338            .compute_all(&[Stat::Min, Stat::Max])
339            .unwrap();
340        assert!(!is_constant(&arr).unwrap().unwrap_or_default());
341
342        let arr =
343            PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)])
344                .into_array();
345        arr.statistics()
346            .compute_all(&[Stat::Min, Stat::Max])
347            .unwrap();
348        assert!(is_constant(&arr).unwrap().unwrap_or_default());
349    }
350}