vegafusion_core/spec/transform/
aggregate.rs

1use crate::expression::column_usage::{ColumnUsage, DatasetsColumnUsage, VlSelectionFields};
2use crate::spec::transform::{TransformColumns, TransformSpecTrait};
3use crate::spec::values::Field;
4use crate::task_graph::graph::ScopedVariable;
5use crate::task_graph::scope::TaskScope;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use vegafusion_common::escape::unescape_field;
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub struct AggregateTransformSpec {
13    pub groupby: Vec<Field>,
14
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub fields: Option<Vec<Option<Field>>>,
17
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub ops: Option<Vec<AggregateOpSpec>>,
20
21    #[serde(rename = "as", skip_serializing_if = "Option::is_none")]
22    pub as_: Option<Vec<Option<String>>>,
23
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub cross: Option<bool>,
26
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub drop: Option<bool>,
29
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub key: Option<Field>,
32
33    #[serde(flatten)]
34    pub extra: HashMap<String, Value>,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
38#[serde(rename_all = "lowercase")]
39pub enum AggregateOpSpec {
40    Count,
41    Valid,
42    Missing,
43    Distinct,
44    Sum,
45    Product,
46    Mean,
47    Average,
48    Variance,
49    Variancep,
50    Stdev,
51    Stdevp,
52    Stderr,
53    Median,
54    Q1,
55    Q3,
56    Ci0,
57    Ci1,
58    Min,
59    Max,
60    Argmin,
61    Argmax,
62    Values,
63}
64
65impl AggregateOpSpec {
66    pub fn name(&self) -> String {
67        serde_json::to_value(self)
68            .unwrap()
69            .as_str()
70            .unwrap()
71            .to_string()
72    }
73}
74
75impl TransformSpecTrait for AggregateTransformSpec {
76    fn supported(&self) -> bool {
77        // Check for supported aggregation op
78        use AggregateOpSpec::*;
79        let ops = self.ops.clone().unwrap_or_else(|| vec![Count]);
80        for op in &ops {
81            if !matches!(
82                op,
83                Count
84                    | Valid
85                    | Missing
86                    | Distinct
87                    | Sum
88                    | Mean
89                    | Average
90                    | Min
91                    | Max
92                    | Variance
93                    | Variancep
94                    | Stdev
95                    | Stdevp
96                    | Median
97                    | Q1
98                    | Q3
99            ) {
100                // Unsupported aggregation op
101                return false;
102            }
103        }
104
105        // Cross aggregation not supported
106        if let Some(true) = &self.cross {
107            return false;
108        }
109
110        // drop=false not support
111        if let Some(false) = &self.drop {
112            return false;
113        }
114        true
115    }
116
117    fn transform_columns(
118        &self,
119        datum_var: &Option<ScopedVariable>,
120        _usage_scope: &[u32],
121        _task_scope: &TaskScope,
122        _vl_selection_fields: &VlSelectionFields,
123    ) -> TransformColumns {
124        if let Some(datum_var) = datum_var {
125            // Compute produced columns
126            // Only handle the case where "as" contains a list of strings with length matching ops
127            let ops = self
128                .ops
129                .clone()
130                .unwrap_or_else(|| vec![AggregateOpSpec::Count]);
131            let as_: Vec<_> = self
132                .as_
133                .clone()
134                .unwrap_or_default()
135                .iter()
136                .cloned()
137                .collect::<Option<Vec<_>>>()
138                .unwrap_or_default();
139            let produced = if ops.len() == as_.len() {
140                ColumnUsage::from(as_.as_slice())
141            } else {
142                ColumnUsage::Unknown
143            };
144
145            // Compute used columns (both groupby and fields)
146            let mut usage_cols: Vec<_> = self
147                .groupby
148                .iter()
149                .map(|field| unescape_field(&field.field()))
150                .collect();
151            for field in self
152                .fields
153                .clone()
154                .unwrap_or_default()
155                .into_iter()
156                .flatten()
157            {
158                usage_cols.push(unescape_field(&field.field()))
159            }
160            let col_usage = ColumnUsage::from(usage_cols.as_slice());
161            let usage = DatasetsColumnUsage::empty().with_column_usage(datum_var, col_usage);
162            TransformColumns::Overwrite { usage, produced }
163        } else {
164            TransformColumns::Unknown
165        }
166    }
167
168    fn local_datetime_columns_produced(
169        &self,
170        input_local_datetime_columns: &[String],
171    ) -> Vec<String> {
172        // Keep input local datetime columns that are used as grouping fields
173        self.groupby
174            .iter()
175            .filter_map(|groupby_field| {
176                let groupby_field_name = groupby_field.field();
177                let unescaped = unescape_field(&groupby_field_name);
178                if input_local_datetime_columns.contains(&unescaped) {
179                    Some(unescaped)
180                } else {
181                    None
182                }
183            })
184            .collect::<Vec<_>>()
185    }
186}