Skip to main content

vortex_array/scalar_fn/fns/
stat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Scalar function implementation for aggregate-backed stat expressions.
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::aggregate_fn::AggregateFnRef;
16use crate::aggregate_fn::fns::all_nan::AllNan;
17use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
18use crate::aggregate_fn::fns::all_non_null::AllNonNull;
19use crate::aggregate_fn::fns::all_null::AllNull;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::stats::Precision;
24use crate::expr::stats::Stat;
25use crate::expr::stats::StatsProvider;
26use crate::expr::stats::StatsProviderExt;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34
35/// Options for the `stat` scalar function.
36#[derive(Clone, Debug, PartialEq, Eq, Hash)]
37pub struct StatOptions {
38    aggregate_fn: AggregateFnRef,
39}
40
41impl StatOptions {
42    /// Creates options for the provided aggregate statistic.
43    pub fn new(aggregate_fn: AggregateFnRef) -> Self {
44        Self { aggregate_fn }
45    }
46
47    /// Returns the aggregate function backing this statistic lookup.
48    pub fn aggregate_fn(&self) -> &AggregateFnRef {
49        &self.aggregate_fn
50    }
51}
52
53impl Display for StatOptions {
54    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55        Display::fmt(&self.aggregate_fn, f)
56    }
57}
58
59/// Scalar function that broadcasts a stored aggregate partial over the input rows.
60///
61/// The only current consumer is **row-wise pruning**: substituting `stat(col, agg)` into a
62/// predicate produces a cheap, row-aligned approximation whose constant runs let downstream
63/// filters drop entire stretches at once. For example, `value < 10` is prunable as
64/// `stat(value, max) < 10` (rows where the bound is false are guaranteed false) or
65/// `stat(value, min) >= 10` (rows where it is true are guaranteed true) — the zone-map /
66/// min-max-index pattern, expressed as an ordinary expression so the existing scalar
67/// machinery can rewrite, fold, and execute it.
68///
69/// The result is row-aligned with the input, at whatever granularity the input carries the
70/// stat at: e.g. a flat array yields a single broadcast `ConstantArray`; a chunked array
71/// yields a constant per chunk; a zone-mapped array would yield a run-end-encoded array,
72/// one run per zone. If the requested stat is not available, the result is a null constant.
73///
74/// Pruning only makes sense for aggregates that can prove something about every row in the scope
75/// — `min`, `max`, `all_null`, `all_non_null`, bloom filters, etc. Non-idempotent aggregates like
76/// `sum`, `count`, `mean`, `null_count`, and `nan_count` still produce a meaningful per-chunk
77/// value but do **not** bound any single row.
78#[derive(Clone)]
79pub struct StatFn;
80
81impl ScalarFnVTable for StatFn {
82    type Options = StatOptions;
83
84    fn id(&self) -> ScalarFnId {
85        ScalarFnId::new("vortex.stat")
86    }
87
88    fn arity(&self, _options: &Self::Options) -> Arity {
89        Arity::Exact(1)
90    }
91
92    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
93        match child_idx {
94            0 => ChildName::from("input"),
95            _ => unreachable!("Invalid child index {} for Stat expression", child_idx),
96        }
97    }
98
99    fn fmt_sql(
100        &self,
101        options: &Self::Options,
102        expr: &Expression,
103        f: &mut Formatter<'_>,
104    ) -> std::fmt::Result {
105        write!(f, "stat(")?;
106        expr.child(0).fmt_sql(f)?;
107        write!(f, ", {})", options.aggregate_fn())
108    }
109
110    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
111        stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
112    }
113
114    fn execute(
115        &self,
116        options: &Self::Options,
117        args: &dyn ExecutionArgs,
118        _ctx: &mut ExecutionCtx,
119    ) -> VortexResult<ArrayRef> {
120        let input = args.get(0)?;
121        let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
122        stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
123    }
124}
125
126fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
127    let Some(dtype) = aggregate_fn.state_dtype(input_dtype) else {
128        vortex_bail!(
129            "Aggregate function {} does not support input dtype {}",
130            aggregate_fn,
131            input_dtype
132        );
133    };
134    Ok(dtype.as_nullable())
135}
136
137fn stat_array(
138    array: &ArrayRef,
139    aggregate_fn: &AggregateFnRef,
140    dtype: DType,
141    len: usize,
142) -> VortexResult<ArrayRef> {
143    let value = if aggregate_fn.is::<AllNull>() {
144        let len = u64::try_from(len)?;
145        match array.statistics().get_as::<u64>(Stat::NullCount) {
146            Precision::Exact(count) => Some(count == len),
147            Precision::Inexact(count) => (count < len).then_some(false),
148            Precision::Absent => None,
149        }
150        .map(ScalarValue::Bool)
151    } else if aggregate_fn.is::<AllNonNull>() {
152        match array.statistics().get_as::<u64>(Stat::NullCount) {
153            Precision::Exact(count) => Some(count == 0),
154            Precision::Inexact(0) => Some(true),
155            Precision::Inexact(_) | Precision::Absent => None,
156        }
157        .map(ScalarValue::Bool)
158    } else if aggregate_fn.is::<AllNan>() {
159        let len = u64::try_from(len)?;
160        match array.statistics().get_as::<u64>(Stat::NaNCount) {
161            Precision::Exact(count) => Some(count == len),
162            Precision::Inexact(count) => (count < len).then_some(false),
163            Precision::Absent => None,
164        }
165        .map(ScalarValue::Bool)
166    } else if aggregate_fn.is::<AllNonNan>() {
167        match array.statistics().get_as::<u64>(Stat::NaNCount) {
168            Precision::Exact(count) => Some(count == 0),
169            Precision::Inexact(0) => Some(true),
170            Precision::Inexact(_) | Precision::Absent => None,
171        }
172        .map(ScalarValue::Bool)
173    } else if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
174        array
175            .statistics()
176            .with_typed_stats_set(|stats| stats.get(stat))
177            // We don't mind whether the stat is approxed or not, since these are row-wise bounds.
178            .into_inner()
179            .and_then(Scalar::into_value)
180    } else {
181        tracing::trace!(
182            "No legacy Stat slot for aggregate {}; stat expression will resolve to null",
183            aggregate_fn
184        );
185        None
186    };
187
188    let scalar = Scalar::try_new(dtype, value)?;
189    Ok(ConstantArray::new(scalar, len).into_array())
190}