term_guard/analyzers/basic/
mean.rs

1//! Mean analyzer for computing average values.
2
3use async_trait::async_trait;
4use datafusion::prelude::*;
5use serde::{Deserialize, Serialize};
6use tracing::instrument;
7
8use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
9
10use crate::core::current_validation_context;
11/// Analyzer that computes the mean (average) value of a numeric column.
12///
13/// The mean is calculated using incremental computation to support
14/// distributed processing and efficient state merging.
15///
16/// # Example
17///
18/// ```rust,ignore
19/// use term_guard::analyzers::basic::MeanAnalyzer;
20/// use datafusion::prelude::*;
21///
22/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
23/// let ctx = SessionContext::new();
24/// // Register your data table
25///
26/// let analyzer = MeanAnalyzer::new("age");
27/// let state = analyzer.compute_state_from_data(&ctx).await?;
28/// let metric = analyzer.compute_metric_from_state(&state)?;
29///
30/// if let MetricValue::Double(mean) = metric {
31///     println!("Average age: {:.2}", mean);
32/// }
33/// # Ok(())
34/// # }
35/// ```
36#[derive(Debug, Clone)]
37pub struct MeanAnalyzer {
38    /// The column to analyze.
39    column: String,
40}
41
42impl MeanAnalyzer {
43    /// Creates a new mean analyzer for the specified column.
44    pub fn new(column: impl Into<String>) -> Self {
45        Self {
46            column: column.into(),
47        }
48    }
49
50    /// Returns the column being analyzed.
51    pub fn column(&self) -> &str {
52        &self.column
53    }
54}
55
56/// State for the mean analyzer supporting incremental computation.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MeanState {
59    /// Sum of all values.
60    pub sum: f64,
61    /// Count of non-null values.
62    pub count: u64,
63}
64
65impl MeanState {
66    /// Calculates the mean value.
67    pub fn mean(&self) -> Option<f64> {
68        if self.count == 0 {
69            None
70        } else {
71            Some(self.sum / self.count as f64)
72        }
73    }
74}
75
76impl AnalyzerState for MeanState {
77    fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
78        let sum = states.iter().map(|s| s.sum).sum();
79        let count = states.iter().map(|s| s.count).sum();
80
81        Ok(MeanState { sum, count })
82    }
83
84    fn is_empty(&self) -> bool {
85        self.count == 0
86    }
87}
88
89#[async_trait]
90impl Analyzer for MeanAnalyzer {
91    type State = MeanState;
92    type Metric = MetricValue;
93
94    #[instrument(skip(ctx), fields(analyzer = "mean", column = %self.column))]
95    async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
96        // Build SQL query to compute sum and count for incremental mean calculation
97        // Get the table name from the validation context
98
99        let validation_ctx = current_validation_context();
100
101        let table_name = validation_ctx.table_name();
102
103        let sql = format!(
104            "SELECT SUM({0}) as sum, COUNT({0}) as count FROM {table_name}",
105            self.column
106        );
107
108        // Execute query
109        let df = ctx.sql(&sql).await?;
110        let batches = df.collect().await?;
111
112        // Extract sum and count from result
113        let (sum, count) = if let Some(batch) = batches.first() {
114            if batch.num_rows() > 0 {
115                // Sum can be Float64 or null
116                let sum = if batch.column(0).is_null(0) {
117                    0.0
118                } else {
119                    let sum_array = batch
120                        .column(0)
121                        .as_any()
122                        .downcast_ref::<arrow::array::Float64Array>()
123                        .ok_or_else(|| {
124                            AnalyzerError::invalid_data("Expected Float64 array for sum")
125                        })?;
126                    sum_array.value(0)
127                };
128
129                let count_array = batch
130                    .column(1)
131                    .as_any()
132                    .downcast_ref::<arrow::array::Int64Array>()
133                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 array for count"))?;
134                let count = count_array.value(0) as u64;
135
136                (sum, count)
137            } else {
138                (0.0, 0)
139            }
140        } else {
141            (0.0, 0)
142        };
143
144        Ok(MeanState { sum, count })
145    }
146
147    fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
148        match state.mean() {
149            Some(mean) => Ok(MetricValue::Double(mean)),
150            None => Err(AnalyzerError::NoData),
151        }
152    }
153
154    fn name(&self) -> &str {
155        "mean"
156    }
157
158    fn description(&self) -> &str {
159        "Computes the average value of a numeric column"
160    }
161
162    fn metric_key(&self) -> String {
163        format!("{}.{}", self.name(), self.column)
164    }
165
166    fn columns(&self) -> Vec<&str> {
167        vec![&self.column]
168    }
169}