term_guard/analyzers/advanced/
histogram.rs1use arrow::array::Array;
4use async_trait::async_trait;
5use datafusion::prelude::*;
6use serde::{Deserialize, Serialize};
7use tracing::instrument;
8
9use crate::analyzers::{
10 types::HistogramBucket, Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState,
11 MetricDistribution, MetricValue,
12};
13use crate::core::current_validation_context;
14
15#[derive(Debug, Clone)]
46pub struct HistogramAnalyzer {
47 column: String,
49 num_buckets: usize,
51}
52
53impl HistogramAnalyzer {
54 pub fn new(column: impl Into<String>, num_buckets: usize) -> Self {
61 Self {
62 column: column.into(),
63 num_buckets: num_buckets.clamp(1, 1000),
64 }
65 }
66
67 pub fn column(&self) -> &str {
69 &self.column
70 }
71
72 pub fn num_buckets(&self) -> usize {
74 self.num_buckets
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct HistogramState {
81 pub buckets: Vec<HistogramBucket>,
83 pub min_value: f64,
85 pub max_value: f64,
87 pub total_count: u64,
89 pub sum: f64,
91 pub sum_squared: f64,
93}
94
95impl HistogramState {
96 pub fn mean(&self) -> Option<f64> {
98 if self.total_count > 0 {
99 Some(self.sum / self.total_count as f64)
100 } else {
101 None
102 }
103 }
104
105 pub fn std_dev(&self) -> Option<f64> {
107 if self.total_count > 1 {
108 let mean = self.mean()?;
109 let variance = (self.sum_squared / self.total_count as f64) - (mean * mean);
110 Some(variance.sqrt())
111 } else {
112 None
113 }
114 }
115}
116
117impl AnalyzerState for HistogramState {
118 fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
119 if states.is_empty() {
120 return Err(AnalyzerError::state_merge("No states to merge"));
121 }
122
123 let first = &states[0];
126 let mut merged_buckets = first.buckets.clone();
127
128 for state in &states[1..] {
130 if state.buckets.len() == merged_buckets.len() {
131 for (i, bucket) in state.buckets.iter().enumerate() {
132 merged_buckets[i] = HistogramBucket::new(
133 merged_buckets[i].lower_bound,
134 merged_buckets[i].upper_bound,
135 merged_buckets[i].count + bucket.count,
136 );
137 }
138 }
139 }
140
141 let min_value = states
142 .iter()
143 .map(|s| s.min_value)
144 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
145 .unwrap_or(0.0);
146
147 let max_value = states
148 .iter()
149 .map(|s| s.max_value)
150 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
151 .unwrap_or(0.0);
152
153 let total_count = states.iter().map(|s| s.total_count).sum();
154 let sum = states.iter().map(|s| s.sum).sum();
155 let sum_squared = states.iter().map(|s| s.sum_squared).sum();
156
157 Ok(HistogramState {
158 buckets: merged_buckets,
159 min_value,
160 max_value,
161 total_count,
162 sum,
163 sum_squared,
164 })
165 }
166
167 fn is_empty(&self) -> bool {
168 self.total_count == 0
169 }
170}
171
172#[async_trait]
173impl Analyzer for HistogramAnalyzer {
174 type State = HistogramState;
175 type Metric = MetricValue;
176
177 #[instrument(skip(ctx), fields(analyzer = "histogram", column = %self.column, buckets = %self.num_buckets))]
178 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
179 let validation_ctx = current_validation_context();
181 let table_name = validation_ctx.table_name();
182
183 let stats_sql = format!(
185 "SELECT
186 MIN({0}) as min_val,
187 MAX({0}) as max_val,
188 COUNT({0}) as count,
189 SUM({0}) as sum,
190 SUM({0} * {0}) as sum_squared
191 FROM {table_name}
192 WHERE {0} IS NOT NULL",
193 self.column
194 );
195
196 let stats_df = ctx.sql(&stats_sql).await?;
197 let stats_batches = stats_df.collect().await?;
198
199 let (min_value, max_value, total_count, sum, sum_squared) = if let Some(batch) =
200 stats_batches.first()
201 {
202 if batch.num_rows() > 0 && !batch.column(0).is_null(0) {
203 let min_val = batch
204 .column(0)
205 .as_any()
206 .downcast_ref::<arrow::array::Float64Array>()
207 .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for min"))?
208 .value(0);
209
210 let max_val = batch
211 .column(1)
212 .as_any()
213 .downcast_ref::<arrow::array::Float64Array>()
214 .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for max"))?
215 .value(0);
216
217 let count = batch
218 .column(2)
219 .as_any()
220 .downcast_ref::<arrow::array::Int64Array>()
221 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?
222 .value(0) as u64;
223
224 let sum_val = batch
225 .column(3)
226 .as_any()
227 .downcast_ref::<arrow::array::Float64Array>()
228 .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum"))?
229 .value(0);
230
231 let sum_sq = batch
232 .column(4)
233 .as_any()
234 .downcast_ref::<arrow::array::Float64Array>()
235 .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum_squared"))?
236 .value(0);
237
238 (min_val, max_val, count, sum_val, sum_sq)
239 } else {
240 return Ok(HistogramState {
242 buckets: vec![],
243 min_value: 0.0,
244 max_value: 0.0,
245 total_count: 0,
246 sum: 0.0,
247 sum_squared: 0.0,
248 });
249 }
250 } else {
251 return Err(AnalyzerError::NoData);
252 };
253
254 let range = max_value - min_value;
256 let bucket_width = if range > 0.0 && self.num_buckets > 1 {
257 range / self.num_buckets as f64
258 } else {
259 1.0
260 };
261
262 let mut case_clauses = Vec::new();
264 for i in 0..self.num_buckets {
265 let lower = min_value + (i as f64 * bucket_width);
266 let upper = if i == self.num_buckets - 1 {
267 max_value + bucket_width * 0.001
268 } else {
269 min_value + ((i + 1) as f64 * bucket_width)
270 };
271 case_clauses.push(format!(
272 "WHEN {0} >= {1} AND {0} < {2} THEN {3}",
273 self.column,
274 lower,
275 upper,
276 i + 1
277 ));
278 }
279
280 let histogram_sql = format!(
281 "SELECT
282 CASE
283 {}
284 ELSE {}
285 END as bucket_num,
286 COUNT(*) as count
287 FROM {table_name}
288 WHERE {} IS NOT NULL
289 GROUP BY bucket_num
290 ORDER BY bucket_num",
291 case_clauses.join(" "),
292 self.num_buckets,
293 self.column
294 );
295
296 let hist_df = ctx.sql(&histogram_sql).await?;
297 let hist_batches = hist_df.collect().await?;
298
299 let mut buckets = vec![HistogramBucket::new(0.0, 0.0, 0); self.num_buckets];
301
302 for (i, bucket) in buckets.iter_mut().enumerate() {
304 let lower = min_value + (i as f64 * bucket_width);
305 let upper = if i == self.num_buckets - 1 {
306 max_value + bucket_width * 0.001
307 } else {
308 min_value + ((i + 1) as f64 * bucket_width)
309 };
310 *bucket = HistogramBucket::new(lower, upper, 0);
311 }
312
313 for batch in &hist_batches {
315 let bucket_array = batch
316 .column(0)
317 .as_any()
318 .downcast_ref::<arrow::array::Int64Array>()
319 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for bucket_num"))?;
320
321 let count_array = batch
322 .column(1)
323 .as_any()
324 .downcast_ref::<arrow::array::Int64Array>()
325 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?;
326
327 for i in 0..batch.num_rows() {
328 let bucket_idx = (bucket_array.value(i) - 1) as usize;
329 let count = count_array.value(i) as u64;
330
331 if bucket_idx < buckets.len() {
332 buckets[bucket_idx] = HistogramBucket::new(
333 buckets[bucket_idx].lower_bound,
334 buckets[bucket_idx].upper_bound,
335 count,
336 );
337 }
338 }
339 }
340
341 Ok(HistogramState {
342 buckets,
343 min_value,
344 max_value,
345 total_count,
346 sum,
347 sum_squared,
348 })
349 }
350
351 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
352 let distribution = MetricDistribution::from_buckets(state.buckets.clone()).with_stats(
353 state.min_value,
354 state.max_value,
355 state.mean().unwrap_or(0.0),
356 state.std_dev().unwrap_or(0.0),
357 );
358
359 Ok(MetricValue::Histogram(distribution))
360 }
361
362 fn name(&self) -> &str {
363 "histogram"
364 }
365
366 fn description(&self) -> &str {
367 "Computes value distribution histogram"
368 }
369
370 fn columns(&self) -> Vec<&str> {
371 vec![&self.column]
372 }
373}