vegafusion_runtime/transform/
aggregate.rs1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use datafusion_expr::{expr::AggregateFunctionParams, lit, Expr};
5use datafusion_functions_aggregate::median::median_udaf;
6use datafusion_functions_aggregate::variance::{var_pop_udaf, var_samp_udaf};
7use sqlparser::ast::NullTreatment;
8use std::collections::HashMap;
9
10use crate::data::util::DataFrameUtils;
11use crate::datafusion::udafs::percentile::{Q1_UDF, Q3_UDF};
12use async_trait::async_trait;
13use datafusion::prelude::DataFrame;
14use datafusion_expr::expr;
15use datafusion_functions_aggregate::expr_fn::{avg, count, count_distinct, max, min, sum};
16use datafusion_functions_aggregate::stddev::{stddev_pop_udaf, stddev_udaf};
17use std::sync::Arc;
18use vegafusion_common::column::{flat_col, unescaped_col};
19use vegafusion_common::data::ORDER_COL;
20use vegafusion_common::datafusion_common::{DFSchema, ScalarValue};
21use vegafusion_common::datatypes::to_numeric;
22use vegafusion_common::error::ResultWithContext;
23use vegafusion_common::escape::unescape_field;
24use vegafusion_core::arrow::datatypes::DataType;
25use vegafusion_core::error::{Result, VegaFusionError};
26use vegafusion_core::proto::gen::transforms::{Aggregate, AggregateOp};
27use vegafusion_core::task_graph::task_value::TaskValue;
28use vegafusion_core::transform::aggregate::op_name;
29
30#[async_trait]
31impl TransformTrait for Aggregate {
32 async fn eval(
33 &self,
34 dataframe: DataFrame,
35 _config: &CompilationConfig,
36 ) -> Result<(DataFrame, Vec<TaskValue>)> {
37 let group_exprs: Vec<_> = self
38 .groupby
39 .iter()
40 .filter(|c| {
41 dataframe
42 .schema()
43 .inner()
44 .column_with_name(&unescape_field(c))
45 .is_some()
46 })
47 .map(|c| unescaped_col(c))
48 .collect();
49
50 let (mut agg_exprs, projections) = get_agg_and_proj_exprs(self, dataframe.schema())?;
51
52 agg_exprs.push(min(flat_col(ORDER_COL)).alias(ORDER_COL));
54
55 let grouped_dataframe = dataframe.aggregate_mixed(group_exprs, agg_exprs)?;
57
58 let grouped_dataframe = grouped_dataframe.select(projections)?;
60
61 Ok((grouped_dataframe, Vec::new()))
62 }
63}
64
65fn get_agg_and_proj_exprs(tx: &Aggregate, schema: &DFSchema) -> Result<(Vec<Expr>, Vec<Expr>)> {
66 let mut agg_aliases: HashMap<(Option<String>, i32), String> = HashMap::new();
70
71 let mut projections: Vec<_> = tx.groupby.iter().map(|f| unescaped_col(f)).collect();
73
74 projections.insert(0, flat_col(ORDER_COL));
76
77 for (i, (field, op_code)) in tx.fields.iter().zip(tx.ops.iter()).enumerate() {
78 let op = AggregateOp::try_from(*op_code).unwrap();
79
80 let column = if *op_code == AggregateOp::Count as i32 {
81 None
83 } else {
84 match field.as_str() {
85 "" => {
86 return Err(VegaFusionError::specification(format!(
87 "Null field is not allowed for {op:?} op"
88 )))
89 }
90 column => Some(column.to_string()),
91 }
92 };
93
94 let alias = if let Some(alias) = tx.aliases.get(i).filter(|a| !a.is_empty()) {
96 alias.clone()
98 } else if field.is_empty() {
99 op_name(op).to_string()
100 } else {
101 format!("{}_{}", op_name(op), field,)
102 };
103
104 let key = (column, *op_code);
105 if let Some(agg_alias) = agg_aliases.get(&key) {
106 projections.push(flat_col(agg_alias).alias(&alias));
108 } else {
109 projections.push(flat_col(&alias));
110 agg_aliases.insert(key, alias);
111 }
112 }
113
114 let mut agg_exprs = Vec::new();
115
116 for ((col_name, op_code), alias) in agg_aliases {
117 let op = AggregateOp::try_from(op_code).unwrap();
118
119 let agg_expr = make_aggr_expr_for_named_col(col_name, &op, schema)?;
120
121 let agg_expr = agg_expr.alias(alias);
123
124 agg_exprs.push(agg_expr)
125 }
126 Ok((agg_exprs, projections))
127}
128
129pub fn make_aggr_expr_for_named_col(
130 col_name: Option<String>,
131 op: &AggregateOp,
132 schema: &DFSchema,
133) -> Result<Expr> {
134 let column = if let Some(col_name) = col_name {
135 let col_name = unescape_field(&col_name);
136 if schema.index_of_column_by_name(None, &col_name).is_none() {
137 return if matches!(op, AggregateOp::Sum | AggregateOp::Count) {
139 Ok(lit(0))
141 } else {
142 Ok(lit(ScalarValue::Float64(None)))
144 };
145 } else {
146 flat_col(&col_name)
147 }
148 } else {
149 lit(0i32)
150 };
151
152 make_agg_expr_for_col_expr(column, op, schema)
153}
154
155pub fn make_agg_expr_for_col_expr(
156 column: Expr,
157 op: &AggregateOp,
158 schema: &DFSchema,
159) -> Result<Expr> {
160 let numeric_column = || -> Result<Expr> {
161 to_numeric(column.clone(), schema)
162 .with_context(|| format!("Failed to convert column {column:?} to numeric data type"))
163 };
164
165 let agg_expr = match op {
166 AggregateOp::Count => count(column),
167 AggregateOp::Mean | AggregateOp::Average => avg(numeric_column()?),
168 AggregateOp::Min => min(column),
169 AggregateOp::Max => max(column),
170 AggregateOp::Sum => sum(numeric_column()?),
171 AggregateOp::Median => Expr::AggregateFunction(expr::AggregateFunction {
172 func: median_udaf(),
173 params: AggregateFunctionParams {
174 distinct: false,
175 args: vec![numeric_column()?],
176 filter: None,
177 order_by: None,
178 null_treatment: Some(NullTreatment::IgnoreNulls),
179 },
180 }),
181 AggregateOp::Variance => Expr::AggregateFunction(expr::AggregateFunction {
182 func: var_samp_udaf(),
183 params: AggregateFunctionParams {
184 distinct: false,
185 args: vec![numeric_column()?],
186 filter: None,
187 order_by: None,
188 null_treatment: Some(NullTreatment::IgnoreNulls),
189 },
190 }),
191 AggregateOp::Variancep => Expr::AggregateFunction(expr::AggregateFunction {
192 func: var_pop_udaf(),
193 params: AggregateFunctionParams {
194 distinct: false,
195 args: vec![numeric_column()?],
196 filter: None,
197 order_by: None,
198 null_treatment: Some(NullTreatment::IgnoreNulls),
199 },
200 }),
201 AggregateOp::Stdev => Expr::AggregateFunction(expr::AggregateFunction {
202 func: stddev_udaf(),
203 params: AggregateFunctionParams {
204 distinct: false,
205 args: vec![numeric_column()?],
206 filter: None,
207 order_by: None,
208 null_treatment: Some(NullTreatment::IgnoreNulls),
209 },
210 }),
211 AggregateOp::Stdevp => Expr::AggregateFunction(expr::AggregateFunction {
212 func: stddev_pop_udaf(),
213 params: AggregateFunctionParams {
214 distinct: false,
215 args: vec![numeric_column()?],
216 filter: None,
217 order_by: None,
218 null_treatment: Some(NullTreatment::IgnoreNulls),
219 },
220 }),
221 AggregateOp::Valid => {
222 let valid = Expr::Cast(expr::Cast {
223 expr: Box::new(Expr::IsNotNull(Box::new(column))),
224 data_type: DataType::Int64,
225 });
226 sum(valid)
227 }
228 AggregateOp::Missing => {
229 let missing = Expr::Cast(expr::Cast {
230 expr: Box::new(Expr::IsNull(Box::new(column))),
231 data_type: DataType::Int64,
232 });
233 sum(missing)
234 }
235 AggregateOp::Distinct => {
236 let missing = Expr::Cast(expr::Cast {
238 expr: Box::new(Expr::IsNull(Box::new(column.clone()))),
239 data_type: DataType::Int64,
240 });
241 count_distinct(column) + max(missing)
242 }
243 AggregateOp::Q1 => Expr::AggregateFunction(expr::AggregateFunction {
244 func: Arc::new((*Q1_UDF).clone()),
245 params: AggregateFunctionParams {
246 distinct: false,
247 args: vec![numeric_column()?],
248 filter: None,
249 order_by: None,
250 null_treatment: Some(NullTreatment::IgnoreNulls),
251 },
252 }),
253 AggregateOp::Q3 => Expr::AggregateFunction(expr::AggregateFunction {
254 func: Arc::new((*Q3_UDF).clone()),
255 params: AggregateFunctionParams {
256 distinct: false,
257 args: vec![numeric_column()?],
258 filter: None,
259 order_by: None,
260 null_treatment: Some(NullTreatment::IgnoreNulls),
261 },
262 }),
263 _ => {
264 return Err(VegaFusionError::specification(format!(
265 "Unsupported aggregation op: {op:?}"
266 )))
267 }
268 };
269 Ok(agg_expr)
270}