vegafusion_runtime/transform/
collect.rs1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use datafusion_expr::{expr, expr::WindowFunctionParams, Expr, WindowFunctionDefinition};
5use datafusion_functions_window::row_number::RowNumber;
6use sqlparser::ast::NullTreatment;
7
8use std::sync::Arc;
9use vegafusion_core::error::{Result, ResultWithContext};
10use vegafusion_core::proto::gen::transforms::{Collect, SortOrder};
11
12use async_trait::async_trait;
13use datafusion::prelude::DataFrame;
14use datafusion_expr::WindowFrame;
15use vegafusion_common::column::{flat_col, unescaped_col};
16use vegafusion_common::data::ORDER_COL;
17use vegafusion_core::task_graph::task_value::TaskValue;
18
19#[async_trait]
20impl TransformTrait for Collect {
21 async fn eval(
22 &self,
23 dataframe: DataFrame,
24 _config: &CompilationConfig,
25 ) -> Result<(DataFrame, Vec<TaskValue>)> {
26 let sort_exprs: Vec<_> = self
28 .fields
29 .clone()
30 .into_iter()
31 .zip(&self.order)
32 .filter_map(|(field, order)| {
33 if dataframe
34 .schema()
35 .inner()
36 .column_with_name(&field)
37 .is_some()
38 {
39 let sort_expr = unescaped_col(&field).sort(
40 *order == SortOrder::Ascending as i32,
41 *order == SortOrder::Ascending as i32,
42 );
43 Some(sort_expr)
44 } else {
45 None
46 }
47 })
48 .collect();
49
50 let order_col = Expr::WindowFunction(Box::new(expr::WindowFunction {
54 fun: WindowFunctionDefinition::WindowUDF(Arc::new(RowNumber::new().into())),
55 params: WindowFunctionParams {
56 args: vec![],
57 partition_by: vec![],
58 order_by: sort_exprs,
59 window_frame: WindowFrame::new(Some(true)),
60 null_treatment: Some(NullTreatment::IgnoreNulls),
61 },
62 }))
63 .alias(ORDER_COL);
64
65 let mut selections = dataframe
67 .schema()
68 .inner()
69 .fields
70 .iter()
71 .filter_map(|field| {
72 if field.name() == ORDER_COL {
73 None
74 } else {
75 Some(flat_col(field.name()))
76 }
77 })
78 .collect::<Vec<_>>();
79 selections.insert(0, order_col);
80
81 let result = dataframe
82 .select(selections)
83 .with_context(|| "Collect transform failed".to_string())?;
84 Ok((result, Default::default()))
85 }
86}