vegafusion_runtime/transform/
identifier.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use async_trait::async_trait;
5use datafusion::prelude::DataFrame;
6use datafusion_expr::{
7    expr, expr::WindowFunctionParams, Expr, WindowFrame, WindowFunctionDefinition,
8};
9use datafusion_functions_window::row_number::RowNumber;
10use sqlparser::ast::NullTreatment;
11use std::sync::Arc;
12use vegafusion_common::column::flat_col;
13use vegafusion_common::data::ORDER_COL;
14use vegafusion_common::error::Result;
15use vegafusion_core::proto::gen::transforms::Identifier;
16use vegafusion_core::task_graph::task_value::TaskValue;
17
18#[async_trait]
19impl TransformTrait for Identifier {
20    async fn eval(
21        &self,
22        dataframe: DataFrame,
23        _config: &CompilationConfig,
24    ) -> Result<(DataFrame, Vec<TaskValue>)> {
25        // Add row number column with the desired name, sorted by the input order column
26        let row_number_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
27            fun: WindowFunctionDefinition::WindowUDF(Arc::new(RowNumber::new().into())),
28            params: WindowFunctionParams {
29                args: Vec::new(),
30                partition_by: Vec::new(),
31                order_by: vec![expr::Sort {
32                    expr: flat_col(ORDER_COL),
33                    asc: true,
34                    nulls_first: false,
35                }],
36                window_frame: WindowFrame::new(Some(true)),
37                null_treatment: Some(NullTreatment::IgnoreNulls),
38            },
39        }))
40        .alias(&self.r#as);
41
42        // Select all original columns plus the new identifier column
43        let mut select_exprs: Vec<Expr> = dataframe
44            .schema()
45            .fields()
46            .iter()
47            .map(|f| flat_col(f.name()))
48            .collect();
49
50        select_exprs.push(row_number_expr.into());
51
52        let result = dataframe.select(select_exprs)?;
53
54        Ok((result, Default::default()))
55    }
56}