Skip to main content

vortex_array/compute/
is_constant.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_error::VortexError;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14
15use crate::Array;
16use crate::arrays::ConstantVTable;
17use crate::arrays::NullVTable;
18use crate::compute::ComputeFn;
19use crate::compute::ComputeFnVTable;
20use crate::compute::InvocationArgs;
21use crate::compute::Kernel;
22use crate::compute::Options;
23use crate::compute::Output;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::expr::stats::StatsProviderExt;
28use crate::scalar::Scalar;
29use crate::vtable::VTable;
30
31static IS_CONSTANT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
32    let compute = ComputeFn::new("is_constant".into(), ArcRef::new_ref(&IsConstant));
33    for kernel in inventory::iter::<IsConstantKernelRef> {
34        compute.register_kernel(kernel.0.clone());
35    }
36    compute
37});
38
39pub(crate) fn warm_up_vtable() -> usize {
40    IS_CONSTANT_FN.kernels().len()
41}
42
43/// Computes whether an array has constant values. If the array's encoding doesn't implement the
44/// relevant VTable, it'll try and canonicalize in order to make a determination.
45///
46/// An array is constant IFF at least one of the following conditions apply:
47/// 1. It has at least one element (**Note** - an empty array isn't constant).
48/// 1. It's encoded as a [`crate::arrays::ConstantArray`] or [`crate::arrays::NullArray`]
49/// 1. Has an exact statistic attached to it, saying its constant.
50/// 1. Is all invalid.
51/// 1. Is all valid AND has minimum and maximum statistics that are equal.
52///
53/// If the array has some null values but is not all null, it'll never be constant.
54///
55/// Returns `Ok(None)` if we could not determine whether the array is constant, e.g. if
56/// canonicalization is disabled and the no kernel exists for the array's encoding.
57pub fn is_constant(array: &dyn Array) -> VortexResult<Option<bool>> {
58    let opts = IsConstantOpts::default();
59    is_constant_opts(array, &opts)
60}
61
62/// Computes whether an array has constant values. Configurable by [`IsConstantOpts`].
63///
64/// Please see [`is_constant`] for a more detailed explanation of its behavior.
65pub fn is_constant_opts(array: &dyn Array, options: &IsConstantOpts) -> VortexResult<Option<bool>> {
66    Ok(IS_CONSTANT_FN
67        .invoke(&InvocationArgs {
68            inputs: &[array.into()],
69            options,
70        })?
71        .unwrap_scalar()?
72        .as_bool()
73        .value())
74}
75
76struct IsConstant;
77
78impl ComputeFnVTable for IsConstant {
79    fn invoke(
80        &self,
81        args: &InvocationArgs,
82        kernels: &[ArcRef<dyn Kernel>],
83    ) -> VortexResult<Output> {
84        let IsConstantArgs { array, options } = IsConstantArgs::try_from(args)?;
85
86        // We try and rely on some easy-to-get stats
87        if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
88            let scalar: Scalar = Some(value).into();
89            return Ok(scalar.into());
90        }
91
92        let value = is_constant_impl(array, options, kernels)?;
93
94        if options.cost == Cost::Canonicalize {
95            // When we run linear canonicalize, there we must always return an exact answer.
96            assert!(
97                value.is_some(),
98                "is constant in array {array} canonicalize returned None"
99            );
100        }
101
102        // Only if we made a determination do we update the stats.
103        if let Some(value) = value {
104            array
105                .statistics()
106                .set(Stat::IsConstant, Precision::Exact(value.into()));
107        }
108
109        let scalar: Scalar = value.into();
110        Ok(scalar.into())
111    }
112
113    fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
114        // We always return a nullable boolean where `null` indicates we couldn't determine
115        // whether the array is constant.
116        Ok(DType::Bool(Nullability::Nullable))
117    }
118
119    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
120        Ok(1)
121    }
122
123    fn is_elementwise(&self) -> bool {
124        false
125    }
126}
127
128fn is_constant_impl(
129    array: &dyn Array,
130    options: &IsConstantOpts,
131    kernels: &[ArcRef<dyn Kernel>],
132) -> VortexResult<Option<bool>> {
133    match array.len() {
134        // Our current semantics are that we can always get a value out of a constant array. We might want to change that in the future.
135        0 => return Ok(Some(false)),
136        // Array of length 1 is always constant.
137        1 => return Ok(Some(true)),
138        _ => {}
139    }
140
141    // Constant and null arrays are always constant
142    if array.is::<ConstantVTable>() || array.is::<NullVTable>() {
143        return Ok(Some(true));
144    }
145
146    let all_invalid = array.all_invalid()?;
147    if all_invalid {
148        return Ok(Some(true));
149    }
150
151    let all_valid = array.all_valid()?;
152
153    // If we have some nulls, array can't be constant
154    if !all_valid && !all_invalid {
155        return Ok(Some(false));
156    }
157
158    // We already know here that the array is all valid, so we check for min/max stats.
159    let min = array.statistics().get(Stat::Min);
160    let max = array.statistics().get(Stat::Max);
161
162    if let Some((min, max)) = min.zip(max) {
163        // min/max are equal and exact and there are no NaNs
164        if min.is_exact()
165            && min == max
166            && (Stat::NaNCount.dtype(array.dtype()).is_none()
167                || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
168        {
169            return Ok(Some(true));
170        }
171    }
172
173    assert!(
174        all_valid,
175        "All values must be valid as an invariant of the VTable."
176    );
177    let args = InvocationArgs {
178        inputs: &[array.into()],
179        options,
180    };
181    for kernel in kernels {
182        if let Some(output) = kernel.invoke(&args)? {
183            return Ok(output.unwrap_scalar()?.as_bool().value());
184        }
185    }
186
187    tracing::debug!(
188        "No is_constant implementation found for {}",
189        array.encoding_id()
190    );
191
192    if options.cost == Cost::Canonicalize && !array.is_canonical() {
193        let array = array.to_canonical()?;
194        let is_constant = is_constant_opts(array.as_ref(), options)?;
195        return Ok(is_constant);
196    }
197
198    // Otherwise, we cannot determine if the array is constant.
199    Ok(None)
200}
201
202pub struct IsConstantKernelRef(ArcRef<dyn Kernel>);
203inventory::collect!(IsConstantKernelRef);
204
205pub trait IsConstantKernel: VTable {
206    /// # Preconditions
207    ///
208    /// * All values are valid
209    /// * array.len() > 1
210    ///
211    /// Returns `Ok(None)` to signal we couldn't make an exact determination.
212    fn is_constant(&self, array: &Self::Array, opts: &IsConstantOpts)
213    -> VortexResult<Option<bool>>;
214}
215
216#[derive(Debug)]
217pub struct IsConstantKernelAdapter<V: VTable>(pub V);
218
219impl<V: VTable + IsConstantKernel> IsConstantKernelAdapter<V> {
220    pub const fn lift(&'static self) -> IsConstantKernelRef {
221        IsConstantKernelRef(ArcRef::new_ref(self))
222    }
223}
224
225impl<V: VTable + IsConstantKernel> Kernel for IsConstantKernelAdapter<V> {
226    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
227        let args = IsConstantArgs::try_from(args)?;
228        let Some(array) = args.array.as_opt::<V>() else {
229            return Ok(None);
230        };
231        let is_constant = V::is_constant(&self.0, array, args.options)?;
232        let scalar: Scalar = is_constant.into();
233        Ok(Some(scalar.into()))
234    }
235}
236
237struct IsConstantArgs<'a> {
238    array: &'a dyn Array,
239    options: &'a IsConstantOpts,
240}
241
242impl<'a> TryFrom<&InvocationArgs<'a>> for IsConstantArgs<'a> {
243    type Error = VortexError;
244
245    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
246        if value.inputs.len() != 1 {
247            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
248        }
249        let array = value.inputs[0]
250            .array()
251            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
252        let options = value
253            .options
254            .as_any()
255            .downcast_ref::<IsConstantOpts>()
256            .ok_or_else(|| vortex_err!("Expected options to be of type IsConstantOpts"))?;
257        Ok(Self { array, options })
258    }
259}
260
261/// When calling `is_constant` the children are all checked for constantness.
262/// This enum decide at each precision/cost level the constant check should run as.
263/// The cost increase as we move down the list.
264#[derive(Clone, Copy, Debug, Eq, PartialEq)]
265pub enum Cost {
266    /// Only apply constant time computation to estimate constantness.
267    Negligible,
268    /// Allow the encoding to do a linear amount of work to determine is constant.
269    /// Each encoding should implement short-circuiting make the common case runtime well below
270    /// a linear scan.
271    Specialized,
272    /// Same as linear, but when necessary canonicalize the array and check is constant.
273    /// This *must* always return a known answer.
274    Canonicalize,
275}
276
277/// Configuration for [`is_constant_opts`] operations.
278#[derive(Clone, Debug)]
279pub struct IsConstantOpts {
280    /// What precision cost trade off should be used
281    pub cost: Cost,
282}
283
284impl Default for IsConstantOpts {
285    fn default() -> Self {
286        Self {
287            cost: Cost::Canonicalize,
288        }
289    }
290}
291
292impl Options for IsConstantOpts {
293    fn as_any(&self) -> &dyn Any {
294        self
295    }
296}
297
298impl IsConstantOpts {
299    pub fn is_negligible_cost(&self) -> bool {
300        self.cost == Cost::Negligible
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use vortex_buffer::buffer;
307
308    use crate::IntoArray as _;
309    use crate::arrays::PrimitiveArray;
310    use crate::compute::is_constant;
311    use crate::expr::stats::Stat;
312
313    #[test]
314    fn is_constant_min_max_no_nan() {
315        let arr = buffer![0, 1].into_array();
316        arr.statistics()
317            .compute_all(&[Stat::Min, Stat::Max])
318            .unwrap();
319        assert!(!is_constant(&arr).unwrap().unwrap_or_default());
320
321        let arr = buffer![0, 0].into_array();
322        arr.statistics()
323            .compute_all(&[Stat::Min, Stat::Max])
324            .unwrap();
325        assert!(is_constant(&arr).unwrap().unwrap_or_default());
326
327        let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]);
328        assert!(is_constant(arr.as_ref()).unwrap().unwrap_or_default());
329    }
330
331    #[test]
332    fn is_constant_min_max_with_nan() {
333        let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]);
334        arr.statistics()
335            .compute_all(&[Stat::Min, Stat::Max])
336            .unwrap();
337        assert!(!is_constant(arr.as_ref()).unwrap().unwrap_or_default());
338
339        let arr =
340            PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)]);
341        arr.statistics()
342            .compute_all(&[Stat::Min, Stat::Max])
343            .unwrap();
344        assert!(is_constant(arr.as_ref()).unwrap().unwrap_or_default());
345    }
346}