vegafusion_runtime/transform/
project.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3
4use async_trait::async_trait;
5use datafusion::prelude::DataFrame;
6use std::collections::HashSet;
7use vegafusion_common::column::flat_col;
8use vegafusion_common::data::ORDER_COL;
9use vegafusion_common::escape::unescape_field;
10use vegafusion_core::error::Result;
11use vegafusion_core::proto::gen::transforms::Project;
12use vegafusion_core::task_graph::task_value::TaskValue;
13
14#[async_trait]
15impl TransformTrait for Project {
16    async fn eval(
17        &self,
18        dataframe: DataFrame,
19        _config: &CompilationConfig,
20    ) -> Result<(DataFrame, Vec<TaskValue>)> {
21        // Collect all dataframe fields into a HashSet for fast membership test
22        let all_fields: HashSet<_> = dataframe
23            .schema()
24            .fields()
25            .iter()
26            .map(|field| field.name().clone())
27            .collect();
28
29        // Keep all of the project columns that are present in the dataframe.
30        // Skip projection fields that are not found
31        let mut select_fields: Vec<_> = self
32            .fields
33            .iter()
34            .filter_map(|field| {
35                let field = unescape_field(field);
36                if all_fields.contains(&field) {
37                    Some(field)
38                } else {
39                    None
40                }
41            })
42            .collect();
43
44        // Always keep ordering column
45        select_fields.insert(0, ORDER_COL.to_string());
46
47        let select_col_exprs: Vec<_> = select_fields.iter().map(|f| flat_col(f)).collect();
48        let result = dataframe.select(select_col_exprs)?;
49        Ok((result, Default::default()))
50    }
51}