Skip to main content

vortex_array/scalar_fn/fns/
stat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Scalar function implementation for aggregate-backed stat expressions.
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::aggregate_fn::AggregateFnRef;
17use crate::aggregate_fn::fns::all_nan::AllNan;
18use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
19use crate::aggregate_fn::fns::all_non_null::AllNonNull;
20use crate::aggregate_fn::fns::all_null::AllNull;
21use crate::arrays::ConstantArray;
22use crate::dtype::DType;
23use crate::expr::Expression;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::expr::stats::StatsProviderExt;
28use crate::scalar::Scalar;
29use crate::scalar::ScalarValue;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35
36/// Options for the `stat` scalar function.
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct StatOptions {
39    aggregate_fn: AggregateFnRef,
40}
41
42impl StatOptions {
43    /// Creates options for the provided aggregate statistic.
44    pub fn new(aggregate_fn: AggregateFnRef) -> Self {
45        Self { aggregate_fn }
46    }
47
48    /// Returns the aggregate function backing this statistic lookup.
49    pub fn aggregate_fn(&self) -> &AggregateFnRef {
50        &self.aggregate_fn
51    }
52}
53
54impl Display for StatOptions {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        Display::fmt(&self.aggregate_fn, f)
57    }
58}
59
60/// Scalar function that broadcasts a stored aggregate partial over the input rows.
61///
62/// The only current consumer is **row-wise pruning**: substituting `stat(col, agg)` into a
63/// predicate produces a cheap, row-aligned approximation whose constant runs let downstream
64/// filters drop entire stretches at once. For example, `value < 10` is prunable as
65/// `stat(value, max) < 10` (rows where the bound is false are guaranteed false) or
66/// `stat(value, min) >= 10` (rows where it is true are guaranteed true) — the zone-map /
67/// min-max-index pattern, expressed as an ordinary expression so the existing scalar
68/// machinery can rewrite, fold, and execute it.
69///
70/// The result is row-aligned with the input, at whatever granularity the input carries the
71/// stat at: e.g. a flat array yields a single broadcast `ConstantArray`; a chunked array
72/// yields a constant per chunk; a zone-mapped array would yield a run-end-encoded array,
73/// one run per zone. If the requested stat is not available, the result is a null constant.
74///
75/// Pruning only makes sense for aggregates that can prove something about every row in the scope
76/// — `min`, `max`, `all_null`, `all_non_null`, bloom filters, etc. Non-idempotent aggregates like
77/// `sum`, `count`, `mean`, `null_count`, and `nan_count` still produce a meaningful per-chunk
78/// value but do **not** bound any single row.
79#[derive(Clone)]
80pub struct StatFn;
81
82impl ScalarFnVTable for StatFn {
83    type Options = StatOptions;
84
85    fn id(&self) -> ScalarFnId {
86        static ID: CachedId = CachedId::new("vortex.stat");
87        *ID
88    }
89
90    fn arity(&self, _options: &Self::Options) -> Arity {
91        Arity::Exact(1)
92    }
93
94    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
95        match child_idx {
96            0 => ChildName::from("input"),
97            _ => unreachable!("Invalid child index {} for Stat expression", child_idx),
98        }
99    }
100
101    fn fmt_sql(
102        &self,
103        options: &Self::Options,
104        expr: &Expression,
105        f: &mut Formatter<'_>,
106    ) -> std::fmt::Result {
107        write!(f, "stat(")?;
108        expr.child(0).fmt_sql(f)?;
109        write!(f, ", {})", options.aggregate_fn())
110    }
111
112    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
113        stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
114    }
115
116    fn execute(
117        &self,
118        options: &Self::Options,
119        args: &dyn ExecutionArgs,
120        _ctx: &mut ExecutionCtx,
121    ) -> VortexResult<ArrayRef> {
122        let input = args.get(0)?;
123        let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
124        stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
125    }
126}
127
128fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
129    let Some(dtype) = aggregate_fn.state_dtype(input_dtype) else {
130        vortex_bail!(
131            "Aggregate function {} does not support input dtype {}",
132            aggregate_fn,
133            input_dtype
134        );
135    };
136    Ok(dtype.as_nullable())
137}
138
139fn stat_array(
140    array: &ArrayRef,
141    aggregate_fn: &AggregateFnRef,
142    dtype: DType,
143    len: usize,
144) -> VortexResult<ArrayRef> {
145    let value = if aggregate_fn.is::<AllNull>() {
146        let len = u64::try_from(len)?;
147        match array.statistics().get_as::<u64>(Stat::NullCount) {
148            Precision::Exact(count) => Some(count == len),
149            Precision::Inexact(count) => (count < len).then_some(false),
150            Precision::Absent => None,
151        }
152        .map(ScalarValue::Bool)
153    } else if aggregate_fn.is::<AllNonNull>() {
154        match array.statistics().get_as::<u64>(Stat::NullCount) {
155            Precision::Exact(count) => Some(count == 0),
156            Precision::Inexact(0) => Some(true),
157            Precision::Inexact(_) | Precision::Absent => None,
158        }
159        .map(ScalarValue::Bool)
160    } else if aggregate_fn.is::<AllNan>() {
161        let len = u64::try_from(len)?;
162        match array.statistics().get_as::<u64>(Stat::NaNCount) {
163            Precision::Exact(count) => Some(count == len),
164            Precision::Inexact(count) => (count < len).then_some(false),
165            Precision::Absent => None,
166        }
167        .map(ScalarValue::Bool)
168    } else if aggregate_fn.is::<AllNonNan>() {
169        match array.statistics().get_as::<u64>(Stat::NaNCount) {
170            Precision::Exact(count) => Some(count == 0),
171            Precision::Inexact(0) => Some(true),
172            Precision::Inexact(_) | Precision::Absent => None,
173        }
174        .map(ScalarValue::Bool)
175    } else if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
176        array
177            .statistics()
178            .with_typed_stats_set(|stats| stats.get(stat))
179            // We don't mind whether the stat is approxed or not, since these are row-wise bounds.
180            .into_inner()
181            .and_then(Scalar::into_value)
182    } else {
183        tracing::trace!(
184            "No legacy Stat slot for aggregate {}; stat expression will resolve to null",
185            aggregate_fn
186        );
187        None
188    };
189
190    let scalar = Scalar::try_new(dtype, value)?;
191    Ok(ConstantArray::new(scalar, len).into_array())
192}