vegafusion_runtime/data/
util.rs1use async_trait::async_trait;
2use datafusion::datasource::{provider_as_source, MemTable};
3use datafusion::prelude::{DataFrame, SessionContext};
4use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
5use datafusion_common::TableReference;
6use datafusion_expr::{col, Expr, LogicalPlanBuilder, UNNAMED_TABLE};
7use datafusion_functions_window::row_number::row_number;
8use std::sync::Arc;
9use vegafusion_common::arrow::array::RecordBatch;
10use vegafusion_common::arrow::compute::concat_batches;
11use vegafusion_common::data::table::VegaFusionTable;
12use vegafusion_common::error::ResultWithContext;
13
14#[async_trait]
15pub trait SessionContextUtils {
16 async fn vegafusion_table(
17 &self,
18 tbl: VegaFusionTable,
19 ) -> vegafusion_common::error::Result<DataFrame>;
20}
21
22#[async_trait]
23impl SessionContextUtils for SessionContext {
24 async fn vegafusion_table(
25 &self,
26 tbl: VegaFusionTable,
27 ) -> vegafusion_common::error::Result<DataFrame> {
28 let mem_table = MemTable::try_new(tbl.schema.clone(), vec![tbl.batches])?;
29
30 Ok(DataFrame::new(
32 self.state(),
33 LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(Arc::new(mem_table)), None)?
34 .build()?,
35 ))
36 }
37}
38
39#[async_trait]
40pub trait DataFrameUtils {
41 async fn collect_to_table(self) -> vegafusion_common::error::Result<VegaFusionTable>;
42 async fn collect_flat(self) -> vegafusion_common::error::Result<RecordBatch>;
43 async fn with_index(self, index_name: &str) -> vegafusion_common::error::Result<DataFrame>;
44
45 fn aggregate_mixed(
48 self,
49 group_expr: Vec<Expr>,
50 aggr_expr: Vec<Expr>,
51 ) -> vegafusion_common::error::Result<DataFrame>;
52 fn alias(self, name: impl Into<TableReference>) -> vegafusion_common::error::Result<DataFrame>;
53}
54
55#[async_trait]
56impl DataFrameUtils for DataFrame {
57 async fn collect_to_table(self) -> vegafusion_common::error::Result<VegaFusionTable> {
58 let mut arrow_schema = self.schema().inner().clone();
59 let batches = self.collect().await?;
60 if let Some(batch) = batches.first() {
61 arrow_schema = batch.schema()
63 }
64 VegaFusionTable::try_new(arrow_schema, batches)
65 }
66
67 async fn collect_flat(self) -> vegafusion_common::error::Result<RecordBatch> {
68 let mut arrow_schema = self.schema().inner().clone();
69 let batches = self.collect().await?;
70 if let Some(batch) = batches.first() {
71 arrow_schema = batch.schema()
72 }
73 concat_batches(&arrow_schema, batches.as_slice())
74 .with_context(|| String::from("Failed to concatenate RecordBatches"))
75 }
76
77 async fn with_index(self, index_name: &str) -> vegafusion_common::error::Result<DataFrame> {
78 if self.schema().inner().column_with_name(index_name).is_some() {
79 Ok(self.select(vec![datafusion_expr::expr_fn::wildcard()])?)
81 } else {
82 let selections: Vec<datafusion_expr::select_expr::SelectExpr> = vec![
83 row_number().alias(index_name).into(),
84 datafusion_expr::expr_fn::wildcard(),
85 ];
86 Ok(self.select(selections)?)
87 }
88 }
89
90 fn aggregate_mixed(
91 self,
92 group_expr: Vec<Expr>,
93 aggr_expr: Vec<Expr>,
94 ) -> vegafusion_common::error::Result<DataFrame> {
95 let mut select_exprs: Vec<Expr> = Vec::new();
96
97 let mut agg_rewriter = PureAggRewriter::new();
99
100 for agg_expr in aggr_expr {
101 let select_expr = agg_expr.rewrite(&mut agg_rewriter)?;
102 select_exprs.push(select_expr.data)
103 }
104
105 let df = self.aggregate(group_expr.clone(), agg_rewriter.pure_aggs)?;
107
108 select_exprs.extend(group_expr);
110
111 Ok(df.select(select_exprs)?)
113 }
114
115 fn alias(self, name: impl Into<TableReference>) -> vegafusion_common::error::Result<DataFrame> {
116 let (state, plan) = self.into_parts();
117 Ok(DataFrame::new(
118 state,
119 LogicalPlanBuilder::new(plan).alias(name)?.build()?,
120 ))
121 }
122}
123
124pub struct PureAggRewriter {
125 pub pure_aggs: Vec<Expr>,
126 pub next_id: usize,
127}
128
129impl Default for PureAggRewriter {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl PureAggRewriter {
136 pub fn new() -> Self {
137 Self {
138 pure_aggs: vec![],
139 next_id: 0,
140 }
141 }
142
143 fn new_agg_name(&mut self) -> String {
144 let name = format!("_agg_{}", self.next_id);
145 self.next_id += 1;
146 name
147 }
148}
149
150impl TreeNodeRewriter for PureAggRewriter {
151 type Node = Expr;
152
153 fn f_down(&mut self, node: Expr) -> datafusion_common::Result<Transformed<Self::Node>> {
154 if let Expr::AggregateFunction(agg) = node {
155 let name = self.new_agg_name();
157 self.pure_aggs
158 .push(Expr::AggregateFunction(agg).alias(&name));
159 Ok(Transformed::new_transformed(col(name), true))
160 } else {
161 Ok(Transformed::no(node))
163 }
164 }
165}