term_guard/analyzers/advanced/
histogram.rs

1//! Histogram analyzer for computing value distributions.
2
3use arrow::array::Array;
4use async_trait::async_trait;
5use datafusion::prelude::*;
6use serde::{Deserialize, Serialize};
7use tracing::instrument;
8
9use crate::analyzers::{
10    types::HistogramBucket, Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState,
11    MetricDistribution, MetricValue,
12};
13use crate::core::current_validation_context;
14
15/// Analyzer that computes histogram distributions for numeric columns.
16///
17/// This analyzer creates a histogram with configurable number of buckets,
18/// providing insights into data distribution patterns. It's memory-efficient
19/// even for high-cardinality columns by using fixed-size buckets.
20///
21/// # Example
22///
23/// ```rust,ignore
24/// use term_guard::analyzers::advanced::HistogramAnalyzer;
25/// use datafusion::prelude::*;
26///
27/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
28/// let ctx = SessionContext::new();
29/// // Register your data table
30///
31/// let analyzer = HistogramAnalyzer::new("price", 10); // 10 buckets
32/// let state = analyzer.compute_state_from_data(&ctx).await?;
33/// let metric = analyzer.compute_metric_from_state(&state)?;
34///
35/// if let MetricValue::Histogram(distribution) = metric {
36///     println!("Price distribution: {} buckets", distribution.buckets.len());
37///     for bucket in &distribution.buckets {
38///         println!("[{:.2}, {:.2}): {} items",
39///             bucket.lower_bound, bucket.upper_bound, bucket.count);
40///     }
41/// }
42/// # Ok(())
43/// # }
44/// ```
45#[derive(Debug, Clone)]
46pub struct HistogramAnalyzer {
47    /// The column to analyze.
48    column: String,
49    /// Number of histogram buckets.
50    num_buckets: usize,
51}
52
53impl HistogramAnalyzer {
54    /// Creates a new histogram analyzer with the specified number of buckets.
55    ///
56    /// # Arguments
57    ///
58    /// * `column` - The column to analyze
59    /// * `num_buckets` - Number of histogram buckets (clamped between 1 and 1000)
60    pub fn new(column: impl Into<String>, num_buckets: usize) -> Self {
61        Self {
62            column: column.into(),
63            num_buckets: num_buckets.clamp(1, 1000),
64        }
65    }
66
67    /// Returns the column being analyzed.
68    pub fn column(&self) -> &str {
69        &self.column
70    }
71
72    /// Returns the number of buckets.
73    pub fn num_buckets(&self) -> usize {
74        self.num_buckets
75    }
76}
77
78/// State for the histogram analyzer.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct HistogramState {
81    /// The histogram buckets.
82    pub buckets: Vec<HistogramBucket>,
83    /// Minimum value in the dataset.
84    pub min_value: f64,
85    /// Maximum value in the dataset.
86    pub max_value: f64,
87    /// Total count of non-null values.
88    pub total_count: u64,
89    /// Sum of all values (for mean calculation).
90    pub sum: f64,
91    /// Sum of squared values (for std dev calculation).
92    pub sum_squared: f64,
93}
94
95impl HistogramState {
96    /// Calculates the mean value.
97    pub fn mean(&self) -> Option<f64> {
98        if self.total_count > 0 {
99            Some(self.sum / self.total_count as f64)
100        } else {
101            None
102        }
103    }
104
105    /// Calculates the standard deviation.
106    pub fn std_dev(&self) -> Option<f64> {
107        if self.total_count > 1 {
108            let mean = self.mean()?;
109            let variance = (self.sum_squared / self.total_count as f64) - (mean * mean);
110            Some(variance.sqrt())
111        } else {
112            None
113        }
114    }
115}
116
117impl AnalyzerState for HistogramState {
118    fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
119        if states.is_empty() {
120            return Err(AnalyzerError::state_merge("No states to merge"));
121        }
122
123        // For simplicity, we take the first state's bucket structure
124        // In production, we'd re-bucket based on global min/max
125        let first = &states[0];
126        let mut merged_buckets = first.buckets.clone();
127
128        // Merge bucket counts
129        for state in &states[1..] {
130            if state.buckets.len() == merged_buckets.len() {
131                for (i, bucket) in state.buckets.iter().enumerate() {
132                    merged_buckets[i] = HistogramBucket::new(
133                        merged_buckets[i].lower_bound,
134                        merged_buckets[i].upper_bound,
135                        merged_buckets[i].count + bucket.count,
136                    );
137                }
138            }
139        }
140
141        let min_value = states
142            .iter()
143            .map(|s| s.min_value)
144            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
145            .unwrap_or(0.0);
146
147        let max_value = states
148            .iter()
149            .map(|s| s.max_value)
150            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
151            .unwrap_or(0.0);
152
153        let total_count = states.iter().map(|s| s.total_count).sum();
154        let sum = states.iter().map(|s| s.sum).sum();
155        let sum_squared = states.iter().map(|s| s.sum_squared).sum();
156
157        Ok(HistogramState {
158            buckets: merged_buckets,
159            min_value,
160            max_value,
161            total_count,
162            sum,
163            sum_squared,
164        })
165    }
166
167    fn is_empty(&self) -> bool {
168        self.total_count == 0
169    }
170}
171
172#[async_trait]
173impl Analyzer for HistogramAnalyzer {
174    type State = HistogramState;
175    type Metric = MetricValue;
176
177    #[instrument(skip(ctx), fields(analyzer = "histogram", column = %self.column, buckets = %self.num_buckets))]
178    async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
179        // Get the table name from the validation context
180        let validation_ctx = current_validation_context();
181        let table_name = validation_ctx.table_name();
182
183        // First, get min/max values to determine bucket boundaries
184        let stats_sql = format!(
185            "SELECT 
186                MIN({0}) as min_val, 
187                MAX({0}) as max_val,
188                COUNT({0}) as count,
189                SUM({0}) as sum,
190                SUM({0} * {0}) as sum_squared
191            FROM {table_name} 
192            WHERE {0} IS NOT NULL",
193            self.column
194        );
195
196        let stats_df = ctx.sql(&stats_sql).await?;
197        let stats_batches = stats_df.collect().await?;
198
199        let (min_value, max_value, total_count, sum, sum_squared) = if let Some(batch) =
200            stats_batches.first()
201        {
202            if batch.num_rows() > 0 && !batch.column(0).is_null(0) {
203                let min_val = batch
204                    .column(0)
205                    .as_any()
206                    .downcast_ref::<arrow::array::Float64Array>()
207                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for min"))?
208                    .value(0);
209
210                let max_val = batch
211                    .column(1)
212                    .as_any()
213                    .downcast_ref::<arrow::array::Float64Array>()
214                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for max"))?
215                    .value(0);
216
217                let count = batch
218                    .column(2)
219                    .as_any()
220                    .downcast_ref::<arrow::array::Int64Array>()
221                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?
222                    .value(0) as u64;
223
224                let sum_val = batch
225                    .column(3)
226                    .as_any()
227                    .downcast_ref::<arrow::array::Float64Array>()
228                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum"))?
229                    .value(0);
230
231                let sum_sq = batch
232                    .column(4)
233                    .as_any()
234                    .downcast_ref::<arrow::array::Float64Array>()
235                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum_squared"))?
236                    .value(0);
237
238                (min_val, max_val, count, sum_val, sum_sq)
239            } else {
240                // No data
241                return Ok(HistogramState {
242                    buckets: vec![],
243                    min_value: 0.0,
244                    max_value: 0.0,
245                    total_count: 0,
246                    sum: 0.0,
247                    sum_squared: 0.0,
248                });
249            }
250        } else {
251            return Err(AnalyzerError::NoData);
252        };
253
254        // Calculate bucket width
255        let range = max_value - min_value;
256        let bucket_width = if range > 0.0 && self.num_buckets > 1 {
257            range / self.num_buckets as f64
258        } else {
259            1.0
260        };
261
262        // Build histogram query using CASE statement since WIDTH_BUCKET is not available
263        let mut case_clauses = Vec::new();
264        for i in 0..self.num_buckets {
265            let lower = min_value + (i as f64 * bucket_width);
266            let upper = if i == self.num_buckets - 1 {
267                max_value + bucket_width * 0.001
268            } else {
269                min_value + ((i + 1) as f64 * bucket_width)
270            };
271            case_clauses.push(format!(
272                "WHEN {0} >= {1} AND {0} < {2} THEN {3}",
273                self.column,
274                lower,
275                upper,
276                i + 1
277            ));
278        }
279
280        let histogram_sql = format!(
281            "SELECT 
282                CASE 
283                    {}
284                    ELSE {}
285                END as bucket_num,
286                COUNT(*) as count
287            FROM {table_name}
288            WHERE {} IS NOT NULL
289            GROUP BY bucket_num
290            ORDER BY bucket_num",
291            case_clauses.join(" "),
292            self.num_buckets,
293            self.column
294        );
295
296        let hist_df = ctx.sql(&histogram_sql).await?;
297        let hist_batches = hist_df.collect().await?;
298
299        // Build buckets array
300        let mut buckets = vec![HistogramBucket::new(0.0, 0.0, 0); self.num_buckets];
301
302        // Initialize bucket boundaries
303        for (i, bucket) in buckets.iter_mut().enumerate() {
304            let lower = min_value + (i as f64 * bucket_width);
305            let upper = if i == self.num_buckets - 1 {
306                max_value + bucket_width * 0.001
307            } else {
308                min_value + ((i + 1) as f64 * bucket_width)
309            };
310            *bucket = HistogramBucket::new(lower, upper, 0);
311        }
312
313        // Fill in counts from query results
314        for batch in &hist_batches {
315            let bucket_array = batch
316                .column(0)
317                .as_any()
318                .downcast_ref::<arrow::array::Int64Array>()
319                .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for bucket_num"))?;
320
321            let count_array = batch
322                .column(1)
323                .as_any()
324                .downcast_ref::<arrow::array::Int64Array>()
325                .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?;
326
327            for i in 0..batch.num_rows() {
328                let bucket_idx = (bucket_array.value(i) - 1) as usize;
329                let count = count_array.value(i) as u64;
330
331                if bucket_idx < buckets.len() {
332                    buckets[bucket_idx] = HistogramBucket::new(
333                        buckets[bucket_idx].lower_bound,
334                        buckets[bucket_idx].upper_bound,
335                        count,
336                    );
337                }
338            }
339        }
340
341        Ok(HistogramState {
342            buckets,
343            min_value,
344            max_value,
345            total_count,
346            sum,
347            sum_squared,
348        })
349    }
350
351    fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
352        let distribution = MetricDistribution::from_buckets(state.buckets.clone()).with_stats(
353            state.min_value,
354            state.max_value,
355            state.mean().unwrap_or(0.0),
356            state.std_dev().unwrap_or(0.0),
357        );
358
359        Ok(MetricValue::Histogram(distribution))
360    }
361
362    fn name(&self) -> &str {
363        "histogram"
364    }
365
366    fn description(&self) -> &str {
367        "Computes value distribution histogram"
368    }
369
370    fn columns(&self) -> Vec<&str> {
371        vec![&self.column]
372    }
373}