term_guard/analyzers/advanced/
standard_deviation.rs

1//! Standard deviation analyzer for measuring data spread.
2
3use arrow::array::Array;
4use async_trait::async_trait;
5use datafusion::prelude::*;
6use serde::{Deserialize, Serialize};
7use tracing::instrument;
8
9use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
10
11use crate::core::current_validation_context;
12/// Analyzer that computes standard deviation and variance.
13///
14/// This analyzer calculates both population and sample standard deviation,
15/// providing insights into data variability and spread. It uses numerically
16/// stable algorithms to avoid precision loss with large datasets.
17///
18/// # Example
19///
20/// ```rust,ignore
21/// use term_guard::analyzers::advanced::StandardDeviationAnalyzer;
22/// use datafusion::prelude::*;
23///
24/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
25/// let ctx = SessionContext::new();
26/// // Register your data table
27///
28/// let analyzer = StandardDeviationAnalyzer::new("temperature");
29/// let state = analyzer.compute_state_from_data(&ctx).await?;
30/// let metric = analyzer.compute_metric_from_state(&state)?;
31///
32/// if let MetricValue::Map(stats) = metric {
33///     println!("Temperature statistics:");
34///     println!("  Standard Deviation: {:?}", stats.get("std_dev"));
35///     println!("  Variance: {:?}", stats.get("variance"));
36///     println!("  Sample Std Dev: {:?}", stats.get("sample_std_dev"));
37/// }
38/// # Ok(())
39/// # }
40/// ```
41#[derive(Debug, Clone)]
42pub struct StandardDeviationAnalyzer {
43    /// The column to analyze.
44    column: String,
45}
46
47impl StandardDeviationAnalyzer {
48    /// Creates a new standard deviation analyzer for the specified column.
49    pub fn new(column: impl Into<String>) -> Self {
50        Self {
51            column: column.into(),
52        }
53    }
54
55    /// Returns the column being analyzed.
56    pub fn column(&self) -> &str {
57        &self.column
58    }
59}
60
61/// State for the standard deviation analyzer.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct StandardDeviationState {
64    /// Count of non-null values.
65    pub count: u64,
66    /// Sum of values.
67    pub sum: f64,
68    /// Sum of squared values.
69    pub sum_squared: f64,
70    /// Mean value.
71    pub mean: f64,
72}
73
74impl StandardDeviationState {
75    /// Calculates the population standard deviation.
76    pub fn population_std_dev(&self) -> Option<f64> {
77        if self.count == 0 {
78            None
79        } else {
80            let variance = self.population_variance()?;
81            Some(variance.sqrt())
82        }
83    }
84
85    /// Calculates the sample standard deviation.
86    pub fn sample_std_dev(&self) -> Option<f64> {
87        if self.count <= 1 {
88            None
89        } else {
90            let variance = self.sample_variance()?;
91            Some(variance.sqrt())
92        }
93    }
94
95    /// Calculates the population variance.
96    pub fn population_variance(&self) -> Option<f64> {
97        if self.count == 0 {
98            None
99        } else {
100            // Var(X) = E[X²] - E[X]²
101            let mean_of_squares = self.sum_squared / self.count as f64;
102            let variance = mean_of_squares - (self.mean * self.mean);
103            // Ensure non-negative due to floating point precision
104            Some(variance.max(0.0))
105        }
106    }
107
108    /// Calculates the sample variance.
109    pub fn sample_variance(&self) -> Option<f64> {
110        if self.count <= 1 {
111            None
112        } else {
113            // Sample variance uses n-1 in denominator (Bessel's correction)
114            let sum_of_squared_deviations =
115                self.sum_squared - (self.sum * self.sum / self.count as f64);
116            let variance = sum_of_squared_deviations / (self.count - 1) as f64;
117            Some(variance.max(0.0))
118        }
119    }
120
121    /// Calculates the coefficient of variation (CV).
122    pub fn coefficient_of_variation(&self) -> Option<f64> {
123        let std_dev = self.population_std_dev()?;
124        if self.mean.abs() < f64::EPSILON {
125            None
126        } else {
127            Some(std_dev / self.mean.abs())
128        }
129    }
130}
131
132impl AnalyzerState for StandardDeviationState {
133    fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
134        if states.is_empty() {
135            return Err(AnalyzerError::state_merge("No states to merge"));
136        }
137
138        let count: u64 = states.iter().map(|s| s.count).sum();
139        let sum: f64 = states.iter().map(|s| s.sum).sum();
140        let sum_squared: f64 = states.iter().map(|s| s.sum_squared).sum();
141
142        let mean = if count > 0 { sum / count as f64 } else { 0.0 };
143
144        Ok(StandardDeviationState {
145            count,
146            sum,
147            sum_squared,
148            mean,
149        })
150    }
151
152    fn is_empty(&self) -> bool {
153        self.count == 0
154    }
155}
156
157#[async_trait]
158impl Analyzer for StandardDeviationAnalyzer {
159    type State = StandardDeviationState;
160    type Metric = MetricValue;
161
162    #[instrument(skip(ctx), fields(analyzer = "standard_deviation", column = %self.column))]
163    async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
164        // Build SQL query to compute statistics
165        // Get the table name from the validation context
166
167        let validation_ctx = current_validation_context();
168
169        let table_name = validation_ctx.table_name();
170
171        let sql = format!(
172            "SELECT 
173                COUNT({0}) as count,
174                AVG({0}) as mean,
175                SUM({0}) as sum,
176                SUM({0} * {0}) as sum_squared
177            FROM {table_name} 
178            WHERE {0} IS NOT NULL",
179            self.column
180        );
181
182        // Execute query
183        let df = ctx.sql(&sql).await?;
184        let batches = df.collect().await?;
185
186        // Extract statistics from result
187        let (count, mean, sum, sum_squared) = if let Some(batch) = batches.first() {
188            if batch.num_rows() > 0 && !batch.column(0).is_null(0) {
189                let count_array = batch
190                    .column(0)
191                    .as_any()
192                    .downcast_ref::<arrow::array::Int64Array>()
193                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?;
194                let count = count_array.value(0) as u64;
195
196                if count == 0 {
197                    (0, 0.0, 0.0, 0.0)
198                } else {
199                    let mean_array = batch
200                        .column(1)
201                        .as_any()
202                        .downcast_ref::<arrow::array::Float64Array>()
203                        .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for mean"))?;
204                    let mean = mean_array.value(0);
205
206                    let sum_array = batch
207                        .column(2)
208                        .as_any()
209                        .downcast_ref::<arrow::array::Float64Array>()
210                        .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum"))?;
211                    let sum = sum_array.value(0);
212
213                    let sum_squared_array = batch
214                        .column(3)
215                        .as_any()
216                        .downcast_ref::<arrow::array::Float64Array>()
217                        .ok_or_else(|| {
218                            AnalyzerError::invalid_data("Expected Float64 for sum_squared")
219                        })?;
220                    let sum_squared = sum_squared_array.value(0);
221
222                    (count, mean, sum, sum_squared)
223                }
224            } else {
225                (0, 0.0, 0.0, 0.0)
226            }
227        } else {
228            return Err(AnalyzerError::NoData);
229        };
230
231        Ok(StandardDeviationState {
232            count,
233            sum,
234            sum_squared,
235            mean,
236        })
237    }
238
239    fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
240        use std::collections::HashMap;
241
242        let mut stats = HashMap::new();
243
244        // Add basic statistics
245        stats.insert("count".to_string(), MetricValue::Long(state.count as i64));
246        stats.insert("mean".to_string(), MetricValue::Double(state.mean));
247
248        // Add standard deviations and variances
249        if let Some(pop_std_dev) = state.population_std_dev() {
250            stats.insert("std_dev".to_string(), MetricValue::Double(pop_std_dev));
251        }
252
253        if let Some(sample_std_dev) = state.sample_std_dev() {
254            stats.insert(
255                "sample_std_dev".to_string(),
256                MetricValue::Double(sample_std_dev),
257            );
258        }
259
260        if let Some(pop_variance) = state.population_variance() {
261            stats.insert("variance".to_string(), MetricValue::Double(pop_variance));
262        }
263
264        if let Some(sample_variance) = state.sample_variance() {
265            stats.insert(
266                "sample_variance".to_string(),
267                MetricValue::Double(sample_variance),
268            );
269        }
270
271        if let Some(cv) = state.coefficient_of_variation() {
272            stats.insert(
273                "coefficient_of_variation".to_string(),
274                MetricValue::Double(cv),
275            );
276        }
277
278        Ok(MetricValue::Map(stats))
279    }
280
281    fn name(&self) -> &str {
282        "standard_deviation"
283    }
284
285    fn description(&self) -> &str {
286        "Computes standard deviation and variance metrics"
287    }
288
289    fn columns(&self) -> Vec<&str> {
290        vec![&self.column]
291    }
292}