term_guard/analyzers/advanced/
entropy.rs

1//! Entropy analyzer for information theory metrics.
2
3use arrow::array::{Array, StringViewArray};
4use async_trait::async_trait;
5use datafusion::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tracing::instrument;
9
10use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
11
12use crate::core::current_validation_context;
13/// Analyzer that computes Shannon entropy and related information theory metrics.
14///
15/// Entropy measures the average information content or uncertainty in a dataset.
16/// Higher entropy indicates more randomness/diversity, while lower entropy
17/// indicates more predictability/uniformity.
18///
19/// # Metrics Computed
20///
21/// - **Shannon Entropy**: -Σ(p_i * log2(p_i)) where p_i is the probability of each value
22/// - **Normalized Entropy**: Entropy divided by log2(n) where n is the number of unique values
23/// - **Gini Impurity**: 1 - Σ(p_i²), another measure of diversity
24/// - **Effective Number of Values**: 2^entropy, interpretable as the effective cardinality
25///
26/// # Example
27///
28/// ```rust,ignore
29/// use term_guard::analyzers::advanced::EntropyAnalyzer;
30/// use datafusion::prelude::*;
31///
32/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
33/// let ctx = SessionContext::new();
34/// // Register your data table
35///
36/// let analyzer = EntropyAnalyzer::new("category");
37/// let state = analyzer.compute_state_from_data(&ctx).await?;
38/// let metric = analyzer.compute_metric_from_state(&state)?;
39///
40/// if let MetricValue::Map(metrics) = metric {
41///     println!("Category entropy: {:?} bits", metrics.get("entropy"));
42///     println!("Normalized entropy: {:?}", metrics.get("normalized_entropy"));
43///     println!("Effective categories: {:?}", metrics.get("effective_values"));
44/// }
45/// # Ok(())
46/// # }
47/// ```
48#[derive(Debug, Clone)]
49pub struct EntropyAnalyzer {
50    /// The column to analyze.
51    column: String,
52    /// Maximum number of unique values to track (for memory efficiency).
53    max_unique_values: usize,
54}
55
56impl EntropyAnalyzer {
57    /// Creates a new entropy analyzer for the specified column.
58    pub fn new(column: impl Into<String>) -> Self {
59        Self {
60            column: column.into(),
61            max_unique_values: 10_000,
62        }
63    }
64
65    /// Creates a new entropy analyzer with a custom maximum unique values limit.
66    pub fn with_max_unique_values(column: impl Into<String>, max_unique_values: usize) -> Self {
67        Self {
68            column: column.into(),
69            max_unique_values: max_unique_values.max(10),
70        }
71    }
72
73    /// Returns the column being analyzed.
74    pub fn column(&self) -> &str {
75        &self.column
76    }
77
78    /// Returns the maximum unique values limit.
79    pub fn max_unique_values(&self) -> usize {
80        self.max_unique_values
81    }
82}
83
84/// State for the entropy analyzer.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EntropyState {
87    /// Count of occurrences for each unique value.
88    pub value_counts: HashMap<String, u64>,
89    /// Total count of non-null values.
90    pub total_count: u64,
91    /// Whether the unique value limit was exceeded.
92    pub truncated: bool,
93}
94
95impl EntropyState {
96    /// Calculates Shannon entropy in bits.
97    pub fn entropy(&self) -> f64 {
98        if self.total_count == 0 {
99            return 0.0;
100        }
101
102        let total = self.total_count as f64;
103        self.value_counts
104            .values()
105            .map(|&count| {
106                let p = count as f64 / total;
107                if p > 0.0 {
108                    -p * p.log2()
109                } else {
110                    0.0
111                }
112            })
113            .sum()
114    }
115
116    /// Calculates normalized entropy (0 to 1).
117    pub fn normalized_entropy(&self) -> f64 {
118        let num_unique = self.value_counts.len();
119        if num_unique <= 1 {
120            0.0
121        } else {
122            let max_entropy = (num_unique as f64).log2();
123            if max_entropy > 0.0 {
124                self.entropy() / max_entropy
125            } else {
126                0.0
127            }
128        }
129    }
130
131    /// Calculates Gini impurity.
132    pub fn gini_impurity(&self) -> f64 {
133        if self.total_count == 0 {
134            return 0.0;
135        }
136
137        let total = self.total_count as f64;
138        let sum_squared_probs: f64 = self
139            .value_counts
140            .values()
141            .map(|&count| {
142                let p = count as f64 / total;
143                p * p
144            })
145            .sum();
146
147        1.0 - sum_squared_probs
148    }
149
150    /// Calculates the effective number of values (perplexity).
151    pub fn effective_values(&self) -> f64 {
152        2.0_f64.powf(self.entropy())
153    }
154
155    /// Returns the probability distribution.
156    pub fn probability_distribution(&self) -> HashMap<String, f64> {
157        let total = self.total_count as f64;
158        self.value_counts
159            .iter()
160            .map(|(value, &count)| (value.clone(), count as f64 / total))
161            .collect()
162    }
163}
164
165impl AnalyzerState for EntropyState {
166    fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
167        let mut merged_counts = HashMap::new();
168        let mut total_count = 0;
169        let mut truncated = false;
170
171        for state in states {
172            total_count += state.total_count;
173            truncated |= state.truncated;
174
175            for (value, count) in state.value_counts {
176                *merged_counts.entry(value).or_insert(0) += count;
177            }
178        }
179
180        Ok(EntropyState {
181            value_counts: merged_counts,
182            total_count,
183            truncated,
184        })
185    }
186
187    fn is_empty(&self) -> bool {
188        self.total_count == 0
189    }
190}
191
192#[async_trait]
193impl Analyzer for EntropyAnalyzer {
194    type State = EntropyState;
195    type Metric = MetricValue;
196
197    #[instrument(skip(ctx), fields(analyzer = "entropy", column = %self.column, max_unique = %self.max_unique_values))]
198    async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
199        // Get the table name from the validation context
200        let validation_ctx = current_validation_context();
201        let table_name = validation_ctx.table_name();
202
203        // First check if we have too many unique values
204        let count_distinct_sql = format!(
205            "SELECT COUNT(DISTINCT {0}) as unique_count FROM {table_name} WHERE {0} IS NOT NULL",
206            self.column
207        );
208
209        let count_df = ctx.sql(&count_distinct_sql).await?;
210        let count_batches = count_df.collect().await?;
211
212        let unique_count = if let Some(batch) = count_batches.first() {
213            if batch.num_rows() > 0 {
214                let count_array = batch
215                    .column(0)
216                    .as_any()
217                    .downcast_ref::<arrow::array::Int64Array>()
218                    .ok_or_else(|| {
219                        AnalyzerError::invalid_data("Expected Int64 for unique count")
220                    })?;
221                count_array.value(0) as usize
222            } else {
223                0
224            }
225        } else {
226            0
227        };
228
229        // If too many unique values, use sampling or approximation
230        let (sql, truncated) = if unique_count > self.max_unique_values {
231            // Sample top N most frequent values
232            // Get the table name from the validation context
233
234            let validation_ctx = current_validation_context();
235
236            let table_name = validation_ctx.table_name();
237
238            let sql = format!(
239                "SELECT 
240                    CAST({0} AS VARCHAR) as value, 
241                    COUNT(*) as count
242                FROM {table_name}
243                WHERE {0} IS NOT NULL
244                GROUP BY CAST({0} AS VARCHAR)
245                ORDER BY count DESC
246                LIMIT {1}",
247                self.column, self.max_unique_values
248            );
249            (sql, true)
250        } else {
251            // Get all values
252            let sql = format!(
253                "SELECT 
254                    CAST({0} AS VARCHAR) as value, 
255                    COUNT(*) as count
256                FROM {table_name}
257                WHERE {0} IS NOT NULL
258                GROUP BY CAST({0} AS VARCHAR)",
259                self.column
260            );
261            (sql, false)
262        };
263
264        // Execute query
265        let df = ctx.sql(&sql).await?;
266        let batches = df.collect().await?;
267
268        // Build value counts map
269        let mut value_counts = HashMap::new();
270        let mut total_count = 0;
271
272        for batch in &batches {
273            let value_array = batch.column(0).as_any();
274
275            // Try to handle different string array types
276            let values: Vec<(String, bool)> =
277                if let Some(arr) = value_array.downcast_ref::<arrow::array::StringArray>() {
278                    (0..arr.len())
279                        .map(|i| (arr.value(i).to_string(), arr.is_null(i)))
280                        .collect()
281                } else if let Some(arr) = value_array.downcast_ref::<StringViewArray>() {
282                    (0..arr.len())
283                        .map(|i| (arr.value(i).to_string(), arr.is_null(i)))
284                        .collect()
285                } else {
286                    return Err(AnalyzerError::invalid_data(format!(
287                        "Expected String array for values, got {:?}",
288                        batch.column(0).data_type()
289                    )));
290                };
291
292            let count_array = batch
293                .column(1)
294                .as_any()
295                .downcast_ref::<arrow::array::Int64Array>()
296                .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 array for counts"))?;
297
298            for (i, (value, is_null)) in values.iter().enumerate() {
299                if !is_null {
300                    let count = count_array.value(i) as u64;
301                    value_counts.insert(value.clone(), count);
302                    total_count += count;
303                }
304            }
305        }
306
307        // If truncated, we need to get the true total count
308        if truncated {
309            let total_sql = format!(
310                "SELECT COUNT({0}) as total FROM {table_name} WHERE {0} IS NOT NULL",
311                self.column
312            );
313            let total_df = ctx.sql(&total_sql).await?;
314            let total_batches = total_df.collect().await?;
315
316            if let Some(batch) = total_batches.first() {
317                if batch.num_rows() > 0 {
318                    let total_array = batch
319                        .column(0)
320                        .as_any()
321                        .downcast_ref::<arrow::array::Int64Array>()
322                        .ok_or_else(|| {
323                            AnalyzerError::invalid_data("Expected Int64 for total count")
324                        })?;
325                    total_count = total_array.value(0) as u64;
326                }
327            }
328        }
329
330        Ok(EntropyState {
331            value_counts,
332            total_count,
333            truncated,
334        })
335    }
336
337    fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
338        use std::collections::HashMap;
339
340        let mut metrics = HashMap::new();
341
342        // Core entropy metrics
343        metrics.insert("entropy".to_string(), MetricValue::Double(state.entropy()));
344        metrics.insert(
345            "normalized_entropy".to_string(),
346            MetricValue::Double(state.normalized_entropy()),
347        );
348        metrics.insert(
349            "gini_impurity".to_string(),
350            MetricValue::Double(state.gini_impurity()),
351        );
352        metrics.insert(
353            "effective_values".to_string(),
354            MetricValue::Double(state.effective_values()),
355        );
356
357        // Additional statistics
358        metrics.insert(
359            "unique_values".to_string(),
360            MetricValue::Long(state.value_counts.len() as i64),
361        );
362        metrics.insert(
363            "total_count".to_string(),
364            MetricValue::Long(state.total_count as i64),
365        );
366        metrics.insert(
367            "truncated".to_string(),
368            MetricValue::Boolean(state.truncated),
369        );
370
371        // Add top 10 most frequent values if not too many
372        if state.value_counts.len() <= 100 {
373            let mut sorted_values: Vec<_> = state.value_counts.iter().collect();
374            sorted_values.sort_by(|a, b| b.1.cmp(a.1));
375
376            let top_values: HashMap<String, MetricValue> = sorted_values
377                .iter()
378                .take(10)
379                .map(|(value, &count)| {
380                    let prob = count as f64 / state.total_count as f64;
381                    (
382                        value.to_string(),
383                        MetricValue::Map(HashMap::from([
384                            ("count".to_string(), MetricValue::Long(count as i64)),
385                            ("probability".to_string(), MetricValue::Double(prob)),
386                        ])),
387                    )
388                })
389                .collect();
390
391            metrics.insert("top_values".to_string(), MetricValue::Map(top_values));
392        }
393
394        Ok(MetricValue::Map(metrics))
395    }
396
397    fn name(&self) -> &str {
398        "entropy"
399    }
400
401    fn description(&self) -> &str {
402        "Computes Shannon entropy and information theory metrics"
403    }
404
405    fn columns(&self) -> Vec<&str> {
406        vec![&self.column]
407    }
408}