term_guard/analyzers/advanced/
standard_deviation.rs1use arrow::array::Array;
4use async_trait::async_trait;
5use datafusion::prelude::*;
6use serde::{Deserialize, Serialize};
7use tracing::instrument;
8
9use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
10
11use crate::core::current_validation_context;
12#[derive(Debug, Clone)]
42pub struct StandardDeviationAnalyzer {
43 column: String,
45}
46
47impl StandardDeviationAnalyzer {
48 pub fn new(column: impl Into<String>) -> Self {
50 Self {
51 column: column.into(),
52 }
53 }
54
55 pub fn column(&self) -> &str {
57 &self.column
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct StandardDeviationState {
64 pub count: u64,
66 pub sum: f64,
68 pub sum_squared: f64,
70 pub mean: f64,
72}
73
74impl StandardDeviationState {
75 pub fn population_std_dev(&self) -> Option<f64> {
77 if self.count == 0 {
78 None
79 } else {
80 let variance = self.population_variance()?;
81 Some(variance.sqrt())
82 }
83 }
84
85 pub fn sample_std_dev(&self) -> Option<f64> {
87 if self.count <= 1 {
88 None
89 } else {
90 let variance = self.sample_variance()?;
91 Some(variance.sqrt())
92 }
93 }
94
95 pub fn population_variance(&self) -> Option<f64> {
97 if self.count == 0 {
98 None
99 } else {
100 let mean_of_squares = self.sum_squared / self.count as f64;
102 let variance = mean_of_squares - (self.mean * self.mean);
103 Some(variance.max(0.0))
105 }
106 }
107
108 pub fn sample_variance(&self) -> Option<f64> {
110 if self.count <= 1 {
111 None
112 } else {
113 let sum_of_squared_deviations =
115 self.sum_squared - (self.sum * self.sum / self.count as f64);
116 let variance = sum_of_squared_deviations / (self.count - 1) as f64;
117 Some(variance.max(0.0))
118 }
119 }
120
121 pub fn coefficient_of_variation(&self) -> Option<f64> {
123 let std_dev = self.population_std_dev()?;
124 if self.mean.abs() < f64::EPSILON {
125 None
126 } else {
127 Some(std_dev / self.mean.abs())
128 }
129 }
130}
131
132impl AnalyzerState for StandardDeviationState {
133 fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
134 if states.is_empty() {
135 return Err(AnalyzerError::state_merge("No states to merge"));
136 }
137
138 let count: u64 = states.iter().map(|s| s.count).sum();
139 let sum: f64 = states.iter().map(|s| s.sum).sum();
140 let sum_squared: f64 = states.iter().map(|s| s.sum_squared).sum();
141
142 let mean = if count > 0 { sum / count as f64 } else { 0.0 };
143
144 Ok(StandardDeviationState {
145 count,
146 sum,
147 sum_squared,
148 mean,
149 })
150 }
151
152 fn is_empty(&self) -> bool {
153 self.count == 0
154 }
155}
156
157#[async_trait]
158impl Analyzer for StandardDeviationAnalyzer {
159 type State = StandardDeviationState;
160 type Metric = MetricValue;
161
162 #[instrument(skip(ctx), fields(analyzer = "standard_deviation", column = %self.column))]
163 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
164 let validation_ctx = current_validation_context();
168
169 let table_name = validation_ctx.table_name();
170
171 let sql = format!(
172 "SELECT
173 COUNT({0}) as count,
174 AVG({0}) as mean,
175 SUM({0}) as sum,
176 SUM({0} * {0}) as sum_squared
177 FROM {table_name}
178 WHERE {0} IS NOT NULL",
179 self.column
180 );
181
182 let df = ctx.sql(&sql).await?;
184 let batches = df.collect().await?;
185
186 let (count, mean, sum, sum_squared) = if let Some(batch) = batches.first() {
188 if batch.num_rows() > 0 && !batch.column(0).is_null(0) {
189 let count_array = batch
190 .column(0)
191 .as_any()
192 .downcast_ref::<arrow::array::Int64Array>()
193 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?;
194 let count = count_array.value(0) as u64;
195
196 if count == 0 {
197 (0, 0.0, 0.0, 0.0)
198 } else {
199 let mean_array = batch
200 .column(1)
201 .as_any()
202 .downcast_ref::<arrow::array::Float64Array>()
203 .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for mean"))?;
204 let mean = mean_array.value(0);
205
206 let sum_array = batch
207 .column(2)
208 .as_any()
209 .downcast_ref::<arrow::array::Float64Array>()
210 .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum"))?;
211 let sum = sum_array.value(0);
212
213 let sum_squared_array = batch
214 .column(3)
215 .as_any()
216 .downcast_ref::<arrow::array::Float64Array>()
217 .ok_or_else(|| {
218 AnalyzerError::invalid_data("Expected Float64 for sum_squared")
219 })?;
220 let sum_squared = sum_squared_array.value(0);
221
222 (count, mean, sum, sum_squared)
223 }
224 } else {
225 (0, 0.0, 0.0, 0.0)
226 }
227 } else {
228 return Err(AnalyzerError::NoData);
229 };
230
231 Ok(StandardDeviationState {
232 count,
233 sum,
234 sum_squared,
235 mean,
236 })
237 }
238
239 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
240 use std::collections::HashMap;
241
242 let mut stats = HashMap::new();
243
244 stats.insert("count".to_string(), MetricValue::Long(state.count as i64));
246 stats.insert("mean".to_string(), MetricValue::Double(state.mean));
247
248 if let Some(pop_std_dev) = state.population_std_dev() {
250 stats.insert("std_dev".to_string(), MetricValue::Double(pop_std_dev));
251 }
252
253 if let Some(sample_std_dev) = state.sample_std_dev() {
254 stats.insert(
255 "sample_std_dev".to_string(),
256 MetricValue::Double(sample_std_dev),
257 );
258 }
259
260 if let Some(pop_variance) = state.population_variance() {
261 stats.insert("variance".to_string(), MetricValue::Double(pop_variance));
262 }
263
264 if let Some(sample_variance) = state.sample_variance() {
265 stats.insert(
266 "sample_variance".to_string(),
267 MetricValue::Double(sample_variance),
268 );
269 }
270
271 if let Some(cv) = state.coefficient_of_variation() {
272 stats.insert(
273 "coefficient_of_variation".to_string(),
274 MetricValue::Double(cv),
275 );
276 }
277
278 Ok(MetricValue::Map(stats))
279 }
280
281 fn name(&self) -> &str {
282 "standard_deviation"
283 }
284
285 fn description(&self) -> &str {
286 "Computes standard deviation and variance metrics"
287 }
288
289 fn columns(&self) -> Vec<&str> {
290 vec![&self.column]
291 }
292}