vegafusion_runtime/transform/
fold.rs1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use async_trait::async_trait;
5use datafusion::prelude::DataFrame;
6use datafusion_common::ScalarValue;
7use datafusion_expr::{
8 expr, expr::WindowFunctionParams, lit, Expr, WindowFrame, WindowFunctionDefinition,
9};
10use datafusion_functions_window::row_number::RowNumber;
11use sqlparser::ast::NullTreatment;
12use std::sync::Arc;
13use vegafusion_common::column::flat_col;
14use vegafusion_common::data::ORDER_COL;
15use vegafusion_common::error::Result;
16use vegafusion_common::escape::unescape_field;
17use vegafusion_core::proto::gen::transforms::Fold;
18use vegafusion_core::task_graph::task_value::TaskValue;
19
20#[async_trait]
21impl TransformTrait for Fold {
22 async fn eval(
23 &self,
24 dataframe: DataFrame,
25 _config: &CompilationConfig,
26 ) -> Result<(DataFrame, Vec<TaskValue>)> {
27 let field_cols: Vec<_> = self.fields.iter().map(|f| unescape_field(f)).collect();
28 let key_col = unescape_field(
29 &self
30 .r#as
31 .first()
32 .cloned()
33 .unwrap_or_else(|| "key".to_string()),
34 );
35 let value_col = unescape_field(
36 &self
37 .r#as
38 .get(1)
39 .cloned()
40 .unwrap_or_else(|| "value".to_string()),
41 );
42
43 let input_selection = dataframe
46 .schema()
47 .fields()
48 .iter()
49 .filter_map(|f| {
50 if f.name() == &key_col || f.name() == &value_col {
51 None
52 } else {
53 Some(flat_col(f.name()))
54 }
55 })
56 .collect::<Vec<_>>();
57
58 let mut subquery_union: Option<DataFrame> = None;
60
61 let field_order_col = format!("{ORDER_COL}_field");
62 for (i, field) in field_cols.iter().enumerate() {
63 let mut subquery_selection = input_selection.clone();
65 subquery_selection.push(lit(field).alias(key_col.clone()));
66 if dataframe.schema().inner().column_with_name(field).is_some() {
67 subquery_selection.push(flat_col(field).alias(value_col.clone()));
69 } else {
70 subquery_selection.push(lit(ScalarValue::Null).alias(value_col.clone()));
72 }
73
74 subquery_selection.push(lit(i as u32).alias(&field_order_col));
76
77 let subquery_df = dataframe.clone().select(subquery_selection)?;
78 if let Some(union) = subquery_union {
79 subquery_union = Some(union.union(subquery_df)?);
80 } else {
81 subquery_union = Some(subquery_df);
82 }
83 }
84
85 let Some(subquery_union) = subquery_union else {
87 return Ok((dataframe, Default::default()));
89 };
90
91 let mut final_selections = dataframe
93 .schema()
94 .fields()
95 .iter()
96 .filter_map(|f| {
97 if f.name() == ORDER_COL {
98 None
99 } else {
100 Some(flat_col(f.name()))
101 }
102 })
103 .collect::<Vec<_>>();
104
105 final_selections.push(flat_col(&key_col));
107 final_selections.push(flat_col(&value_col));
108
109 let final_order_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
111 fun: WindowFunctionDefinition::WindowUDF(Arc::new(RowNumber::new().into())),
112 params: WindowFunctionParams {
113 args: vec![],
114 partition_by: vec![],
115 order_by: vec![
116 expr::Sort {
117 expr: flat_col(ORDER_COL),
118 asc: true,
119 nulls_first: true,
120 },
121 expr::Sort {
122 expr: flat_col(&field_order_col),
123 asc: true,
124 nulls_first: true,
125 },
126 ],
127 window_frame: WindowFrame::new(Some(true)),
128 null_treatment: Some(NullTreatment::IgnoreNulls),
129 },
130 }))
131 .alias(ORDER_COL);
132 final_selections.push(final_order_expr);
133
134 Ok((subquery_union.select(final_selections)?, Default::default()))
135 }
136}