vegafusion_runtime/transform/
joinaggregate.rs1use crate::data::util::DataFrameUtils;
2use crate::expression::compiler::config::CompilationConfig;
3use crate::transform::aggregate::make_aggr_expr_for_named_col;
4use crate::transform::TransformTrait;
5use async_trait::async_trait;
6use datafusion::prelude::DataFrame;
7use datafusion_common::JoinType;
8use datafusion_expr::lit;
9use vegafusion_common::column::{relation_col, unescaped_col};
10use vegafusion_common::escape::escape_field;
11use vegafusion_core::error::Result;
12use vegafusion_core::proto::gen::transforms::{AggregateOp, JoinAggregate};
13use vegafusion_core::task_graph::task_value::TaskValue;
14use vegafusion_core::transform::aggregate::op_name;
15
16#[async_trait]
17impl TransformTrait for JoinAggregate {
18 async fn eval(
19 &self,
20 dataframe: DataFrame,
21 _config: &CompilationConfig,
22 ) -> Result<(DataFrame, Vec<TaskValue>)> {
23 let group_exprs: Vec<_> = self.groupby.iter().map(|c| unescaped_col(c)).collect();
24 let schema = dataframe.schema();
25
26 let mut agg_exprs = Vec::new();
27 let mut new_col_names = Vec::new();
28 for (i, (field, op)) in self.fields.iter().zip(&self.ops).enumerate() {
29 let op = AggregateOp::try_from(*op).unwrap();
30 let alias = if let Some(alias) = self.aliases.get(i).filter(|a| !a.is_empty()) {
31 alias.clone()
33 } else if field.is_empty() {
34 op_name(op).to_string()
35 } else {
36 format!("{}_{}", op_name(op), field)
37 };
38
39 let agg_expr = if matches!(op, AggregateOp::Count) {
40 make_aggr_expr_for_named_col(None, &op, schema)?
42 } else {
43 make_aggr_expr_for_named_col(Some(field.clone()), &op, schema)?
44 };
45
46 let agg_expr = agg_expr.alias(&alias);
48
49 new_col_names.push(alias);
51
52 agg_exprs.push(agg_expr);
53 }
54 let agged_df = dataframe
56 .clone()
57 .aggregate_mixed(group_exprs, agg_exprs)?
58 .alias("rhs")?;
59
60 let mut on = self
62 .groupby
63 .iter()
64 .map(|g| {
65 relation_col(&escape_field(g), "lhs").eq(relation_col(&escape_field(g), "rhs"))
66 })
67 .collect::<Vec<_>>();
68
69 if on.is_empty() {
72 on.push(lit(true));
73 }
74
75 let mut final_selections = dataframe
76 .schema()
77 .fields()
78 .iter()
79 .filter_map(|f| {
80 if new_col_names.contains(f.name()) {
81 None
82 } else {
83 Some(relation_col(f.name(), "lhs").alias(f.name()))
85 }
86 })
87 .collect::<Vec<_>>();
88 for col in &new_col_names {
89 final_selections.push(relation_col(col, "rhs").alias(col));
91 }
92
93 let result = dataframe
94 .clone()
95 .alias("lhs")?
96 .join_on(agged_df, JoinType::Left, on)?
97 .select(final_selections)?;
98
99 Ok((result, Vec::new()))
100 }
101}