term_guard/analyzers/basic/
grouped_completeness.rs1use async_trait::async_trait;
4use datafusion::arrow::array::{StringArray, UInt64Array};
5use datafusion::arrow::record_batch::RecordBatch;
6use datafusion::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::BTreeMap;
9use tracing::{debug, instrument};
10
11use crate::analyzers::{
12 grouped::{
13 GroupedAnalyzer, GroupedAnalyzerState, GroupedMetadata, GroupedMetrics, GroupingConfig,
14 },
15 AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue,
16};
17use crate::core::current_validation_context;
18
19use super::completeness::{CompletenessAnalyzer, CompletenessState};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GroupedCompletenessState {
24 pub groups: BTreeMap<Vec<String>, CompletenessState>,
26
27 pub overall: Option<CompletenessState>,
29
30 pub metadata: GroupedMetadata,
32}
33
34impl GroupedAnalyzerState for GroupedCompletenessState {}
35
36impl AnalyzerState for GroupedCompletenessState {
37 fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
38 if states.is_empty() {
39 return Err(AnalyzerError::InvalidConfiguration(
40 "Cannot merge empty states".to_string(),
41 ));
42 }
43
44 let mut merged_groups: BTreeMap<Vec<String>, Vec<CompletenessState>> = BTreeMap::new();
45 let mut overall_states = Vec::new();
46 let mut metadata = states[0].metadata.clone();
47
48 for state in states {
50 for (key, group_state) in state.groups {
51 merged_groups.entry(key).or_default().push(group_state);
52 }
53
54 if let Some(overall) = state.overall {
55 overall_states.push(overall);
56 }
57
58 metadata.total_groups = metadata.total_groups.max(state.metadata.total_groups);
60 }
61
62 let mut final_groups = BTreeMap::new();
64 for (key, states) in merged_groups {
65 let merged = CompletenessState::merge(states)?;
66 final_groups.insert(key, merged);
67 }
68
69 let overall = if !overall_states.is_empty() {
71 Some(CompletenessState::merge(overall_states)?)
72 } else {
73 None
74 };
75
76 metadata.included_groups = final_groups.len();
77 metadata.truncated = metadata.total_groups > metadata.included_groups;
78
79 Ok(GroupedCompletenessState {
80 groups: final_groups,
81 overall,
82 metadata,
83 })
84 }
85
86 fn is_empty(&self) -> bool {
87 self.groups.is_empty()
88 }
89}
90
91#[async_trait]
92impl GroupedAnalyzer for CompletenessAnalyzer {
93 type GroupedState = GroupedCompletenessState;
94
95 #[instrument(skip(ctx), fields(
96 analyzer = "grouped_completeness",
97 column = %self.column()
98 ))]
99 async fn compute_grouped_state_from_data(
100 &self,
101 ctx: &SessionContext,
102 config: &GroupingConfig,
103 ) -> AnalyzerResult<Self::GroupedState> {
104 let table_name = current_validation_context().table_name().to_string();
105
106 let group_columns = if config.columns.is_empty() {
108 String::new()
109 } else {
110 format!("{}, ", config.columns.join(", "))
111 };
112
113 let column = self.column();
114 let group_by = if config.columns.is_empty() {
115 String::new()
116 } else {
117 format!("GROUP BY {}", config.columns.join(", "))
118 };
119 let order_by = if config.columns.is_empty() {
120 String::new()
121 } else {
122 format!("ORDER BY COUNT({column}) * 1.0 / COUNT(*) DESC")
124 };
125 let limit = if let Some(max) = config.max_groups {
126 format!("LIMIT {}", max + 1) } else {
128 String::new()
129 };
130
131 let sql = format!(
132 "SELECT {group_columns}
133 COUNT(*) as total_count,
134 COUNT({column}) as non_null_count
135 FROM {table_name}
136 {group_by}
137 {order_by}
138 {limit}"
139 );
140
141 debug!("Executing grouped completeness query: {}", sql);
142
143 let df = ctx
144 .sql(&sql)
145 .await
146 .map_err(|e| AnalyzerError::Custom(format!("Failed to execute grouped query: {e}")))?;
147
148 let batches = df
149 .collect()
150 .await
151 .map_err(|e| AnalyzerError::Custom(format!("Failed to collect results: {e}")))?;
152
153 let mut groups = BTreeMap::new();
154 let mut total_groups = 0;
155 let mut overall_total = 0u64;
156 let mut overall_non_null = 0u64;
157
158 for batch in &batches {
159 total_groups += batch.num_rows();
160
161 if let Some(max_groups) = config.max_groups {
162 if groups.len() >= max_groups {
163 break;
164 }
165 }
166
167 let group_values = extract_group_values(batch, &config.columns)?;
169 let totals = extract_counts(batch, "total_count")?;
170 let non_nulls = extract_counts(batch, "non_null_count")?;
171
172 for i in 0..batch.num_rows() {
173 if let Some(max_groups) = config.max_groups {
174 if groups.len() >= max_groups {
175 break;
176 }
177 }
178
179 let key = group_values
180 .get(i)
181 .cloned()
182 .unwrap_or_else(|| vec!["NULL".to_string(); config.columns.len()]);
183
184 let state = CompletenessState {
185 total_count: totals[i],
186 non_null_count: non_nulls[i],
187 };
188
189 groups.insert(key, state);
190
191 if config.include_overall {
192 overall_total += totals[i];
193 overall_non_null += non_nulls[i];
194 }
195 }
196 }
197
198 let overall = if config.include_overall {
199 Some(CompletenessState {
200 total_count: overall_total,
201 non_null_count: overall_non_null,
202 })
203 } else {
204 None
205 };
206
207 let metadata = GroupedMetadata::new(config.columns.clone(), total_groups, groups.len());
208
209 Ok(GroupedCompletenessState {
210 groups,
211 overall,
212 metadata,
213 })
214 }
215
216 fn compute_grouped_metrics_from_state(
217 &self,
218 state: &Self::GroupedState,
219 ) -> AnalyzerResult<GroupedMetrics> {
220 let mut metric_groups = BTreeMap::new();
221
222 for (key, group_state) in &state.groups {
224 let completeness = group_state.completeness();
225 metric_groups.insert(key.clone(), MetricValue::Double(completeness));
226 }
227
228 let overall_metric = state
230 .overall
231 .as_ref()
232 .map(|s| MetricValue::Double(s.completeness()));
233
234 Ok(GroupedMetrics::new(
235 metric_groups,
236 overall_metric,
237 state.metadata.clone(),
238 ))
239 }
240}
241
242fn extract_group_values(
244 batch: &RecordBatch,
245 group_columns: &[String],
246) -> AnalyzerResult<Vec<Vec<String>>> {
247 let mut result = vec![vec![]; batch.num_rows()];
248
249 for col_name in group_columns {
250 let col_idx = batch
251 .schema()
252 .index_of(col_name)
253 .map_err(|_| AnalyzerError::Custom(format!("Column {col_name} not found")))?;
254
255 let array = batch.column(col_idx);
256
257 if let Some(string_array) = array.as_any().downcast_ref::<StringArray>() {
259 for (i, row) in result.iter_mut().enumerate().take(batch.num_rows()) {
260 let value = string_array.value(i).to_string();
261 row.push(value);
262 }
263 } else {
264 for row in result.iter_mut().take(batch.num_rows()) {
266 row.push(format!("{array:?}"));
267 }
268 }
269 }
270
271 Ok(result)
272}
273
274fn extract_counts(batch: &RecordBatch, column_name: &str) -> AnalyzerResult<Vec<u64>> {
276 let col_idx = batch
277 .schema()
278 .index_of(column_name)
279 .map_err(|_| AnalyzerError::Custom(format!("Column {column_name} not found")))?;
280
281 let array = batch.column(col_idx);
282
283 if let Some(uint_array) = array.as_any().downcast_ref::<UInt64Array>() {
285 Ok((0..batch.num_rows()).map(|i| uint_array.value(i)).collect())
286 } else if let Some(int_array) = array
287 .as_any()
288 .downcast_ref::<datafusion::arrow::array::Int64Array>()
289 {
290 Ok((0..batch.num_rows())
291 .map(|i| int_array.value(i) as u64)
292 .collect())
293 } else {
294 Err(AnalyzerError::Custom(format!(
295 "Expected Int64Array or UInt64Array for {column_name}, got {:?}",
296 array.data_type()
297 )))
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use datafusion::arrow::array::{Int32Array, StringArray};
305 use datafusion::arrow::datatypes::{DataType, Field, Schema};
306 use datafusion::datasource::MemTable;
307 use std::sync::Arc;
308
309 #[tokio::test]
310 async fn test_grouped_completeness() {
311 let schema = Arc::new(Schema::new(vec![
313 Field::new("region", DataType::Utf8, false),
314 Field::new("product", DataType::Utf8, false),
315 Field::new("sales", DataType::Int32, true),
316 ]));
317
318 let regions = StringArray::from(vec!["US", "US", "EU", "EU", "US", "EU"]);
319 let products = StringArray::from(vec!["A", "B", "A", "B", "A", "A"]);
320 let sales = Int32Array::from(vec![
321 Some(100),
322 Some(200),
323 None, Some(150),
325 Some(250),
326 Some(300),
327 ]);
328
329 let batch = RecordBatch::try_new(
330 schema.clone(),
331 vec![Arc::new(regions), Arc::new(products), Arc::new(sales)],
332 )
333 .unwrap();
334
335 let ctx = SessionContext::new();
337 let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
338 ctx.register_table("data", Arc::new(table)).unwrap();
339
340 let analyzer = CompletenessAnalyzer::new("sales");
342 let config = GroupingConfig::new(vec!["region".to_string(), "product".to_string()]);
343
344 let state = analyzer
345 .compute_grouped_state_from_data(&ctx, &config)
346 .await
347 .unwrap();
348
349 assert_eq!(state.groups.len(), 4); let us_a = state
354 .groups
355 .get(&vec!["US".to_string(), "A".to_string()])
356 .unwrap();
357 assert_eq!(us_a.completeness(), 1.0);
358
359 let eu_a = state
361 .groups
362 .get(&vec!["EU".to_string(), "A".to_string()])
363 .unwrap();
364 assert_eq!(eu_a.completeness(), 0.5);
365 }
366}