term_guard/analyzers/advanced/
entropy.rs1use arrow::array::{Array, StringViewArray};
4use async_trait::async_trait;
5use datafusion::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tracing::instrument;
9
10use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
11
12use crate::core::current_validation_context;
13#[derive(Debug, Clone)]
49pub struct EntropyAnalyzer {
50 column: String,
52 max_unique_values: usize,
54}
55
56impl EntropyAnalyzer {
57 pub fn new(column: impl Into<String>) -> Self {
59 Self {
60 column: column.into(),
61 max_unique_values: 10_000,
62 }
63 }
64
65 pub fn with_max_unique_values(column: impl Into<String>, max_unique_values: usize) -> Self {
67 Self {
68 column: column.into(),
69 max_unique_values: max_unique_values.max(10),
70 }
71 }
72
73 pub fn column(&self) -> &str {
75 &self.column
76 }
77
78 pub fn max_unique_values(&self) -> usize {
80 self.max_unique_values
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EntropyState {
87 pub value_counts: HashMap<String, u64>,
89 pub total_count: u64,
91 pub truncated: bool,
93}
94
95impl EntropyState {
96 pub fn entropy(&self) -> f64 {
98 if self.total_count == 0 {
99 return 0.0;
100 }
101
102 let total = self.total_count as f64;
103 self.value_counts
104 .values()
105 .map(|&count| {
106 let p = count as f64 / total;
107 if p > 0.0 {
108 -p * p.log2()
109 } else {
110 0.0
111 }
112 })
113 .sum()
114 }
115
116 pub fn normalized_entropy(&self) -> f64 {
118 let num_unique = self.value_counts.len();
119 if num_unique <= 1 {
120 0.0
121 } else {
122 let max_entropy = (num_unique as f64).log2();
123 if max_entropy > 0.0 {
124 self.entropy() / max_entropy
125 } else {
126 0.0
127 }
128 }
129 }
130
131 pub fn gini_impurity(&self) -> f64 {
133 if self.total_count == 0 {
134 return 0.0;
135 }
136
137 let total = self.total_count as f64;
138 let sum_squared_probs: f64 = self
139 .value_counts
140 .values()
141 .map(|&count| {
142 let p = count as f64 / total;
143 p * p
144 })
145 .sum();
146
147 1.0 - sum_squared_probs
148 }
149
150 pub fn effective_values(&self) -> f64 {
152 2.0_f64.powf(self.entropy())
153 }
154
155 pub fn probability_distribution(&self) -> HashMap<String, f64> {
157 let total = self.total_count as f64;
158 self.value_counts
159 .iter()
160 .map(|(value, &count)| (value.clone(), count as f64 / total))
161 .collect()
162 }
163}
164
165impl AnalyzerState for EntropyState {
166 fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
167 let mut merged_counts = HashMap::new();
168 let mut total_count = 0;
169 let mut truncated = false;
170
171 for state in states {
172 total_count += state.total_count;
173 truncated |= state.truncated;
174
175 for (value, count) in state.value_counts {
176 *merged_counts.entry(value).or_insert(0) += count;
177 }
178 }
179
180 Ok(EntropyState {
181 value_counts: merged_counts,
182 total_count,
183 truncated,
184 })
185 }
186
187 fn is_empty(&self) -> bool {
188 self.total_count == 0
189 }
190}
191
192#[async_trait]
193impl Analyzer for EntropyAnalyzer {
194 type State = EntropyState;
195 type Metric = MetricValue;
196
197 #[instrument(skip(ctx), fields(analyzer = "entropy", column = %self.column, max_unique = %self.max_unique_values))]
198 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
199 let validation_ctx = current_validation_context();
201 let table_name = validation_ctx.table_name();
202
203 let count_distinct_sql = format!(
205 "SELECT COUNT(DISTINCT {0}) as unique_count FROM {table_name} WHERE {0} IS NOT NULL",
206 self.column
207 );
208
209 let count_df = ctx.sql(&count_distinct_sql).await?;
210 let count_batches = count_df.collect().await?;
211
212 let unique_count = if let Some(batch) = count_batches.first() {
213 if batch.num_rows() > 0 {
214 let count_array = batch
215 .column(0)
216 .as_any()
217 .downcast_ref::<arrow::array::Int64Array>()
218 .ok_or_else(|| {
219 AnalyzerError::invalid_data("Expected Int64 for unique count")
220 })?;
221 count_array.value(0) as usize
222 } else {
223 0
224 }
225 } else {
226 0
227 };
228
229 let (sql, truncated) = if unique_count > self.max_unique_values {
231 let validation_ctx = current_validation_context();
235
236 let table_name = validation_ctx.table_name();
237
238 let sql = format!(
239 "SELECT
240 CAST({0} AS VARCHAR) as value,
241 COUNT(*) as count
242 FROM {table_name}
243 WHERE {0} IS NOT NULL
244 GROUP BY CAST({0} AS VARCHAR)
245 ORDER BY count DESC
246 LIMIT {1}",
247 self.column, self.max_unique_values
248 );
249 (sql, true)
250 } else {
251 let sql = format!(
253 "SELECT
254 CAST({0} AS VARCHAR) as value,
255 COUNT(*) as count
256 FROM {table_name}
257 WHERE {0} IS NOT NULL
258 GROUP BY CAST({0} AS VARCHAR)",
259 self.column
260 );
261 (sql, false)
262 };
263
264 let df = ctx.sql(&sql).await?;
266 let batches = df.collect().await?;
267
268 let mut value_counts = HashMap::new();
270 let mut total_count = 0;
271
272 for batch in &batches {
273 let value_array = batch.column(0).as_any();
274
275 let values: Vec<(String, bool)> =
277 if let Some(arr) = value_array.downcast_ref::<arrow::array::StringArray>() {
278 (0..arr.len())
279 .map(|i| (arr.value(i).to_string(), arr.is_null(i)))
280 .collect()
281 } else if let Some(arr) = value_array.downcast_ref::<StringViewArray>() {
282 (0..arr.len())
283 .map(|i| (arr.value(i).to_string(), arr.is_null(i)))
284 .collect()
285 } else {
286 return Err(AnalyzerError::invalid_data(format!(
287 "Expected String array for values, got {:?}",
288 batch.column(0).data_type()
289 )));
290 };
291
292 let count_array = batch
293 .column(1)
294 .as_any()
295 .downcast_ref::<arrow::array::Int64Array>()
296 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 array for counts"))?;
297
298 for (i, (value, is_null)) in values.iter().enumerate() {
299 if !is_null {
300 let count = count_array.value(i) as u64;
301 value_counts.insert(value.clone(), count);
302 total_count += count;
303 }
304 }
305 }
306
307 if truncated {
309 let total_sql = format!(
310 "SELECT COUNT({0}) as total FROM {table_name} WHERE {0} IS NOT NULL",
311 self.column
312 );
313 let total_df = ctx.sql(&total_sql).await?;
314 let total_batches = total_df.collect().await?;
315
316 if let Some(batch) = total_batches.first() {
317 if batch.num_rows() > 0 {
318 let total_array = batch
319 .column(0)
320 .as_any()
321 .downcast_ref::<arrow::array::Int64Array>()
322 .ok_or_else(|| {
323 AnalyzerError::invalid_data("Expected Int64 for total count")
324 })?;
325 total_count = total_array.value(0) as u64;
326 }
327 }
328 }
329
330 Ok(EntropyState {
331 value_counts,
332 total_count,
333 truncated,
334 })
335 }
336
337 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
338 use std::collections::HashMap;
339
340 let mut metrics = HashMap::new();
341
342 metrics.insert("entropy".to_string(), MetricValue::Double(state.entropy()));
344 metrics.insert(
345 "normalized_entropy".to_string(),
346 MetricValue::Double(state.normalized_entropy()),
347 );
348 metrics.insert(
349 "gini_impurity".to_string(),
350 MetricValue::Double(state.gini_impurity()),
351 );
352 metrics.insert(
353 "effective_values".to_string(),
354 MetricValue::Double(state.effective_values()),
355 );
356
357 metrics.insert(
359 "unique_values".to_string(),
360 MetricValue::Long(state.value_counts.len() as i64),
361 );
362 metrics.insert(
363 "total_count".to_string(),
364 MetricValue::Long(state.total_count as i64),
365 );
366 metrics.insert(
367 "truncated".to_string(),
368 MetricValue::Boolean(state.truncated),
369 );
370
371 if state.value_counts.len() <= 100 {
373 let mut sorted_values: Vec<_> = state.value_counts.iter().collect();
374 sorted_values.sort_by(|a, b| b.1.cmp(a.1));
375
376 let top_values: HashMap<String, MetricValue> = sorted_values
377 .iter()
378 .take(10)
379 .map(|(value, &count)| {
380 let prob = count as f64 / state.total_count as f64;
381 (
382 value.to_string(),
383 MetricValue::Map(HashMap::from([
384 ("count".to_string(), MetricValue::Long(count as i64)),
385 ("probability".to_string(), MetricValue::Double(prob)),
386 ])),
387 )
388 })
389 .collect();
390
391 metrics.insert("top_values".to_string(), MetricValue::Map(top_values));
392 }
393
394 Ok(MetricValue::Map(metrics))
395 }
396
397 fn name(&self) -> &str {
398 "entropy"
399 }
400
401 fn description(&self) -> &str {
402 "Computes Shannon entropy and information theory metrics"
403 }
404
405 fn columns(&self) -> Vec<&str> {
406 vec![&self.column]
407 }
408}