term_guard/analyzers/basic/
grouped_completeness.rs

1//! Grouped completeness analyzer implementation.
2
3use async_trait::async_trait;
4use datafusion::arrow::array::{StringArray, UInt64Array};
5use datafusion::arrow::record_batch::RecordBatch;
6use datafusion::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::BTreeMap;
9use tracing::{debug, instrument};
10
11use crate::analyzers::{
12    grouped::{
13        GroupedAnalyzer, GroupedAnalyzerState, GroupedMetadata, GroupedMetrics, GroupingConfig,
14    },
15    AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue,
16};
17use crate::core::current_validation_context;
18
19use super::completeness::{CompletenessAnalyzer, CompletenessState};
20
21/// State for grouped completeness analysis.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GroupedCompletenessState {
24    /// Map of group keys to completeness states.
25    pub groups: BTreeMap<Vec<String>, CompletenessState>,
26
27    /// Overall state across all groups (if requested).
28    pub overall: Option<CompletenessState>,
29
30    /// Metadata about the computation.
31    pub metadata: GroupedMetadata,
32}
33
34impl GroupedAnalyzerState for GroupedCompletenessState {}
35
36impl AnalyzerState for GroupedCompletenessState {
37    fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
38        if states.is_empty() {
39            return Err(AnalyzerError::InvalidConfiguration(
40                "Cannot merge empty states".to_string(),
41            ));
42        }
43
44        let mut merged_groups: BTreeMap<Vec<String>, Vec<CompletenessState>> = BTreeMap::new();
45        let mut overall_states = Vec::new();
46        let mut metadata = states[0].metadata.clone();
47
48        // Collect states by group
49        for state in states {
50            for (key, group_state) in state.groups {
51                merged_groups.entry(key).or_default().push(group_state);
52            }
53
54            if let Some(overall) = state.overall {
55                overall_states.push(overall);
56            }
57
58            // Update metadata
59            metadata.total_groups = metadata.total_groups.max(state.metadata.total_groups);
60        }
61
62        // Merge states for each group
63        let mut final_groups = BTreeMap::new();
64        for (key, states) in merged_groups {
65            let merged = CompletenessState::merge(states)?;
66            final_groups.insert(key, merged);
67        }
68
69        // Merge overall states
70        let overall = if !overall_states.is_empty() {
71            Some(CompletenessState::merge(overall_states)?)
72        } else {
73            None
74        };
75
76        metadata.included_groups = final_groups.len();
77        metadata.truncated = metadata.total_groups > metadata.included_groups;
78
79        Ok(GroupedCompletenessState {
80            groups: final_groups,
81            overall,
82            metadata,
83        })
84    }
85
86    fn is_empty(&self) -> bool {
87        self.groups.is_empty()
88    }
89}
90
91#[async_trait]
92impl GroupedAnalyzer for CompletenessAnalyzer {
93    type GroupedState = GroupedCompletenessState;
94
95    #[instrument(skip(ctx), fields(
96        analyzer = "grouped_completeness",
97        column = %self.column()
98    ))]
99    async fn compute_grouped_state_from_data(
100        &self,
101        ctx: &SessionContext,
102        config: &GroupingConfig,
103    ) -> AnalyzerResult<Self::GroupedState> {
104        let table_name = current_validation_context().table_name().to_string();
105
106        // Build grouped SQL query
107        let group_columns = if config.columns.is_empty() {
108            String::new()
109        } else {
110            format!("{}, ", config.columns.join(", "))
111        };
112
113        let column = self.column();
114        let group_by = if config.columns.is_empty() {
115            String::new()
116        } else {
117            format!("GROUP BY {}", config.columns.join(", "))
118        };
119        let order_by = if config.columns.is_empty() {
120            String::new()
121        } else {
122            // Order by completeness descending to get worst groups first if truncated
123            format!("ORDER BY COUNT({column}) * 1.0 / COUNT(*) DESC")
124        };
125        let limit = if let Some(max) = config.max_groups {
126            format!("LIMIT {}", max + 1) // +1 to detect truncation
127        } else {
128            String::new()
129        };
130
131        let sql = format!(
132            "SELECT {group_columns}
133                COUNT(*) as total_count,
134                COUNT({column}) as non_null_count
135             FROM {table_name}
136             {group_by}
137             {order_by}
138             {limit}"
139        );
140
141        debug!("Executing grouped completeness query: {}", sql);
142
143        let df = ctx
144            .sql(&sql)
145            .await
146            .map_err(|e| AnalyzerError::Custom(format!("Failed to execute grouped query: {e}")))?;
147
148        let batches = df
149            .collect()
150            .await
151            .map_err(|e| AnalyzerError::Custom(format!("Failed to collect results: {e}")))?;
152
153        let mut groups = BTreeMap::new();
154        let mut total_groups = 0;
155        let mut overall_total = 0u64;
156        let mut overall_non_null = 0u64;
157
158        for batch in &batches {
159            total_groups += batch.num_rows();
160
161            if let Some(max_groups) = config.max_groups {
162                if groups.len() >= max_groups {
163                    break;
164                }
165            }
166
167            // Extract group keys and metrics
168            let group_values = extract_group_values(batch, &config.columns)?;
169            let totals = extract_counts(batch, "total_count")?;
170            let non_nulls = extract_counts(batch, "non_null_count")?;
171
172            for i in 0..batch.num_rows() {
173                if let Some(max_groups) = config.max_groups {
174                    if groups.len() >= max_groups {
175                        break;
176                    }
177                }
178
179                let key = group_values
180                    .get(i)
181                    .cloned()
182                    .unwrap_or_else(|| vec!["NULL".to_string(); config.columns.len()]);
183
184                let state = CompletenessState {
185                    total_count: totals[i],
186                    non_null_count: non_nulls[i],
187                };
188
189                groups.insert(key, state);
190
191                if config.include_overall {
192                    overall_total += totals[i];
193                    overall_non_null += non_nulls[i];
194                }
195            }
196        }
197
198        let overall = if config.include_overall {
199            Some(CompletenessState {
200                total_count: overall_total,
201                non_null_count: overall_non_null,
202            })
203        } else {
204            None
205        };
206
207        let metadata = GroupedMetadata::new(config.columns.clone(), total_groups, groups.len());
208
209        Ok(GroupedCompletenessState {
210            groups,
211            overall,
212            metadata,
213        })
214    }
215
216    fn compute_grouped_metrics_from_state(
217        &self,
218        state: &Self::GroupedState,
219    ) -> AnalyzerResult<GroupedMetrics> {
220        let mut metric_groups = BTreeMap::new();
221
222        // Convert each group's state to a metric
223        for (key, group_state) in &state.groups {
224            let completeness = group_state.completeness();
225            metric_groups.insert(key.clone(), MetricValue::Double(completeness));
226        }
227
228        // Convert overall state to metric
229        let overall_metric = state
230            .overall
231            .as_ref()
232            .map(|s| MetricValue::Double(s.completeness()));
233
234        Ok(GroupedMetrics::new(
235            metric_groups,
236            overall_metric,
237            state.metadata.clone(),
238        ))
239    }
240}
241
242/// Helper function to extract group values from a record batch.
243fn extract_group_values(
244    batch: &RecordBatch,
245    group_columns: &[String],
246) -> AnalyzerResult<Vec<Vec<String>>> {
247    let mut result = vec![vec![]; batch.num_rows()];
248
249    for col_name in group_columns {
250        let col_idx = batch
251            .schema()
252            .index_of(col_name)
253            .map_err(|_| AnalyzerError::Custom(format!("Column {col_name} not found")))?;
254
255        let array = batch.column(col_idx);
256
257        // Convert to string representation
258        if let Some(string_array) = array.as_any().downcast_ref::<StringArray>() {
259            for (i, row) in result.iter_mut().enumerate().take(batch.num_rows()) {
260                let value = string_array.value(i).to_string();
261                row.push(value);
262            }
263        } else {
264            // Try to convert other types to string
265            for row in result.iter_mut().take(batch.num_rows()) {
266                row.push(format!("{array:?}"));
267            }
268        }
269    }
270
271    Ok(result)
272}
273
274/// Helper function to extract count values from a record batch.
275fn extract_counts(batch: &RecordBatch, column_name: &str) -> AnalyzerResult<Vec<u64>> {
276    let col_idx = batch
277        .schema()
278        .index_of(column_name)
279        .map_err(|_| AnalyzerError::Custom(format!("Column {column_name} not found")))?;
280
281    let array = batch.column(col_idx);
282
283    // DataFusion's COUNT(*) can return either Int64 or UInt64, so handle both
284    if let Some(uint_array) = array.as_any().downcast_ref::<UInt64Array>() {
285        Ok((0..batch.num_rows()).map(|i| uint_array.value(i)).collect())
286    } else if let Some(int_array) = array
287        .as_any()
288        .downcast_ref::<datafusion::arrow::array::Int64Array>()
289    {
290        Ok((0..batch.num_rows())
291            .map(|i| int_array.value(i) as u64)
292            .collect())
293    } else {
294        Err(AnalyzerError::Custom(format!(
295            "Expected Int64Array or UInt64Array for {column_name}, got {:?}",
296            array.data_type()
297        )))
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use datafusion::arrow::array::{Int32Array, StringArray};
305    use datafusion::arrow::datatypes::{DataType, Field, Schema};
306    use datafusion::datasource::MemTable;
307    use std::sync::Arc;
308
309    #[tokio::test]
310    async fn test_grouped_completeness() {
311        // Create test data with groups
312        let schema = Arc::new(Schema::new(vec![
313            Field::new("region", DataType::Utf8, false),
314            Field::new("product", DataType::Utf8, false),
315            Field::new("sales", DataType::Int32, true),
316        ]));
317
318        let regions = StringArray::from(vec!["US", "US", "EU", "EU", "US", "EU"]);
319        let products = StringArray::from(vec!["A", "B", "A", "B", "A", "A"]);
320        let sales = Int32Array::from(vec![
321            Some(100),
322            Some(200),
323            None, // EU-A has null
324            Some(150),
325            Some(250),
326            Some(300),
327        ]);
328
329        let batch = RecordBatch::try_new(
330            schema.clone(),
331            vec![Arc::new(regions), Arc::new(products), Arc::new(sales)],
332        )
333        .unwrap();
334
335        // Create context and register data
336        let ctx = SessionContext::new();
337        let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
338        ctx.register_table("data", Arc::new(table)).unwrap();
339
340        // Test grouped completeness
341        let analyzer = CompletenessAnalyzer::new("sales");
342        let config = GroupingConfig::new(vec!["region".to_string(), "product".to_string()]);
343
344        let state = analyzer
345            .compute_grouped_state_from_data(&ctx, &config)
346            .await
347            .unwrap();
348
349        // Check results
350        assert_eq!(state.groups.len(), 4); // US-A, US-B, EU-A, EU-B
351
352        // US-A should have 100% completeness (2 non-null out of 2)
353        let us_a = state
354            .groups
355            .get(&vec!["US".to_string(), "A".to_string()])
356            .unwrap();
357        assert_eq!(us_a.completeness(), 1.0);
358
359        // EU-A should have 50% completeness (1 non-null out of 2)
360        let eu_a = state
361            .groups
362            .get(&vec!["EU".to_string(), "A".to_string()])
363            .unwrap();
364        assert_eq!(eu_a.completeness(), 0.5);
365    }
366}