1use async_trait::async_trait;
7use datafusion::prelude::*;
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap};
10use std::fmt;
11use tracing::{debug, instrument};
12
13use super::{Analyzer, AnalyzerResult, AnalyzerState, MetricValue};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct GroupingConfig {
18 pub columns: Vec<String>,
20
21 pub max_groups: Option<usize>,
23
24 pub include_overall: bool,
26
27 pub overflow_strategy: OverflowStrategy,
29}
30
31impl GroupingConfig {
32 pub fn new(columns: Vec<String>) -> Self {
34 Self {
35 columns,
36 max_groups: Some(10000),
37 include_overall: true,
38 overflow_strategy: OverflowStrategy::TopK,
39 }
40 }
41
42 pub fn with_max_groups(mut self, max: usize) -> Self {
44 self.max_groups = Some(max);
45 self
46 }
47
48 pub fn with_overall(mut self, include: bool) -> Self {
50 self.include_overall = include;
51 self
52 }
53
54 pub fn with_overflow_strategy(mut self, strategy: OverflowStrategy) -> Self {
56 self.overflow_strategy = strategy;
57 self
58 }
59
60 pub fn group_by_sql(&self) -> String {
62 self.columns.join(", ")
63 }
64
65 pub fn select_group_columns_sql(&self) -> String {
67 self.columns
68 .iter()
69 .map(|col| format!("{col} as group_{col}"))
70 .collect::<Vec<_>>()
71 .join(", ")
72 }
73}
74
75#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
77pub enum OverflowStrategy {
78 TopK,
80
81 BottomK,
83
84 Sample,
86
87 Fail,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct GroupedMetrics {
94 pub groups: BTreeMap<Vec<String>, MetricValue>,
97
98 pub overall: Option<MetricValue>,
100
101 pub metadata: GroupedMetadata,
103}
104
105impl GroupedMetrics {
106 pub fn new(
108 groups: BTreeMap<Vec<String>, MetricValue>,
109 overall: Option<MetricValue>,
110 metadata: GroupedMetadata,
111 ) -> Self {
112 Self {
113 groups,
114 overall,
115 metadata,
116 }
117 }
118
119 pub fn group_count(&self) -> usize {
121 self.groups.len()
122 }
123
124 pub fn get_group(&self, key: &[String]) -> Option<&MetricValue> {
126 self.groups.get(key)
127 }
128
129 pub fn is_truncated(&self) -> bool {
131 self.metadata.truncated
132 }
133
134 pub fn to_metric_value(&self) -> MetricValue {
136 let mut map = HashMap::new();
137
138 for (key, value) in &self.groups {
140 let key_str = key.join("_");
141 map.insert(key_str, value.clone());
142 }
143
144 if let Some(ref overall) = self.overall {
146 map.insert("__overall__".to_string(), overall.clone());
147 }
148
149 map.insert(
151 "__metadata__".to_string(),
152 MetricValue::String(serde_json::to_string(&self.metadata).unwrap_or_default()),
153 );
154
155 MetricValue::Map(map)
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct GroupedMetadata {
162 pub group_columns: Vec<String>,
164
165 pub total_groups: usize,
167
168 pub included_groups: usize,
170
171 pub truncated: bool,
173
174 pub overflow_strategy: Option<OverflowStrategy>,
176}
177
178impl GroupedMetadata {
179 pub fn new(group_columns: Vec<String>, total_groups: usize, included_groups: usize) -> Self {
181 Self {
182 group_columns,
183 total_groups,
184 included_groups,
185 truncated: total_groups > included_groups,
186 overflow_strategy: None,
187 }
188 }
189}
190
191#[async_trait]
193pub trait GroupedAnalyzer: Analyzer {
194 type GroupedState: GroupedAnalyzerState;
196
197 fn with_grouping(self, config: GroupingConfig) -> GroupedAnalyzerWrapper<Self>
199 where
200 Self: Sized + 'static,
201 {
202 GroupedAnalyzerWrapper::new(self, config)
203 }
204
205 async fn compute_grouped_state_from_data(
210 &self,
211 ctx: &SessionContext,
212 config: &GroupingConfig,
213 ) -> AnalyzerResult<Self::GroupedState>;
214
215 fn compute_grouped_metrics_from_state(
217 &self,
218 state: &Self::GroupedState,
219 ) -> AnalyzerResult<GroupedMetrics>;
220}
221
222pub trait GroupedAnalyzerState: AnalyzerState {}
226
227pub struct GroupedAnalyzerWrapper<A: GroupedAnalyzer> {
229 analyzer: A,
231
232 config: GroupingConfig,
234}
235
236impl<A: GroupedAnalyzer> GroupedAnalyzerWrapper<A> {
237 pub fn new(analyzer: A, config: GroupingConfig) -> Self {
239 Self { analyzer, config }
240 }
241}
242
243impl<A: GroupedAnalyzer> fmt::Debug for GroupedAnalyzerWrapper<A> {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 f.debug_struct("GroupedAnalyzerWrapper")
246 .field("analyzer", &self.analyzer.name())
247 .field("group_columns", &self.config.columns)
248 .finish()
249 }
250}
251
252#[async_trait]
253impl<A> Analyzer for GroupedAnalyzerWrapper<A>
254where
255 A: GroupedAnalyzer + Send + Sync + 'static,
256 A::GroupedState: AnalyzerState + 'static,
257{
258 type State = A::GroupedState;
259 type Metric = MetricValue;
260
261 #[instrument(skip(ctx), fields(
262 analyzer = %self.analyzer.name(),
263 group_columns = ?self.config.columns
264 ))]
265 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
266 debug!(
267 "Computing grouped state for {} analyzer",
268 self.analyzer.name()
269 );
270 self.analyzer
271 .compute_grouped_state_from_data(ctx, &self.config)
272 .await
273 }
274
275 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
276 let grouped_metrics = self.analyzer.compute_grouped_metrics_from_state(state)?;
277 Ok(grouped_metrics.to_metric_value())
278 }
279
280 fn name(&self) -> &str {
281 self.analyzer.name()
282 }
283
284 fn description(&self) -> &str {
285 self.analyzer.description()
286 }
287
288 fn metric_key(&self) -> String {
289 format!(
290 "{}_grouped_by_{}",
291 self.analyzer.metric_key(),
292 self.config.columns.join("_")
293 )
294 }
295
296 fn columns(&self) -> Vec<&str> {
297 let mut cols = self.analyzer.columns();
298 for col in &self.config.columns {
299 cols.push(col);
300 }
301 cols
302 }
303}
304
305pub mod sql_helpers {
307 use super::GroupingConfig;
308
309 pub fn build_group_by_clause(config: &GroupingConfig) -> String {
311 if config.columns.is_empty() {
312 String::new()
313 } else {
314 format!(" GROUP BY {}", config.group_by_sql())
315 }
316 }
317
318 pub fn build_group_select(config: &GroupingConfig, metric_sql: &str) -> String {
320 if config.columns.is_empty() {
321 metric_sql.to_string()
322 } else {
323 format!("{}, {metric_sql}", config.select_group_columns_sql())
324 }
325 }
326
327 pub fn build_limit_clause(config: &GroupingConfig) -> String {
329 if let Some(max) = config.max_groups {
330 format!(" LIMIT {max}")
331 } else {
332 String::new()
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_grouping_config() {
343 let config = GroupingConfig::new(vec!["country".to_string(), "city".to_string()])
344 .with_max_groups(1000)
345 .with_overall(false);
346
347 assert_eq!(config.group_by_sql(), "country, city");
348 assert_eq!(
349 config.select_group_columns_sql(),
350 "country as group_country, city as group_city"
351 );
352 assert_eq!(config.max_groups, Some(1000));
353 assert!(!config.include_overall);
354 }
355
356 #[test]
357 fn test_grouped_metrics() {
358 let mut groups = BTreeMap::new();
359 groups.insert(
360 vec!["US".to_string(), "NYC".to_string()],
361 MetricValue::Double(0.95),
362 );
363 groups.insert(
364 vec!["US".to_string(), "LA".to_string()],
365 MetricValue::Double(0.92),
366 );
367
368 let metadata = GroupedMetadata::new(vec!["country".to_string(), "city".to_string()], 2, 2);
369
370 let grouped = GroupedMetrics::new(groups, Some(MetricValue::Double(0.935)), metadata);
371
372 assert_eq!(grouped.group_count(), 2);
373 assert!(!grouped.is_truncated());
374
375 let us_nyc = grouped.get_group(&["US".to_string(), "NYC".to_string()]);
376 assert_eq!(us_nyc, Some(&MetricValue::Double(0.95)));
377 }
378}