vegafusion_runtime/transform/
collect.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use datafusion_expr::{expr, expr::WindowFunctionParams, Expr, WindowFunctionDefinition};
5use datafusion_functions_window::row_number::RowNumber;
6use sqlparser::ast::NullTreatment;
7
8use std::sync::Arc;
9use vegafusion_core::error::{Result, ResultWithContext};
10use vegafusion_core::proto::gen::transforms::{Collect, SortOrder};
11
12use async_trait::async_trait;
13use datafusion::prelude::DataFrame;
14use datafusion_expr::WindowFrame;
15use vegafusion_common::column::{flat_col, unescaped_col};
16use vegafusion_common::data::ORDER_COL;
17use vegafusion_core::task_graph::task_value::TaskValue;
18
19#[async_trait]
20impl TransformTrait for Collect {
21    async fn eval(
22        &self,
23        dataframe: DataFrame,
24        _config: &CompilationConfig,
25    ) -> Result<(DataFrame, Vec<TaskValue>)> {
26        // Build vector of sort expressions
27        let sort_exprs: Vec<_> = self
28            .fields
29            .clone()
30            .into_iter()
31            .zip(&self.order)
32            .filter_map(|(field, order)| {
33                if dataframe
34                    .schema()
35                    .inner()
36                    .column_with_name(&field)
37                    .is_some()
38                {
39                    let sort_expr = unescaped_col(&field).sort(
40                        *order == SortOrder::Ascending as i32,
41                        *order == SortOrder::Ascending as i32,
42                    );
43                    Some(sort_expr)
44                } else {
45                    None
46                }
47            })
48            .collect();
49
50        // We don't actually sort here, use a row number window function sorted by the sort
51        // criteria. This column becomes the new ORDER_COL, which will be sorted at the end of
52        // the pipeline.
53        let order_col = Expr::WindowFunction(Box::new(expr::WindowFunction {
54            fun: WindowFunctionDefinition::WindowUDF(Arc::new(RowNumber::new().into())),
55            params: WindowFunctionParams {
56                args: vec![],
57                partition_by: vec![],
58                order_by: sort_exprs,
59                window_frame: WindowFrame::new(Some(true)),
60                null_treatment: Some(NullTreatment::IgnoreNulls),
61            },
62        }))
63        .alias(ORDER_COL);
64
65        // Build vector of selections
66        let mut selections = dataframe
67            .schema()
68            .inner()
69            .fields
70            .iter()
71            .filter_map(|field| {
72                if field.name() == ORDER_COL {
73                    None
74                } else {
75                    Some(flat_col(field.name()))
76                }
77            })
78            .collect::<Vec<_>>();
79        selections.insert(0, order_col);
80
81        let result = dataframe
82            .select(selections)
83            .with_context(|| "Collect transform failed".to_string())?;
84        Ok((result, Default::default()))
85    }
86}