vegafusion_runtime/transform/
joinaggregate.rs

1use crate::data::util::DataFrameUtils;
2use crate::expression::compiler::config::CompilationConfig;
3use crate::transform::aggregate::make_aggr_expr_for_named_col;
4use crate::transform::TransformTrait;
5use async_trait::async_trait;
6use datafusion::prelude::DataFrame;
7use datafusion_common::JoinType;
8use datafusion_expr::lit;
9use vegafusion_common::column::{relation_col, unescaped_col};
10use vegafusion_common::escape::escape_field;
11use vegafusion_core::error::Result;
12use vegafusion_core::proto::gen::transforms::{AggregateOp, JoinAggregate};
13use vegafusion_core::task_graph::task_value::TaskValue;
14use vegafusion_core::transform::aggregate::op_name;
15
16#[async_trait]
17impl TransformTrait for JoinAggregate {
18    async fn eval(
19        &self,
20        dataframe: DataFrame,
21        _config: &CompilationConfig,
22    ) -> Result<(DataFrame, Vec<TaskValue>)> {
23        let group_exprs: Vec<_> = self.groupby.iter().map(|c| unescaped_col(c)).collect();
24        let schema = dataframe.schema();
25
26        let mut agg_exprs = Vec::new();
27        let mut new_col_names = Vec::new();
28        for (i, (field, op)) in self.fields.iter().zip(&self.ops).enumerate() {
29            let op = AggregateOp::try_from(*op).unwrap();
30            let alias = if let Some(alias) = self.aliases.get(i).filter(|a| !a.is_empty()) {
31                // Alias is a non-empty string
32                alias.clone()
33            } else if field.is_empty() {
34                op_name(op).to_string()
35            } else {
36                format!("{}_{}", op_name(op), field)
37            };
38
39            let agg_expr = if matches!(op, AggregateOp::Count) {
40                // In Vega, the provided column is always ignored if op is 'count'.
41                make_aggr_expr_for_named_col(None, &op, schema)?
42            } else {
43                make_aggr_expr_for_named_col(Some(field.clone()), &op, schema)?
44            };
45
46            // Apply alias
47            let agg_expr = agg_expr.alias(&alias);
48
49            // Collect new column aliases
50            new_col_names.push(alias);
51
52            agg_exprs.push(agg_expr);
53        }
54        // Perform regular aggregation on clone of input DataFrame
55        let agged_df = dataframe
56            .clone()
57            .aggregate_mixed(group_exprs, agg_exprs)?
58            .alias("rhs")?;
59
60        // Join with the input dataframe on the grouping columns
61        let mut on = self
62            .groupby
63            .iter()
64            .map(|g| {
65                relation_col(&escape_field(g), "lhs").eq(relation_col(&escape_field(g), "rhs"))
66            })
67            .collect::<Vec<_>>();
68
69        // If there are no groupby columns, use a dummy always-true condition
70        // This is needed because empty join conditions are not allowed
71        if on.is_empty() {
72            on.push(lit(true));
73        }
74
75        let mut final_selections = dataframe
76            .schema()
77            .fields()
78            .iter()
79            .filter_map(|f| {
80                if new_col_names.contains(f.name()) {
81                    None
82                } else {
83                    // Add alias to ensure unqualified column name in result
84                    Some(relation_col(f.name(), "lhs").alias(f.name()))
85                }
86            })
87            .collect::<Vec<_>>();
88        for col in &new_col_names {
89            // Add alias to ensure unqualified column name in result
90            final_selections.push(relation_col(col, "rhs").alias(col));
91        }
92
93        let result = dataframe
94            .clone()
95            .alias("lhs")?
96            .join_on(agged_df, JoinType::Left, on)?
97            .select(final_selections)?;
98
99        Ok((result, Vec::new()))
100    }
101}