vegafusion_runtime/transform/
window.rs1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4
5use datafusion::prelude::DataFrame;
6use datafusion_common::ScalarValue;
7use datafusion_expr::{
8 expr, expr::WindowFunctionParams, lit, Expr, WindowFrame, WindowFunctionDefinition,
9};
10use datafusion_functions_aggregate::variance::{var_pop_udaf, var_samp_udaf};
11use std::sync::Arc;
12use vegafusion_core::error::Result;
13use vegafusion_core::proto::gen::transforms::{
14 window_transform_op, AggregateOp, SortOrder, Window, WindowOp,
15};
16use vegafusion_core::task_graph::task_value::TaskValue;
17
18use datafusion_expr::{WindowFrameBound, WindowFrameUnits};
19use datafusion_functions_aggregate::average::avg_udaf;
20use datafusion_functions_aggregate::count::count_udaf;
21use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
22use datafusion_functions_aggregate::stddev::{stddev_pop_udaf, stddev_udaf};
23use datafusion_functions_aggregate::sum::sum_udaf;
24
25use datafusion_functions_window::{
26 cume_dist::CumeDist,
27 nth_value::{first_value_udwf, last_value_udwf},
28 rank::Rank,
29 row_number::RowNumber,
30};
31
32use vegafusion_common::column::{flat_col, unescaped_col};
33use vegafusion_common::data::ORDER_COL;
34use vegafusion_common::datatypes::to_numeric;
35use vegafusion_common::error::{ResultWithContext, VegaFusionError};
36use vegafusion_common::escape::unescape_field;
37
38#[async_trait]
39impl TransformTrait for Window {
40 async fn eval(
41 &self,
42 dataframe: DataFrame,
43 _config: &CompilationConfig,
44 ) -> Result<(DataFrame, Vec<TaskValue>)> {
45 let mut order_by: Vec<_> = self
46 .sort_fields
47 .iter()
48 .zip(&self.sort)
49 .map(|(field, order)| expr::Sort {
50 expr: unescaped_col(field),
51 asc: *order == SortOrder::Ascending as i32,
52 nulls_first: *order == SortOrder::Ascending as i32,
53 })
54 .collect();
55
56 let mut selections: Vec<_> = dataframe
57 .schema()
58 .fields()
59 .iter()
60 .map(|f| flat_col(f.name()))
61 .collect();
62
63 if order_by.is_empty() {
64 order_by.push(expr::Sort {
66 expr: flat_col(ORDER_COL),
67 asc: true,
68 nulls_first: true,
69 });
70 };
71
72 let partition_by: Vec<_> = self
73 .groupby
74 .iter()
75 .filter(|c| {
76 dataframe
77 .schema()
78 .inner()
79 .column_with_name(&unescape_field(c))
80 .is_some()
81 })
82 .map(|group| unescaped_col(group))
83 .collect();
84
85 let (start_bound, end_bound) = match &self.frame {
86 None => (
87 WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
89 WindowFrameBound::CurrentRow,
91 ),
92 Some(frame) => (
93 WindowFrameBound::Preceding(ScalarValue::UInt64(
94 frame.start.map(|v| v.unsigned_abs()),
95 )),
96 WindowFrameBound::Following(ScalarValue::UInt64(frame.end.map(|v| v as u64))),
97 ),
98 };
99
100 let ignore_peers = self.ignore_peers.unwrap_or(false);
101
102 let units = if ignore_peers {
103 WindowFrameUnits::Rows
104 } else {
105 WindowFrameUnits::Groups
106 };
107 let window_frame = WindowFrame::new_bounds(units, start_bound, end_bound);
108
109 let schema_df = dataframe.schema();
110 let window_exprs = self
111 .ops
112 .iter()
113 .zip(&self.fields)
114 .enumerate()
115 .map(|(i, (op, field))| -> Result<Expr> {
116 let (window_fn, args) = match op.op.as_ref().unwrap() {
117 window_transform_op::Op::AggregateOp(op) => {
118 let op = AggregateOp::try_from(*op).unwrap();
119
120 let numeric_field = || -> Result<Expr> {
121 to_numeric(unescaped_col(field), schema_df).with_context(|| {
122 format!("Failed to convert field {field} to numeric data type")
123 })
124 };
125
126 use AggregateOp::*;
127 match op {
128 Count => (
129 WindowFunctionDefinition::AggregateUDF(count_udaf()),
130 vec![lit(true)],
131 ),
132 Sum => (
133 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
134 vec![numeric_field()?],
135 ),
136 Mean | Average => (
137 WindowFunctionDefinition::AggregateUDF(avg_udaf()),
138 vec![numeric_field()?],
139 ),
140 Min => (
141 WindowFunctionDefinition::AggregateUDF(min_udaf()),
142 vec![numeric_field()?],
143 ),
144 Max => (
145 WindowFunctionDefinition::AggregateUDF(max_udaf()),
146 vec![numeric_field()?],
147 ),
148 Variance => (
149 WindowFunctionDefinition::AggregateUDF(var_samp_udaf()),
150 vec![numeric_field()?],
151 ),
152 Variancep => (
153 WindowFunctionDefinition::AggregateUDF(var_pop_udaf()),
154 vec![numeric_field()?],
155 ),
156 Stdev => (
157 WindowFunctionDefinition::AggregateUDF(stddev_udaf()),
158 vec![numeric_field()?],
159 ),
160 Stdevp => (
161 WindowFunctionDefinition::AggregateUDF(stddev_pop_udaf()),
162 vec![numeric_field()?],
163 ),
164 _ => {
167 return Err(VegaFusionError::compilation(format!(
168 "Unsupported window aggregate: {op:?}"
169 )))
170 }
171 }
172 }
173 window_transform_op::Op::WindowOp(op) => {
174 let op = WindowOp::try_from(*op).unwrap();
175 let _param = self.params.get(i);
176
177 let (window_fn, args) = match op {
178 WindowOp::RowNumber => (
179 WindowFunctionDefinition::WindowUDF(Arc::new(
180 RowNumber::new().into(),
181 )),
182 Vec::new(),
183 ),
184 WindowOp::Rank => (
185 WindowFunctionDefinition::WindowUDF(Arc::new(Rank::basic().into())),
186 Vec::new(),
187 ),
188 WindowOp::DenseRank => (
189 WindowFunctionDefinition::WindowUDF(Arc::new(
190 Rank::dense_rank().into(),
191 )),
192 Vec::new(),
193 ),
194 WindowOp::PercentileRank => (
195 WindowFunctionDefinition::WindowUDF(Arc::new(
196 Rank::percent_rank().into(),
197 )),
198 Vec::new(),
199 ),
200 WindowOp::CumeDist => (
201 WindowFunctionDefinition::WindowUDF(Arc::new(
202 CumeDist::new().into(),
203 )),
204 Vec::new(),
205 ),
206 WindowOp::FirstValue => (
207 WindowFunctionDefinition::WindowUDF(first_value_udwf()),
208 vec![unescaped_col(field)],
209 ),
210 WindowOp::LastValue => (
211 WindowFunctionDefinition::WindowUDF(last_value_udwf()),
212 vec![unescaped_col(field)],
213 ),
214 _ => {
215 return Err(VegaFusionError::compilation(format!(
216 "Unsupported window function: {op:?}"
217 )))
218 }
219 };
220 (window_fn, args)
221 }
222 };
223
224 let window_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
225 fun: window_fn,
226 params: WindowFunctionParams {
227 args,
228 partition_by: partition_by.clone(),
229 order_by: order_by.clone(),
230 window_frame: window_frame.clone(),
231 null_treatment: None,
232 },
233 }));
234
235 if let Some(alias) = self.aliases.get(i) {
236 Ok(window_expr.alias(alias))
237 } else {
238 Ok(window_expr)
239 }
240 })
241 .collect::<Result<Vec<_>>>()?;
242
243 selections.extend(window_exprs);
245
246 let dataframe = dataframe.select(selections)?;
247
248 Ok((dataframe, Vec::new()))
249 }
250}