Skip to main content

vortex_array/scalar_fn/fns/
stat.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Scalar function implementation for aggregate-backed stat expressions.
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::aggregate_fn::AggregateFnRef;
16use crate::arrays::ConstantArray;
17use crate::dtype::DType;
18use crate::expr::Expression;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22use crate::scalar_fn::Arity;
23use crate::scalar_fn::ChildName;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnVTable;
27
28/// Options for the `stat` scalar function.
29#[derive(Clone, Debug, PartialEq, Eq, Hash)]
30pub struct StatOptions {
31    aggregate_fn: AggregateFnRef,
32}
33
34impl StatOptions {
35    /// Creates options for the provided aggregate statistic.
36    pub fn new(aggregate_fn: AggregateFnRef) -> Self {
37        Self { aggregate_fn }
38    }
39
40    /// Returns the aggregate function backing this statistic lookup.
41    pub fn aggregate_fn(&self) -> &AggregateFnRef {
42        &self.aggregate_fn
43    }
44}
45
46impl Display for StatOptions {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        Display::fmt(&self.aggregate_fn, f)
49    }
50}
51
52/// Scalar function that broadcasts a stored aggregate statistic over the input rows.
53///
54/// The only current consumer is **row-wise pruning**: substituting `stat(col, agg)` into a
55/// predicate produces a cheap, row-aligned approximation whose constant runs let downstream
56/// filters drop entire stretches at once. For example, `value < 10` is prunable as
57/// `stat(value, max) < 10` (rows where the bound is false are guaranteed false) or
58/// `stat(value, min) >= 10` (rows where it is true are guaranteed true) — the zone-map /
59/// min-max-index pattern, expressed as an ordinary expression so the existing scalar
60/// machinery can rewrite, fold, and execute it.
61///
62/// The result is row-aligned with the input, at whatever granularity the input carries the
63/// stat at: e.g. a flat array yields a single broadcast `ConstantArray`; a chunked array
64/// yields a constant per chunk; a zone-mapped array would yield a run-end-encoded array,
65/// one run per zone. If the requested stat is not available, the result is a null constant.
66///
67/// Pruning only makes sense for aggregates that bound individual rows — `min`, `max`,
68/// `has_nulls`, bloom filters, etc. Non-idempotent aggregates like `sum`, `count`, `mean`,
69/// `null_count`, and `nan_count` still produce a meaningful per-chunk value but do **not**
70/// bound any single row.
71#[derive(Clone)]
72pub struct StatFn;
73
74impl ScalarFnVTable for StatFn {
75    type Options = StatOptions;
76
77    fn id(&self) -> ScalarFnId {
78        ScalarFnId::new("vortex.stat")
79    }
80
81    fn arity(&self, _options: &Self::Options) -> Arity {
82        Arity::Exact(1)
83    }
84
85    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
86        match child_idx {
87            0 => ChildName::from("input"),
88            _ => unreachable!("Invalid child index {} for Stat expression", child_idx),
89        }
90    }
91
92    fn fmt_sql(
93        &self,
94        options: &Self::Options,
95        expr: &Expression,
96        f: &mut Formatter<'_>,
97    ) -> std::fmt::Result {
98        write!(f, "stat(")?;
99        expr.child(0).fmt_sql(f)?;
100        write!(f, ", {})", options.aggregate_fn())
101    }
102
103    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
104        stat_dtype(options.aggregate_fn(), &arg_dtypes[0])
105    }
106
107    fn execute(
108        &self,
109        options: &Self::Options,
110        args: &dyn ExecutionArgs,
111        _ctx: &mut ExecutionCtx,
112    ) -> VortexResult<ArrayRef> {
113        let input = args.get(0)?;
114        let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?;
115        stat_array(&input, options.aggregate_fn(), dtype, args.row_count())
116    }
117}
118
119fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult<DType> {
120    let Some(dtype) = aggregate_fn.return_dtype(input_dtype) else {
121        vortex_bail!(
122            "Aggregate function {} does not support input dtype {}",
123            aggregate_fn,
124            input_dtype
125        );
126    };
127    Ok(dtype.as_nullable())
128}
129
130fn stat_array(
131    array: &ArrayRef,
132    aggregate_fn: &AggregateFnRef,
133    dtype: DType,
134    len: usize,
135) -> VortexResult<ArrayRef> {
136    let value = if let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) {
137        array
138            .statistics()
139            .with_typed_stats_set(|stats| stats.get(stat))
140            // We don't mind whether the stat is approxed or not, since these are row-wise bounds.
141            .map(|stat| stat.into_inner())
142            .and_then(Scalar::into_value)
143    } else {
144        tracing::trace!(
145            "No legacy Stat slot for aggregate {}; stat expression will resolve to null",
146            aggregate_fn
147        );
148        None
149    };
150
151    let scalar = Scalar::try_new(dtype, value)?;
152    Ok(ConstantArray::new(scalar, len).into_array())
153}