vegafusion_runtime/data/
util.rs

1use async_trait::async_trait;
2use datafusion::datasource::{provider_as_source, MemTable};
3use datafusion::prelude::{DataFrame, SessionContext};
4use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
5use datafusion_common::TableReference;
6use datafusion_expr::{col, Expr, LogicalPlanBuilder, UNNAMED_TABLE};
7use datafusion_functions_window::row_number::row_number;
8use std::sync::Arc;
9use vegafusion_common::arrow::array::RecordBatch;
10use vegafusion_common::arrow::compute::concat_batches;
11use vegafusion_common::data::table::VegaFusionTable;
12use vegafusion_common::error::ResultWithContext;
13
14#[async_trait]
15pub trait SessionContextUtils {
16    async fn vegafusion_table(
17        &self,
18        tbl: VegaFusionTable,
19    ) -> vegafusion_common::error::Result<DataFrame>;
20}
21
22#[async_trait]
23impl SessionContextUtils for SessionContext {
24    async fn vegafusion_table(
25        &self,
26        tbl: VegaFusionTable,
27    ) -> vegafusion_common::error::Result<DataFrame> {
28        let mem_table = MemTable::try_new(tbl.schema.clone(), vec![tbl.batches])?;
29
30        // Based on self.read_batch()
31        Ok(DataFrame::new(
32            self.state(),
33            LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(Arc::new(mem_table)), None)?
34                .build()?,
35        ))
36    }
37}
38
39#[async_trait]
40pub trait DataFrameUtils {
41    async fn collect_to_table(self) -> vegafusion_common::error::Result<VegaFusionTable>;
42    async fn collect_flat(self) -> vegafusion_common::error::Result<RecordBatch>;
43    async fn with_index(self, index_name: &str) -> vegafusion_common::error::Result<DataFrame>;
44
45    /// Variant of aggregate that can handle agg expressions that include projections on top
46    /// of aggregations. Also includes groupby expressions in the final result
47    fn aggregate_mixed(
48        self,
49        group_expr: Vec<Expr>,
50        aggr_expr: Vec<Expr>,
51    ) -> vegafusion_common::error::Result<DataFrame>;
52    fn alias(self, name: impl Into<TableReference>) -> vegafusion_common::error::Result<DataFrame>;
53}
54
55#[async_trait]
56impl DataFrameUtils for DataFrame {
57    async fn collect_to_table(self) -> vegafusion_common::error::Result<VegaFusionTable> {
58        let mut arrow_schema = self.schema().inner().clone();
59        let batches = self.collect().await?;
60        if let Some(batch) = batches.first() {
61            // use first batch schema if present
62            arrow_schema = batch.schema()
63        }
64        VegaFusionTable::try_new(arrow_schema, batches)
65    }
66
67    async fn collect_flat(self) -> vegafusion_common::error::Result<RecordBatch> {
68        let mut arrow_schema = self.schema().inner().clone();
69        let batches = self.collect().await?;
70        if let Some(batch) = batches.first() {
71            arrow_schema = batch.schema()
72        }
73        concat_batches(&arrow_schema, batches.as_slice())
74            .with_context(|| String::from("Failed to concatenate RecordBatches"))
75    }
76
77    async fn with_index(self, index_name: &str) -> vegafusion_common::error::Result<DataFrame> {
78        if self.schema().inner().column_with_name(index_name).is_some() {
79            // Column is already present, don't overwrite
80            Ok(self.select(vec![datafusion_expr::expr_fn::wildcard()])?)
81        } else {
82            let selections: Vec<datafusion_expr::select_expr::SelectExpr> = vec![
83                row_number().alias(index_name).into(),
84                datafusion_expr::expr_fn::wildcard(),
85            ];
86            Ok(self.select(selections)?)
87        }
88    }
89
90    fn aggregate_mixed(
91        self,
92        group_expr: Vec<Expr>,
93        aggr_expr: Vec<Expr>,
94    ) -> vegafusion_common::error::Result<DataFrame> {
95        let mut select_exprs: Vec<Expr> = Vec::new();
96
97        // Extract pure agg expressions
98        let mut agg_rewriter = PureAggRewriter::new();
99
100        for agg_expr in aggr_expr {
101            let select_expr = agg_expr.rewrite(&mut agg_rewriter)?;
102            select_exprs.push(select_expr.data)
103        }
104
105        // Apply pure agg functions
106        let df = self.aggregate(group_expr.clone(), agg_rewriter.pure_aggs)?;
107
108        // Add groupby exprs to selection
109        select_exprs.extend(group_expr);
110
111        // Apply projection on top of aggs
112        Ok(df.select(select_exprs)?)
113    }
114
115    fn alias(self, name: impl Into<TableReference>) -> vegafusion_common::error::Result<DataFrame> {
116        let (state, plan) = self.into_parts();
117        Ok(DataFrame::new(
118            state,
119            LogicalPlanBuilder::new(plan).alias(name)?.build()?,
120        ))
121    }
122}
123
124pub struct PureAggRewriter {
125    pub pure_aggs: Vec<Expr>,
126    pub next_id: usize,
127}
128
129impl Default for PureAggRewriter {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl PureAggRewriter {
136    pub fn new() -> Self {
137        Self {
138            pure_aggs: vec![],
139            next_id: 0,
140        }
141    }
142
143    fn new_agg_name(&mut self) -> String {
144        let name = format!("_agg_{}", self.next_id);
145        self.next_id += 1;
146        name
147    }
148}
149
150impl TreeNodeRewriter for PureAggRewriter {
151    type Node = Expr;
152
153    fn f_down(&mut self, node: Expr) -> datafusion_common::Result<Transformed<Self::Node>> {
154        if let Expr::AggregateFunction(agg) = node {
155            // extract agg and replace with column
156            let name = self.new_agg_name();
157            self.pure_aggs
158                .push(Expr::AggregateFunction(agg).alias(&name));
159            Ok(Transformed::new_transformed(col(name), true))
160        } else {
161            // Return expr node unchanged
162            Ok(Transformed::no(node))
163        }
164    }
165}