vegafusion_runtime/transform/
stack.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4use datafusion::prelude::DataFrame;
5use datafusion_common::JoinType;
6use datafusion_expr::{
7    expr, expr::AggregateFunctionParams, expr::WindowFunctionParams, lit, when, Expr, WindowFrame,
8    WindowFunctionDefinition,
9};
10use datafusion_functions::expr_fn::abs;
11use datafusion_functions_aggregate::expr_fn::max;
12use datafusion_functions_aggregate::sum::sum_udaf;
13use sqlparser::ast::NullTreatment;
14use std::ops::{Add, Div, Sub};
15use vegafusion_common::column::{flat_col, relation_col, unescaped_col};
16use vegafusion_common::data::ORDER_COL;
17use vegafusion_common::datatypes::to_numeric;
18use vegafusion_common::error::{Result, VegaFusionError};
19use vegafusion_common::escape::unescape_field;
20use vegafusion_core::proto::gen::transforms::{SortOrder, Stack, StackOffset};
21use vegafusion_core::task_graph::task_value::TaskValue;
22
23#[async_trait]
24impl TransformTrait for Stack {
25    async fn eval(
26        &self,
27        dataframe: DataFrame,
28        _config: &CompilationConfig,
29    ) -> Result<(DataFrame, Vec<TaskValue>)> {
30        let start_field = self.alias_0.clone().expect("alias0 expected");
31        let stop_field = self.alias_1.clone().expect("alias1 expected");
32
33        let field = unescape_field(&self.field);
34        let group_by: Vec<_> = self.groupby.iter().map(|f| unescape_field(f)).collect();
35
36        // Save off input columns
37        let input_fields: Vec<_> = dataframe
38            .schema()
39            .fields()
40            .iter()
41            .map(|f| f.name().clone())
42            .collect();
43
44        // Build order by vector
45        let mut order_by: Vec<_> = self
46            .sort_fields
47            .iter()
48            .zip(&self.sort)
49            .map(|(field, order)| expr::Sort {
50                expr: unescaped_col(field),
51                asc: *order == SortOrder::Ascending as i32,
52                nulls_first: *order == SortOrder::Ascending as i32,
53            })
54            .collect();
55
56        // Order by input row ordering last
57        order_by.push(expr::Sort {
58            expr: flat_col(ORDER_COL),
59            asc: true,
60            nulls_first: true,
61        });
62
63        let offset = StackOffset::try_from(self.offset).expect("Failed to convert stack offset");
64
65        // Build partitioning column expressions
66        let partition_by: Vec<_> = group_by.iter().map(|group| flat_col(group)).collect();
67
68        // Convert field to numeric first, then handle nulls with CASE expression
69        let numeric_field_expr = to_numeric(flat_col(&field), dataframe.schema())?;
70        let numeric_field =
71            when(numeric_field_expr.clone().is_null(), lit(0.0)).otherwise(numeric_field_expr)?;
72
73        if let StackOffset::Zero = offset {
74            // Build window function to compute stacked value
75            let window_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
76                fun: WindowFunctionDefinition::AggregateUDF(sum_udaf()),
77                params: WindowFunctionParams {
78                    args: vec![numeric_field.clone()],
79                    partition_by,
80                    order_by,
81                    window_frame: WindowFrame::new(Some(true)),
82                    null_treatment: Some(NullTreatment::IgnoreNulls),
83                },
84            }));
85
86            // Initialize selection with all columns, minus those that conflict with start/stop fields
87            let mut select_exprs = dataframe
88                .schema()
89                .fields()
90                .iter()
91                .filter_map(|f| {
92                    if f.name() == &start_field || f.name() == &stop_field {
93                        // Skip fields to be overwritten
94                        None
95                    } else {
96                        Some(flat_col(f.name()))
97                    }
98                })
99                .collect::<Vec<_>>();
100
101            // Add stop window expr
102            select_exprs.push(window_expr.alias(&stop_field));
103
104            // For offset zero, we need to evaluate positive and negative field values separately,
105            // then union the results. This is required to make sure stacks do not overlap. Negative
106            // values stack in the negative direction and positive values stack in the positive
107            // direction.
108            let pos_df = dataframe
109                .clone()
110                .filter(numeric_field.clone().gt_eq(lit(0)))?
111                .select(select_exprs.clone())?;
112
113            let neg_df = dataframe
114                .clone()
115                .filter(numeric_field.clone().lt(lit(0)))?
116                .select(select_exprs)?;
117
118            // Union
119            let unioned_df = pos_df.union(neg_df)?;
120
121            // Add start window expr
122            let result_df = unioned_df.with_column(
123                &start_field,
124                flat_col(&stop_field).sub(numeric_field.clone()),
125            )?;
126
127            Ok((result_df, Default::default()))
128        } else {
129            // Center or Normalized stack modes
130
131            // take absolute value of numeric field
132            let numeric_field = abs(numeric_field);
133
134            // Create __stack column with numeric field
135            let stack_col_name = "__stack";
136            let dataframe = dataframe.select(vec![
137                datafusion_expr::expr_fn::wildcard(),
138                numeric_field.alias(stack_col_name).into(),
139            ])?;
140
141            // Create aggregate for total of stack value
142            let total_agg = Expr::AggregateFunction(expr::AggregateFunction {
143                func: sum_udaf(),
144                params: AggregateFunctionParams {
145                    args: vec![flat_col(stack_col_name)],
146                    distinct: false,
147                    filter: None,
148                    order_by: None,
149                    null_treatment: Some(NullTreatment::IgnoreNulls),
150                },
151            })
152            .alias("__total");
153
154            // Determine the alias for the main dataframe based on whether we have grouping
155            let (dataframe, main_alias) = if partition_by.is_empty() {
156                // Cross join total aggregation
157                // Add dummy join key for cross join since empty join conditions are not allowed
158                let dataframe_with_key = dataframe.with_column("__join_key", lit(1))?;
159                let agg_df = dataframe_with_key
160                    .clone()
161                    .aggregate(vec![], vec![total_agg])?
162                    .with_column("__join_key", lit(1))?
163                    .alias("agg")?;
164
165                // Join on the dummy key
166                let joined = dataframe_with_key.alias("orig")?.join_on(
167                    agg_df,
168                    JoinType::Inner,
169                    vec![relation_col("__join_key", "orig").eq(relation_col("__join_key", "agg"))],
170                )?;
171                (joined, "orig")
172            } else {
173                // Join back total aggregation
174                let on_exprs = group_by
175                    .iter()
176                    .map(|p| relation_col(p, "lhs").eq(relation_col(p, "rhs")))
177                    .collect::<Vec<_>>();
178
179                let lhs_df = dataframe
180                    .clone()
181                    .aggregate(partition_by.clone(), vec![total_agg])?
182                    .alias("lhs")?;
183                let rhs_df = dataframe.alias("rhs")?;
184                let joined = lhs_df.join_on(rhs_df, JoinType::Inner, on_exprs)?;
185                (joined, "rhs")
186            };
187
188            // Build window function to compute cumulative sum of stack column
189            let cumulative_field = "_cumulative";
190            let fun = WindowFunctionDefinition::AggregateUDF(sum_udaf());
191
192            // Update partition_by and order_by to use qualified column references after join
193            let partition_by_qualified: Vec<_> = group_by
194                .iter()
195                .map(|group| relation_col(group, main_alias))
196                .collect();
197
198            let order_by_qualified: Vec<_> = self
199                .sort_fields
200                .iter()
201                .zip(&self.sort)
202                .map(|(field, order)| expr::Sort {
203                    expr: relation_col(&unescape_field(field), main_alias),
204                    asc: *order == SortOrder::Ascending as i32,
205                    nulls_first: *order == SortOrder::Ascending as i32,
206                })
207                .chain(std::iter::once(expr::Sort {
208                    expr: relation_col(ORDER_COL, main_alias),
209                    asc: true,
210                    nulls_first: true,
211                }))
212                .collect();
213
214            let window_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
215                fun,
216                params: WindowFunctionParams {
217                    args: vec![relation_col(stack_col_name, main_alias)],
218                    partition_by: partition_by_qualified,
219                    order_by: order_by_qualified,
220                    window_frame: WindowFrame::new(Some(true)),
221                    null_treatment: Some(NullTreatment::IgnoreNulls),
222                },
223            }))
224            .alias(cumulative_field);
225
226            // Perform selection to add new field value
227            // After a join, we need to select all columns explicitly to handle aliases properly
228            let dataframe = if partition_by.is_empty() {
229                // For cross join case, select all columns from orig table
230                let mut select_exprs: Vec<Expr> = Vec::new();
231                for field in &input_fields {
232                    select_exprs.push(relation_col(field, "orig").alias(field));
233                }
234                // Also select __stack and __total
235                select_exprs.push(relation_col("__stack", "orig").alias("__stack"));
236                select_exprs.push(relation_col("__total", "agg").alias("__total"));
237                // Add the window expression
238                select_exprs.push(window_expr.into());
239                dataframe.select(select_exprs)?
240            } else {
241                // For grouped case, we also need to select columns explicitly to ensure proper aliases
242                let mut select_exprs: Vec<Expr> = Vec::new();
243                for field in &input_fields {
244                    select_exprs.push(relation_col(field, main_alias).alias(field));
245                }
246                // Also select __stack and __total
247                select_exprs.push(relation_col("__stack", main_alias).alias("__stack"));
248                select_exprs.push(relation_col("__total", "lhs").alias("__total"));
249                // Add the window expression
250                select_exprs.push(window_expr.into());
251                dataframe.select(select_exprs)?
252            };
253
254            // Build final_selection
255            let mut final_selection: Vec<_> = input_fields
256                .iter()
257                .filter_map(|field| {
258                    if field == &start_field || field == &stop_field {
259                        None
260                    } else {
261                        Some(flat_col(field))
262                    }
263                })
264                .collect();
265
266            // Now compute stop_field column by adding numeric field to start_field
267            let dataframe = match offset {
268                StackOffset::Center => {
269                    let max_total = max(flat_col("__total")).alias("__max_total");
270
271                    // Save original field names
272                    let orig_fields: Vec<String> = dataframe
273                        .schema()
274                        .fields()
275                        .iter()
276                        .map(|f| f.name().clone())
277                        .collect();
278
279                    // Create a dummy column for joining
280                    let mut select_exprs: Vec<Expr> =
281                        orig_fields.iter().map(|name| flat_col(name)).collect();
282                    select_exprs.push(lit(1).alias("__join_key"));
283                    let dataframe_with_key = dataframe.select(select_exprs)?;
284
285                    // Aggregate with the same dummy key
286                    let agg_df = dataframe_with_key
287                        .clone()
288                        .aggregate(vec![flat_col("__join_key")], vec![max_total])?
289                        .alias("agg")?;
290
291                    // Join on the dummy key
292                    let joined = dataframe_with_key.alias("orig")?.join_on(
293                        agg_df,
294                        JoinType::Inner,
295                        vec![relation_col("__join_key", "orig")
296                            .eq(relation_col("__join_key", "agg"))],
297                    )?;
298
299                    // Select all original columns plus __max_total (but not __join_key)
300                    // Add aliases to ensure unqualified column names in result
301                    let mut select_cols: Vec<Expr> = orig_fields
302                        .iter()
303                        .map(|name| relation_col(name, "orig").alias(name))
304                        .collect();
305                    select_cols.push(relation_col("__max_total", "agg").alias("__max_total"));
306
307                    let dataframe = joined.select(select_cols)?;
308
309                    // Build the final selection for Center case
310                    let mut center_final_selection: Vec<_> = input_fields
311                        .iter()
312                        .filter_map(|field| {
313                            if field == &start_field
314                                || field == &stop_field
315                                || field.starts_with("__")
316                            {
317                                None
318                            } else {
319                                Some(flat_col(field))
320                            }
321                        })
322                        .collect();
323
324                    // Add start and stop columns
325                    let first = flat_col("__max_total")
326                        .sub(flat_col("__total"))
327                        .div(lit(2.0));
328                    let first_col = flat_col(cumulative_field).add(first);
329                    let stop_col = first_col.clone().alias(stop_field);
330                    let start_col = first_col.sub(flat_col(stack_col_name)).alias(start_field);
331                    center_final_selection.push(start_col);
332                    center_final_selection.push(stop_col);
333
334                    dataframe.select(center_final_selection)?
335                }
336                StackOffset::Normalize => {
337                    let total_zero = flat_col("__total").eq(lit(0.0));
338
339                    let start_col = when(total_zero.clone(), lit(0.0))
340                        .otherwise(
341                            flat_col(cumulative_field)
342                                .sub(flat_col(stack_col_name))
343                                .div(flat_col("__total")),
344                        )?
345                        .alias(start_field);
346
347                    final_selection.push(start_col);
348
349                    let stop_col = when(total_zero, lit(0.0))
350                        .otherwise(flat_col(cumulative_field).div(flat_col("__total")))?
351                        .alias(stop_field);
352
353                    final_selection.push(stop_col);
354
355                    dataframe
356                }
357                _ => return Err(VegaFusionError::internal("Unexpected stack mode")),
358            };
359
360            match offset {
361                StackOffset::Center => Ok((dataframe, Default::default())),
362                _ => Ok((dataframe.select(final_selection)?, Default::default())),
363            }
364        }
365    }
366}