vegafusion_core/spec/transform/
joinaggregate.rs

1use crate::expression::column_usage::{ColumnUsage, DatasetsColumnUsage, VlSelectionFields};
2use crate::spec::transform::aggregate::AggregateOpSpec;
3use crate::spec::transform::{TransformColumns, TransformSpecTrait};
4use crate::spec::values::Field;
5use crate::task_graph::graph::ScopedVariable;
6use crate::task_graph::scope::TaskScope;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use vegafusion_common::escape::unescape_field;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct JoinAggregateTransformSpec {
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub groupby: Option<Vec<Field>>,
16
17    pub fields: Vec<Option<Field>>,
18    pub ops: Vec<AggregateOpSpec>,
19
20    #[serde(rename = "as", skip_serializing_if = "Option::is_none")]
21    pub as_: Option<Vec<Option<String>>>,
22
23    #[serde(flatten)]
24    pub extra: HashMap<String, Value>,
25}
26
27impl TransformSpecTrait for JoinAggregateTransformSpec {
28    fn supported(&self) -> bool {
29        // Check for supported aggregation op
30        use AggregateOpSpec::*;
31        for op in &self.ops {
32            if !matches!(
33                op,
34                Count
35                    | Valid
36                    | Missing
37                    | Distinct
38                    | Sum
39                    | Mean
40                    | Average
41                    | Min
42                    | Max
43                    | Variance
44                    | Variancep
45                    | Stdev
46                    | Stdevp
47                    | Median
48                    | Q1
49                    | Q3
50            ) {
51                // Unsupported aggregation op
52                return false;
53            }
54        }
55
56        true
57    }
58
59    fn transform_columns(
60        &self,
61        datum_var: &Option<ScopedVariable>,
62        _usage_scope: &[u32],
63        _task_scope: &TaskScope,
64        _vl_selection_fields: &VlSelectionFields,
65    ) -> TransformColumns {
66        if let Some(datum_var) = datum_var {
67            // Compute produced columns
68            // Only handle the case where "as" contains a list of strings with length matching ops
69            let ops = self.ops.clone();
70            let as_: Vec<_> = self
71                .as_
72                .clone()
73                .unwrap_or_default()
74                .iter()
75                .cloned()
76                .collect::<Option<Vec<_>>>()
77                .unwrap_or_default();
78            let produced = if ops.len() == as_.len() {
79                ColumnUsage::from(as_.as_slice())
80            } else {
81                ColumnUsage::Unknown
82            };
83
84            // Compute used columns (both groupby and fields)
85            let mut usage_cols: Vec<_> = self
86                .groupby
87                .clone()
88                .unwrap_or_default()
89                .iter()
90                .map(|field| unescape_field(&field.field()))
91                .collect();
92            for field in self.fields.iter().flatten() {
93                usage_cols.push(unescape_field(&field.field()))
94            }
95            let col_usage = ColumnUsage::from(usage_cols.as_slice());
96            let usage = DatasetsColumnUsage::empty().with_column_usage(datum_var, col_usage);
97            TransformColumns::PassThrough { usage, produced }
98        } else {
99            TransformColumns::Unknown
100        }
101    }
102
103    fn local_datetime_columns_produced(
104        &self,
105        input_local_datetime_columns: &[String],
106    ) -> Vec<String> {
107        // Keep input local datetime columns as joinaggregate passes through all input columns,
108        // and the new columns created by joinaggregate will never be local datetimes
109        Vec::from(input_local_datetime_columns)
110    }
111}