vegafusion_runtime/transform/
stack.rs1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4use datafusion::prelude::DataFrame;
5use datafusion_common::JoinType;
6use datafusion_expr::{
7 expr, expr::AggregateFunctionParams, expr::WindowFunctionParams, lit, when, Expr, WindowFrame,
8 WindowFunctionDefinition,
9};
10use datafusion_functions::expr_fn::abs;
11use datafusion_functions_aggregate::expr_fn::max;
12use datafusion_functions_aggregate::sum::sum_udaf;
13use sqlparser::ast::NullTreatment;
14use std::ops::{Add, Div, Sub};
15use vegafusion_common::column::{flat_col, relation_col, unescaped_col};
16use vegafusion_common::data::ORDER_COL;
17use vegafusion_common::datatypes::to_numeric;
18use vegafusion_common::error::{Result, VegaFusionError};
19use vegafusion_common::escape::unescape_field;
20use vegafusion_core::proto::gen::transforms::{SortOrder, Stack, StackOffset};
21use vegafusion_core::task_graph::task_value::TaskValue;
22
23#[async_trait]
24impl TransformTrait for Stack {
25 async fn eval(
26 &self,
27 dataframe: DataFrame,
28 _config: &CompilationConfig,
29 ) -> Result<(DataFrame, Vec<TaskValue>)> {
30 let start_field = self.alias_0.clone().expect("alias0 expected");
31 let stop_field = self.alias_1.clone().expect("alias1 expected");
32
33 let field = unescape_field(&self.field);
34 let group_by: Vec<_> = self.groupby.iter().map(|f| unescape_field(f)).collect();
35
36 let input_fields: Vec<_> = dataframe
38 .schema()
39 .fields()
40 .iter()
41 .map(|f| f.name().clone())
42 .collect();
43
44 let mut order_by: Vec<_> = self
46 .sort_fields
47 .iter()
48 .zip(&self.sort)
49 .map(|(field, order)| expr::Sort {
50 expr: unescaped_col(field),
51 asc: *order == SortOrder::Ascending as i32,
52 nulls_first: *order == SortOrder::Ascending as i32,
53 })
54 .collect();
55
56 order_by.push(expr::Sort {
58 expr: flat_col(ORDER_COL),
59 asc: true,
60 nulls_first: true,
61 });
62
63 let offset = StackOffset::try_from(self.offset).expect("Failed to convert stack offset");
64
65 let partition_by: Vec<_> = group_by.iter().map(|group| flat_col(group)).collect();
67
68 let numeric_field_expr = to_numeric(flat_col(&field), dataframe.schema())?;
70 let numeric_field =
71 when(numeric_field_expr.clone().is_null(), lit(0.0)).otherwise(numeric_field_expr)?;
72
73 if let StackOffset::Zero = offset {
74 let window_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
76 fun: WindowFunctionDefinition::AggregateUDF(sum_udaf()),
77 params: WindowFunctionParams {
78 args: vec![numeric_field.clone()],
79 partition_by,
80 order_by,
81 window_frame: WindowFrame::new(Some(true)),
82 null_treatment: Some(NullTreatment::IgnoreNulls),
83 },
84 }));
85
86 let mut select_exprs = dataframe
88 .schema()
89 .fields()
90 .iter()
91 .filter_map(|f| {
92 if f.name() == &start_field || f.name() == &stop_field {
93 None
95 } else {
96 Some(flat_col(f.name()))
97 }
98 })
99 .collect::<Vec<_>>();
100
101 select_exprs.push(window_expr.alias(&stop_field));
103
104 let pos_df = dataframe
109 .clone()
110 .filter(numeric_field.clone().gt_eq(lit(0)))?
111 .select(select_exprs.clone())?;
112
113 let neg_df = dataframe
114 .clone()
115 .filter(numeric_field.clone().lt(lit(0)))?
116 .select(select_exprs)?;
117
118 let unioned_df = pos_df.union(neg_df)?;
120
121 let result_df = unioned_df.with_column(
123 &start_field,
124 flat_col(&stop_field).sub(numeric_field.clone()),
125 )?;
126
127 Ok((result_df, Default::default()))
128 } else {
129 let numeric_field = abs(numeric_field);
133
134 let stack_col_name = "__stack";
136 let dataframe = dataframe.select(vec![
137 datafusion_expr::expr_fn::wildcard(),
138 numeric_field.alias(stack_col_name).into(),
139 ])?;
140
141 let total_agg = Expr::AggregateFunction(expr::AggregateFunction {
143 func: sum_udaf(),
144 params: AggregateFunctionParams {
145 args: vec![flat_col(stack_col_name)],
146 distinct: false,
147 filter: None,
148 order_by: None,
149 null_treatment: Some(NullTreatment::IgnoreNulls),
150 },
151 })
152 .alias("__total");
153
154 let (dataframe, main_alias) = if partition_by.is_empty() {
156 let dataframe_with_key = dataframe.with_column("__join_key", lit(1))?;
159 let agg_df = dataframe_with_key
160 .clone()
161 .aggregate(vec![], vec![total_agg])?
162 .with_column("__join_key", lit(1))?
163 .alias("agg")?;
164
165 let joined = dataframe_with_key.alias("orig")?.join_on(
167 agg_df,
168 JoinType::Inner,
169 vec![relation_col("__join_key", "orig").eq(relation_col("__join_key", "agg"))],
170 )?;
171 (joined, "orig")
172 } else {
173 let on_exprs = group_by
175 .iter()
176 .map(|p| relation_col(p, "lhs").eq(relation_col(p, "rhs")))
177 .collect::<Vec<_>>();
178
179 let lhs_df = dataframe
180 .clone()
181 .aggregate(partition_by.clone(), vec![total_agg])?
182 .alias("lhs")?;
183 let rhs_df = dataframe.alias("rhs")?;
184 let joined = lhs_df.join_on(rhs_df, JoinType::Inner, on_exprs)?;
185 (joined, "rhs")
186 };
187
188 let cumulative_field = "_cumulative";
190 let fun = WindowFunctionDefinition::AggregateUDF(sum_udaf());
191
192 let partition_by_qualified: Vec<_> = group_by
194 .iter()
195 .map(|group| relation_col(group, main_alias))
196 .collect();
197
198 let order_by_qualified: Vec<_> = self
199 .sort_fields
200 .iter()
201 .zip(&self.sort)
202 .map(|(field, order)| expr::Sort {
203 expr: relation_col(&unescape_field(field), main_alias),
204 asc: *order == SortOrder::Ascending as i32,
205 nulls_first: *order == SortOrder::Ascending as i32,
206 })
207 .chain(std::iter::once(expr::Sort {
208 expr: relation_col(ORDER_COL, main_alias),
209 asc: true,
210 nulls_first: true,
211 }))
212 .collect();
213
214 let window_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
215 fun,
216 params: WindowFunctionParams {
217 args: vec![relation_col(stack_col_name, main_alias)],
218 partition_by: partition_by_qualified,
219 order_by: order_by_qualified,
220 window_frame: WindowFrame::new(Some(true)),
221 null_treatment: Some(NullTreatment::IgnoreNulls),
222 },
223 }))
224 .alias(cumulative_field);
225
226 let dataframe = if partition_by.is_empty() {
229 let mut select_exprs: Vec<Expr> = Vec::new();
231 for field in &input_fields {
232 select_exprs.push(relation_col(field, "orig").alias(field));
233 }
234 select_exprs.push(relation_col("__stack", "orig").alias("__stack"));
236 select_exprs.push(relation_col("__total", "agg").alias("__total"));
237 select_exprs.push(window_expr.into());
239 dataframe.select(select_exprs)?
240 } else {
241 let mut select_exprs: Vec<Expr> = Vec::new();
243 for field in &input_fields {
244 select_exprs.push(relation_col(field, main_alias).alias(field));
245 }
246 select_exprs.push(relation_col("__stack", main_alias).alias("__stack"));
248 select_exprs.push(relation_col("__total", "lhs").alias("__total"));
249 select_exprs.push(window_expr.into());
251 dataframe.select(select_exprs)?
252 };
253
254 let mut final_selection: Vec<_> = input_fields
256 .iter()
257 .filter_map(|field| {
258 if field == &start_field || field == &stop_field {
259 None
260 } else {
261 Some(flat_col(field))
262 }
263 })
264 .collect();
265
266 let dataframe = match offset {
268 StackOffset::Center => {
269 let max_total = max(flat_col("__total")).alias("__max_total");
270
271 let orig_fields: Vec<String> = dataframe
273 .schema()
274 .fields()
275 .iter()
276 .map(|f| f.name().clone())
277 .collect();
278
279 let mut select_exprs: Vec<Expr> =
281 orig_fields.iter().map(|name| flat_col(name)).collect();
282 select_exprs.push(lit(1).alias("__join_key"));
283 let dataframe_with_key = dataframe.select(select_exprs)?;
284
285 let agg_df = dataframe_with_key
287 .clone()
288 .aggregate(vec![flat_col("__join_key")], vec![max_total])?
289 .alias("agg")?;
290
291 let joined = dataframe_with_key.alias("orig")?.join_on(
293 agg_df,
294 JoinType::Inner,
295 vec![relation_col("__join_key", "orig")
296 .eq(relation_col("__join_key", "agg"))],
297 )?;
298
299 let mut select_cols: Vec<Expr> = orig_fields
302 .iter()
303 .map(|name| relation_col(name, "orig").alias(name))
304 .collect();
305 select_cols.push(relation_col("__max_total", "agg").alias("__max_total"));
306
307 let dataframe = joined.select(select_cols)?;
308
309 let mut center_final_selection: Vec<_> = input_fields
311 .iter()
312 .filter_map(|field| {
313 if field == &start_field
314 || field == &stop_field
315 || field.starts_with("__")
316 {
317 None
318 } else {
319 Some(flat_col(field))
320 }
321 })
322 .collect();
323
324 let first = flat_col("__max_total")
326 .sub(flat_col("__total"))
327 .div(lit(2.0));
328 let first_col = flat_col(cumulative_field).add(first);
329 let stop_col = first_col.clone().alias(stop_field);
330 let start_col = first_col.sub(flat_col(stack_col_name)).alias(start_field);
331 center_final_selection.push(start_col);
332 center_final_selection.push(stop_col);
333
334 dataframe.select(center_final_selection)?
335 }
336 StackOffset::Normalize => {
337 let total_zero = flat_col("__total").eq(lit(0.0));
338
339 let start_col = when(total_zero.clone(), lit(0.0))
340 .otherwise(
341 flat_col(cumulative_field)
342 .sub(flat_col(stack_col_name))
343 .div(flat_col("__total")),
344 )?
345 .alias(start_field);
346
347 final_selection.push(start_col);
348
349 let stop_col = when(total_zero, lit(0.0))
350 .otherwise(flat_col(cumulative_field).div(flat_col("__total")))?
351 .alias(stop_field);
352
353 final_selection.push(stop_col);
354
355 dataframe
356 }
357 _ => return Err(VegaFusionError::internal("Unexpected stack mode")),
358 };
359
360 match offset {
361 StackOffset::Center => Ok((dataframe, Default::default())),
362 _ => Ok((dataframe.select(final_selection)?, Default::default())),
363 }
364 }
365 }
366}