vegafusion_runtime/transform/
window.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4
5use datafusion::prelude::DataFrame;
6use datafusion_common::ScalarValue;
7use datafusion_expr::{
8    expr, expr::WindowFunctionParams, lit, Expr, WindowFrame, WindowFunctionDefinition,
9};
10use datafusion_functions_aggregate::variance::{var_pop_udaf, var_samp_udaf};
11use std::sync::Arc;
12use vegafusion_core::error::Result;
13use vegafusion_core::proto::gen::transforms::{
14    window_transform_op, AggregateOp, SortOrder, Window, WindowOp,
15};
16use vegafusion_core::task_graph::task_value::TaskValue;
17
18use datafusion_expr::{WindowFrameBound, WindowFrameUnits};
19use datafusion_functions_aggregate::average::avg_udaf;
20use datafusion_functions_aggregate::count::count_udaf;
21use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
22use datafusion_functions_aggregate::stddev::{stddev_pop_udaf, stddev_udaf};
23use datafusion_functions_aggregate::sum::sum_udaf;
24
25use datafusion_functions_window::{
26    cume_dist::CumeDist,
27    nth_value::{first_value_udwf, last_value_udwf},
28    rank::Rank,
29    row_number::RowNumber,
30};
31
32use vegafusion_common::column::{flat_col, unescaped_col};
33use vegafusion_common::data::ORDER_COL;
34use vegafusion_common::datatypes::to_numeric;
35use vegafusion_common::error::{ResultWithContext, VegaFusionError};
36use vegafusion_common::escape::unescape_field;
37
38#[async_trait]
39impl TransformTrait for Window {
40    async fn eval(
41        &self,
42        dataframe: DataFrame,
43        _config: &CompilationConfig,
44    ) -> Result<(DataFrame, Vec<TaskValue>)> {
45        let mut order_by: Vec<_> = self
46            .sort_fields
47            .iter()
48            .zip(&self.sort)
49            .map(|(field, order)| expr::Sort {
50                expr: unescaped_col(field),
51                asc: *order == SortOrder::Ascending as i32,
52                nulls_first: *order == SortOrder::Ascending as i32,
53            })
54            .collect();
55
56        let mut selections: Vec<_> = dataframe
57            .schema()
58            .fields()
59            .iter()
60            .map(|f| flat_col(f.name()))
61            .collect();
62
63        if order_by.is_empty() {
64            // Order by input row if no ordering specified
65            order_by.push(expr::Sort {
66                expr: flat_col(ORDER_COL),
67                asc: true,
68                nulls_first: true,
69            });
70        };
71
72        let partition_by: Vec<_> = self
73            .groupby
74            .iter()
75            .filter(|c| {
76                dataframe
77                    .schema()
78                    .inner()
79                    .column_with_name(&unescape_field(c))
80                    .is_some()
81            })
82            .map(|group| unescaped_col(group))
83            .collect();
84
85        let (start_bound, end_bound) = match &self.frame {
86            None => (
87                // Unbounded preceding
88                WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
89                // Current row
90                WindowFrameBound::CurrentRow,
91            ),
92            Some(frame) => (
93                WindowFrameBound::Preceding(ScalarValue::UInt64(
94                    frame.start.map(|v| v.unsigned_abs()),
95                )),
96                WindowFrameBound::Following(ScalarValue::UInt64(frame.end.map(|v| v as u64))),
97            ),
98        };
99
100        let ignore_peers = self.ignore_peers.unwrap_or(false);
101
102        let units = if ignore_peers {
103            WindowFrameUnits::Rows
104        } else {
105            WindowFrameUnits::Groups
106        };
107        let window_frame = WindowFrame::new_bounds(units, start_bound, end_bound);
108
109        let schema_df = dataframe.schema();
110        let window_exprs = self
111            .ops
112            .iter()
113            .zip(&self.fields)
114            .enumerate()
115            .map(|(i, (op, field))| -> Result<Expr> {
116                let (window_fn, args) = match op.op.as_ref().unwrap() {
117                    window_transform_op::Op::AggregateOp(op) => {
118                        let op = AggregateOp::try_from(*op).unwrap();
119
120                        let numeric_field = || -> Result<Expr> {
121                            to_numeric(unescaped_col(field), schema_df).with_context(|| {
122                                format!("Failed to convert field {field} to numeric data type")
123                            })
124                        };
125
126                        use AggregateOp::*;
127                        match op {
128                            Count => (
129                                WindowFunctionDefinition::AggregateUDF(count_udaf()),
130                                vec![lit(true)],
131                            ),
132                            Sum => (
133                                WindowFunctionDefinition::AggregateUDF(sum_udaf()),
134                                vec![numeric_field()?],
135                            ),
136                            Mean | Average => (
137                                WindowFunctionDefinition::AggregateUDF(avg_udaf()),
138                                vec![numeric_field()?],
139                            ),
140                            Min => (
141                                WindowFunctionDefinition::AggregateUDF(min_udaf()),
142                                vec![numeric_field()?],
143                            ),
144                            Max => (
145                                WindowFunctionDefinition::AggregateUDF(max_udaf()),
146                                vec![numeric_field()?],
147                            ),
148                            Variance => (
149                                WindowFunctionDefinition::AggregateUDF(var_samp_udaf()),
150                                vec![numeric_field()?],
151                            ),
152                            Variancep => (
153                                WindowFunctionDefinition::AggregateUDF(var_pop_udaf()),
154                                vec![numeric_field()?],
155                            ),
156                            Stdev => (
157                                WindowFunctionDefinition::AggregateUDF(stddev_udaf()),
158                                vec![numeric_field()?],
159                            ),
160                            Stdevp => (
161                                WindowFunctionDefinition::AggregateUDF(stddev_pop_udaf()),
162                                vec![numeric_field()?],
163                            ),
164                            // ArrayAgg only available on master right now
165                            // Values => (aggregates::AggregateFunction::ArrayAgg, unescaped_col(field)),
166                            _ => {
167                                return Err(VegaFusionError::compilation(format!(
168                                    "Unsupported window aggregate: {op:?}"
169                                )))
170                            }
171                        }
172                    }
173                    window_transform_op::Op::WindowOp(op) => {
174                        let op = WindowOp::try_from(*op).unwrap();
175                        let _param = self.params.get(i);
176
177                        let (window_fn, args) = match op {
178                            WindowOp::RowNumber => (
179                                WindowFunctionDefinition::WindowUDF(Arc::new(
180                                    RowNumber::new().into(),
181                                )),
182                                Vec::new(),
183                            ),
184                            WindowOp::Rank => (
185                                WindowFunctionDefinition::WindowUDF(Arc::new(Rank::basic().into())),
186                                Vec::new(),
187                            ),
188                            WindowOp::DenseRank => (
189                                WindowFunctionDefinition::WindowUDF(Arc::new(
190                                    Rank::dense_rank().into(),
191                                )),
192                                Vec::new(),
193                            ),
194                            WindowOp::PercentileRank => (
195                                WindowFunctionDefinition::WindowUDF(Arc::new(
196                                    Rank::percent_rank().into(),
197                                )),
198                                Vec::new(),
199                            ),
200                            WindowOp::CumeDist => (
201                                WindowFunctionDefinition::WindowUDF(Arc::new(
202                                    CumeDist::new().into(),
203                                )),
204                                Vec::new(),
205                            ),
206                            WindowOp::FirstValue => (
207                                WindowFunctionDefinition::WindowUDF(first_value_udwf()),
208                                vec![unescaped_col(field)],
209                            ),
210                            WindowOp::LastValue => (
211                                WindowFunctionDefinition::WindowUDF(last_value_udwf()),
212                                vec![unescaped_col(field)],
213                            ),
214                            _ => {
215                                return Err(VegaFusionError::compilation(format!(
216                                    "Unsupported window function: {op:?}"
217                                )))
218                            }
219                        };
220                        (window_fn, args)
221                    }
222                };
223
224                let window_expr = Expr::WindowFunction(Box::new(expr::WindowFunction {
225                    fun: window_fn,
226                    params: WindowFunctionParams {
227                        args,
228                        partition_by: partition_by.clone(),
229                        order_by: order_by.clone(),
230                        window_frame: window_frame.clone(),
231                        null_treatment: None,
232                    },
233                }));
234
235                if let Some(alias) = self.aliases.get(i) {
236                    Ok(window_expr.alias(alias))
237                } else {
238                    Ok(window_expr)
239                }
240            })
241            .collect::<Result<Vec<_>>>()?;
242
243        // Add window expressions to original selections
244        selections.extend(window_exprs);
245
246        let dataframe = dataframe.select(selections)?;
247
248        Ok((dataframe, Vec::new()))
249    }
250}