vegafusion_runtime/transform/
aggregate.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use datafusion_expr::{expr::AggregateFunctionParams, lit, Expr};
5use datafusion_functions_aggregate::median::median_udaf;
6use datafusion_functions_aggregate::variance::{var_pop_udaf, var_samp_udaf};
7use sqlparser::ast::NullTreatment;
8use std::collections::HashMap;
9
10use crate::data::util::DataFrameUtils;
11use crate::datafusion::udafs::percentile::{Q1_UDF, Q3_UDF};
12use async_trait::async_trait;
13use datafusion::prelude::DataFrame;
14use datafusion_expr::expr;
15use datafusion_functions_aggregate::expr_fn::{avg, count, count_distinct, max, min, sum};
16use datafusion_functions_aggregate::stddev::{stddev_pop_udaf, stddev_udaf};
17use std::sync::Arc;
18use vegafusion_common::column::{flat_col, unescaped_col};
19use vegafusion_common::data::ORDER_COL;
20use vegafusion_common::datafusion_common::{DFSchema, ScalarValue};
21use vegafusion_common::datatypes::to_numeric;
22use vegafusion_common::error::ResultWithContext;
23use vegafusion_common::escape::unescape_field;
24use vegafusion_core::arrow::datatypes::DataType;
25use vegafusion_core::error::{Result, VegaFusionError};
26use vegafusion_core::proto::gen::transforms::{Aggregate, AggregateOp};
27use vegafusion_core::task_graph::task_value::TaskValue;
28use vegafusion_core::transform::aggregate::op_name;
29
30#[async_trait]
31impl TransformTrait for Aggregate {
32    async fn eval(
33        &self,
34        dataframe: DataFrame,
35        _config: &CompilationConfig,
36    ) -> Result<(DataFrame, Vec<TaskValue>)> {
37        let group_exprs: Vec<_> = self
38            .groupby
39            .iter()
40            .filter(|c| {
41                dataframe
42                    .schema()
43                    .inner()
44                    .column_with_name(&unescape_field(c))
45                    .is_some()
46            })
47            .map(|c| unescaped_col(c))
48            .collect();
49
50        let (mut agg_exprs, projections) = get_agg_and_proj_exprs(self, dataframe.schema())?;
51
52        // Append ordering column to aggregations
53        agg_exprs.push(min(flat_col(ORDER_COL)).alias(ORDER_COL));
54
55        // Perform aggregation
56        let grouped_dataframe = dataframe.aggregate_mixed(group_exprs, agg_exprs)?;
57
58        // Make final projection
59        let grouped_dataframe = grouped_dataframe.select(projections)?;
60
61        Ok((grouped_dataframe, Vec::new()))
62    }
63}
64
65fn get_agg_and_proj_exprs(tx: &Aggregate, schema: &DFSchema) -> Result<(Vec<Expr>, Vec<Expr>)> {
66    // DataFusion does not allow repeated (field, op) combinations in an aggregate expression,
67    // so if there are duplicates we need to use a projection after the aggregation to alias
68    // the desired column
69    let mut agg_aliases: HashMap<(Option<String>, i32), String> = HashMap::new();
70
71    // Initialize vec of final projections with the grouping fields
72    let mut projections: Vec<_> = tx.groupby.iter().map(|f| unescaped_col(f)).collect();
73
74    // Prepend ORDER_COL
75    projections.insert(0, flat_col(ORDER_COL));
76
77    for (i, (field, op_code)) in tx.fields.iter().zip(tx.ops.iter()).enumerate() {
78        let op = AggregateOp::try_from(*op_code).unwrap();
79
80        let column = if *op_code == AggregateOp::Count as i32 {
81            // In Vega, the provided column is always ignored if op is 'count'.
82            None
83        } else {
84            match field.as_str() {
85                "" => {
86                    return Err(VegaFusionError::specification(format!(
87                        "Null field is not allowed for {op:?} op"
88                    )))
89                }
90                column => Some(column.to_string()),
91            }
92        };
93
94        // Apply alias
95        let alias = if let Some(alias) = tx.aliases.get(i).filter(|a| !a.is_empty()) {
96            // Alias is a non-empty string
97            alias.clone()
98        } else if field.is_empty() {
99            op_name(op).to_string()
100        } else {
101            format!("{}_{}", op_name(op), field,)
102        };
103
104        let key = (column, *op_code);
105        if let Some(agg_alias) = agg_aliases.get(&key) {
106            // We're already going to preform the aggregation, so alias result
107            projections.push(flat_col(agg_alias).alias(&alias));
108        } else {
109            projections.push(flat_col(&alias));
110            agg_aliases.insert(key, alias);
111        }
112    }
113
114    let mut agg_exprs = Vec::new();
115
116    for ((col_name, op_code), alias) in agg_aliases {
117        let op = AggregateOp::try_from(op_code).unwrap();
118
119        let agg_expr = make_aggr_expr_for_named_col(col_name, &op, schema)?;
120
121        // Apply alias
122        let agg_expr = agg_expr.alias(alias);
123
124        agg_exprs.push(agg_expr)
125    }
126    Ok((agg_exprs, projections))
127}
128
129pub fn make_aggr_expr_for_named_col(
130    col_name: Option<String>,
131    op: &AggregateOp,
132    schema: &DFSchema,
133) -> Result<Expr> {
134    let column = if let Some(col_name) = col_name {
135        let col_name = unescape_field(&col_name);
136        if schema.index_of_column_by_name(None, &col_name).is_none() {
137            // No column with specified name, short circuit to return default value
138            return if matches!(op, AggregateOp::Sum | AggregateOp::Count) {
139                // return zero for sum and count
140                Ok(lit(0))
141            } else {
142                // return NULL for all other operators
143                Ok(lit(ScalarValue::Float64(None)))
144            };
145        } else {
146            flat_col(&col_name)
147        }
148    } else {
149        lit(0i32)
150    };
151
152    make_agg_expr_for_col_expr(column, op, schema)
153}
154
155pub fn make_agg_expr_for_col_expr(
156    column: Expr,
157    op: &AggregateOp,
158    schema: &DFSchema,
159) -> Result<Expr> {
160    let numeric_column = || -> Result<Expr> {
161        to_numeric(column.clone(), schema)
162            .with_context(|| format!("Failed to convert column {column:?} to numeric data type"))
163    };
164
165    let agg_expr = match op {
166        AggregateOp::Count => count(column),
167        AggregateOp::Mean | AggregateOp::Average => avg(numeric_column()?),
168        AggregateOp::Min => min(column),
169        AggregateOp::Max => max(column),
170        AggregateOp::Sum => sum(numeric_column()?),
171        AggregateOp::Median => Expr::AggregateFunction(expr::AggregateFunction {
172            func: median_udaf(),
173            params: AggregateFunctionParams {
174                distinct: false,
175                args: vec![numeric_column()?],
176                filter: None,
177                order_by: None,
178                null_treatment: Some(NullTreatment::IgnoreNulls),
179            },
180        }),
181        AggregateOp::Variance => Expr::AggregateFunction(expr::AggregateFunction {
182            func: var_samp_udaf(),
183            params: AggregateFunctionParams {
184                distinct: false,
185                args: vec![numeric_column()?],
186                filter: None,
187                order_by: None,
188                null_treatment: Some(NullTreatment::IgnoreNulls),
189            },
190        }),
191        AggregateOp::Variancep => Expr::AggregateFunction(expr::AggregateFunction {
192            func: var_pop_udaf(),
193            params: AggregateFunctionParams {
194                distinct: false,
195                args: vec![numeric_column()?],
196                filter: None,
197                order_by: None,
198                null_treatment: Some(NullTreatment::IgnoreNulls),
199            },
200        }),
201        AggregateOp::Stdev => Expr::AggregateFunction(expr::AggregateFunction {
202            func: stddev_udaf(),
203            params: AggregateFunctionParams {
204                distinct: false,
205                args: vec![numeric_column()?],
206                filter: None,
207                order_by: None,
208                null_treatment: Some(NullTreatment::IgnoreNulls),
209            },
210        }),
211        AggregateOp::Stdevp => Expr::AggregateFunction(expr::AggregateFunction {
212            func: stddev_pop_udaf(),
213            params: AggregateFunctionParams {
214                distinct: false,
215                args: vec![numeric_column()?],
216                filter: None,
217                order_by: None,
218                null_treatment: Some(NullTreatment::IgnoreNulls),
219            },
220        }),
221        AggregateOp::Valid => {
222            let valid = Expr::Cast(expr::Cast {
223                expr: Box::new(Expr::IsNotNull(Box::new(column))),
224                data_type: DataType::Int64,
225            });
226            sum(valid)
227        }
228        AggregateOp::Missing => {
229            let missing = Expr::Cast(expr::Cast {
230                expr: Box::new(Expr::IsNull(Box::new(column))),
231                data_type: DataType::Int64,
232            });
233            sum(missing)
234        }
235        AggregateOp::Distinct => {
236            // Vega counts null as a distinct category but SQL does not
237            let missing = Expr::Cast(expr::Cast {
238                expr: Box::new(Expr::IsNull(Box::new(column.clone()))),
239                data_type: DataType::Int64,
240            });
241            count_distinct(column) + max(missing)
242        }
243        AggregateOp::Q1 => Expr::AggregateFunction(expr::AggregateFunction {
244            func: Arc::new((*Q1_UDF).clone()),
245            params: AggregateFunctionParams {
246                distinct: false,
247                args: vec![numeric_column()?],
248                filter: None,
249                order_by: None,
250                null_treatment: Some(NullTreatment::IgnoreNulls),
251            },
252        }),
253        AggregateOp::Q3 => Expr::AggregateFunction(expr::AggregateFunction {
254            func: Arc::new((*Q3_UDF).clone()),
255            params: AggregateFunctionParams {
256                distinct: false,
257                args: vec![numeric_column()?],
258                filter: None,
259                order_by: None,
260                null_treatment: Some(NullTreatment::IgnoreNulls),
261            },
262        }),
263        _ => {
264            return Err(VegaFusionError::specification(format!(
265                "Unsupported aggregation op: {op:?}"
266            )))
267        }
268    };
269    Ok(agg_expr)
270}