vegafusion_runtime/transform/
fold.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use async_trait::async_trait;
5use datafusion::prelude::DataFrame;
6use datafusion_common::ScalarValue;
7use datafusion_expr::{
8    expr, expr::WindowFunctionParams, lit, Expr, WindowFrame, WindowFunctionDefinition,
9};
10use datafusion_functions_window::row_number::RowNumber;
11use sqlparser::ast::NullTreatment;
12use std::sync::Arc;
13use vegafusion_common::column::flat_col;
14use vegafusion_common::data::ORDER_COL;
15use vegafusion_common::error::Result;
16use vegafusion_common::escape::unescape_field;
17use vegafusion_core::proto::gen::transforms::Fold;
18use vegafusion_core::task_graph::task_value::TaskValue;
19
20#[async_trait]
21impl TransformTrait for Fold {
22    async fn eval(
23        &self,
24        dataframe: DataFrame,
25        _config: &CompilationConfig,
26    ) -> Result<(DataFrame, Vec<TaskValue>)> {
27        let field_cols: Vec<_> = self.fields.iter().map(|f| unescape_field(f)).collect();
28        let key_col = unescape_field(
29            &self
30                .r#as
31                .first()
32                .cloned()
33                .unwrap_or_else(|| "key".to_string()),
34        );
35        let value_col = unescape_field(
36            &self
37                .r#as
38                .get(1)
39                .cloned()
40                .unwrap_or_else(|| "value".to_string()),
41        );
42
43        // Build selection that includes all input fields that
44        // aren't shadowed by key/value cols
45        let input_selection = dataframe
46            .schema()
47            .fields()
48            .iter()
49            .filter_map(|f| {
50                if f.name() == &key_col || f.name() == &value_col {
51                    None
52                } else {
53                    Some(flat_col(f.name()))
54                }
55            })
56            .collect::<Vec<_>>();
57
58        // Build union of subqueries that select and rename each field
59        let mut subquery_union: Option<DataFrame> = None;
60
61        let field_order_col = format!("{ORDER_COL}_field");
62        for (i, field) in field_cols.iter().enumerate() {
63            // Clone input selection and add key/val cols to it
64            let mut subquery_selection = input_selection.clone();
65            subquery_selection.push(lit(field).alias(key_col.clone()));
66            if dataframe.schema().inner().column_with_name(field).is_some() {
67                // Field exists as a column in the parent table
68                subquery_selection.push(flat_col(field).alias(value_col.clone()));
69            } else {
70                // Field does not exist in parent table, fill in NULL instead
71                subquery_selection.push(lit(ScalarValue::Null).alias(value_col.clone()));
72            }
73
74            // Add order column
75            subquery_selection.push(lit(i as u32).alias(&field_order_col));
76
77            let subquery_df = dataframe.clone().select(subquery_selection)?;
78            if let Some(union) = subquery_union {
79                subquery_union = Some(union.union(subquery_df)?);
80            } else {
81                subquery_union = Some(subquery_df);
82            }
83        }
84
85        // Unwrap
86        let Some(subquery_union) = subquery_union else {
87            // Return input dataframe as-is
88            return Ok((dataframe, Default::default()));
89        };
90
91        // Compute final selection, start with all the non-order input columns
92        let mut final_selections = dataframe
93            .schema()
94            .fields()
95            .iter()
96            .filter_map(|f| {
97                if f.name() == ORDER_COL {
98                    None
99                } else {
100                    Some(flat_col(f.name()))
101                }
102            })
103            .collect::<Vec<_>>();
104
105        // Add key and value columns
106        final_selections.push(flat_col(&key_col));
107        final_selections.push(flat_col(&value_col));
108
109        // Add new order column
110        let final_order_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
111            fun: WindowFunctionDefinition::WindowUDF(Arc::new(RowNumber::new().into())),
112            params: WindowFunctionParams {
113                args: vec![],
114                partition_by: vec![],
115                order_by: vec![
116                    expr::Sort {
117                        expr: flat_col(ORDER_COL),
118                        asc: true,
119                        nulls_first: true,
120                    },
121                    expr::Sort {
122                        expr: flat_col(&field_order_col),
123                        asc: true,
124                        nulls_first: true,
125                    },
126                ],
127                window_frame: WindowFrame::new(Some(true)),
128                null_treatment: Some(NullTreatment::IgnoreNulls),
129            },
130        }))
131        .alias(ORDER_COL);
132        final_selections.push(final_order_expr);
133
134        Ok((subquery_union.select(final_selections)?, Default::default()))
135    }
136}