vegafusion_core/spec/transform/
aggregate.rs1use crate::expression::column_usage::{ColumnUsage, DatasetsColumnUsage, VlSelectionFields};
2use crate::spec::transform::{TransformColumns, TransformSpecTrait};
3use crate::spec::values::Field;
4use crate::task_graph::graph::ScopedVariable;
5use crate::task_graph::scope::TaskScope;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use vegafusion_common::escape::unescape_field;
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub struct AggregateTransformSpec {
13 pub groupby: Vec<Field>,
14
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub fields: Option<Vec<Option<Field>>>,
17
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub ops: Option<Vec<AggregateOpSpec>>,
20
21 #[serde(rename = "as", skip_serializing_if = "Option::is_none")]
22 pub as_: Option<Vec<Option<String>>>,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub cross: Option<bool>,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub drop: Option<bool>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub key: Option<Field>,
32
33 #[serde(flatten)]
34 pub extra: HashMap<String, Value>,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
38#[serde(rename_all = "lowercase")]
39pub enum AggregateOpSpec {
40 Count,
41 Valid,
42 Missing,
43 Distinct,
44 Sum,
45 Product,
46 Mean,
47 Average,
48 Variance,
49 Variancep,
50 Stdev,
51 Stdevp,
52 Stderr,
53 Median,
54 Q1,
55 Q3,
56 Ci0,
57 Ci1,
58 Min,
59 Max,
60 Argmin,
61 Argmax,
62 Values,
63}
64
65impl AggregateOpSpec {
66 pub fn name(&self) -> String {
67 serde_json::to_value(self)
68 .unwrap()
69 .as_str()
70 .unwrap()
71 .to_string()
72 }
73}
74
75impl TransformSpecTrait for AggregateTransformSpec {
76 fn supported(&self) -> bool {
77 use AggregateOpSpec::*;
79 let ops = self.ops.clone().unwrap_or_else(|| vec![Count]);
80 for op in &ops {
81 if !matches!(
82 op,
83 Count
84 | Valid
85 | Missing
86 | Distinct
87 | Sum
88 | Mean
89 | Average
90 | Min
91 | Max
92 | Variance
93 | Variancep
94 | Stdev
95 | Stdevp
96 | Median
97 | Q1
98 | Q3
99 ) {
100 return false;
102 }
103 }
104
105 if let Some(true) = &self.cross {
107 return false;
108 }
109
110 if let Some(false) = &self.drop {
112 return false;
113 }
114 true
115 }
116
117 fn transform_columns(
118 &self,
119 datum_var: &Option<ScopedVariable>,
120 _usage_scope: &[u32],
121 _task_scope: &TaskScope,
122 _vl_selection_fields: &VlSelectionFields,
123 ) -> TransformColumns {
124 if let Some(datum_var) = datum_var {
125 let ops = self
128 .ops
129 .clone()
130 .unwrap_or_else(|| vec![AggregateOpSpec::Count]);
131 let as_: Vec<_> = self
132 .as_
133 .clone()
134 .unwrap_or_default()
135 .iter()
136 .cloned()
137 .collect::<Option<Vec<_>>>()
138 .unwrap_or_default();
139 let produced = if ops.len() == as_.len() {
140 ColumnUsage::from(as_.as_slice())
141 } else {
142 ColumnUsage::Unknown
143 };
144
145 let mut usage_cols: Vec<_> = self
147 .groupby
148 .iter()
149 .map(|field| unescape_field(&field.field()))
150 .collect();
151 for field in self
152 .fields
153 .clone()
154 .unwrap_or_default()
155 .into_iter()
156 .flatten()
157 {
158 usage_cols.push(unescape_field(&field.field()))
159 }
160 let col_usage = ColumnUsage::from(usage_cols.as_slice());
161 let usage = DatasetsColumnUsage::empty().with_column_usage(datum_var, col_usage);
162 TransformColumns::Overwrite { usage, produced }
163 } else {
164 TransformColumns::Unknown
165 }
166 }
167
168 fn local_datetime_columns_produced(
169 &self,
170 input_local_datetime_columns: &[String],
171 ) -> Vec<String> {
172 self.groupby
174 .iter()
175 .filter_map(|groupby_field| {
176 let groupby_field_name = groupby_field.field();
177 let unescaped = unescape_field(&groupby_field_name);
178 if input_local_datetime_columns.contains(&unescaped) {
179 Some(unescaped)
180 } else {
181 None
182 }
183 })
184 .collect::<Vec<_>>()
185 }
186}