three_dcf_core/
metrics.rs1use once_cell::sync::Lazy;
2use regex::Regex;
3use serde::{Deserialize, Serialize};
4use strsim::levenshtein;
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
7pub struct Metrics {
8 pub pages: u32,
9 pub lines_total: u32,
10 pub cells_total: u32,
11 pub cells_kept: u32,
12 pub dedup_ratio: f32,
13 pub numguard_count: u32,
14 pub raw_tokens_estimate: Option<u32>,
15 pub compressed_tokens_estimate: Option<u32>,
16 pub compression_factor: Option<f32>,
17}
18
19impl Metrics {
20 pub fn with_token_metrics(mut self, raw: Option<u32>, compressed: Option<u32>) -> Self {
21 self.record_tokens(raw, compressed);
22 self
23 }
24
25 pub fn record_tokens(&mut self, raw: Option<u32>, compressed: Option<u32>) {
26 if let Some(raw) = raw {
27 self.raw_tokens_estimate = Some(raw);
28 }
29 if let Some(compressed) = compressed {
30 self.compressed_tokens_estimate = Some(compressed);
31 }
32 self.compression_factor = match (self.raw_tokens_estimate, self.compressed_tokens_estimate)
33 {
34 (Some(raw), Some(comp)) if comp > 0 => Some(raw as f32 / comp as f32),
35 _ => None,
36 };
37 }
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct TokenMetrics {
42 pub raw: u32,
43 pub compressed: u32,
44 pub factor: f32,
45}
46
47#[derive(Debug, Clone, Copy, Default)]
48pub struct NumStats {
49 pub precision: f64,
50 pub recall: f64,
51 pub f1: f64,
52 pub units_ok: f64,
53}
54
55pub fn cer(pred: &str, gold: &str) -> f64 {
56 normalized_distance(pred, gold)
57}
58
59pub fn wer(pred: &str, gold: &str) -> f64 {
60 let pred_words: Vec<&str> = pred.split_whitespace().collect();
61 let gold_words: Vec<&str> = gold.split_whitespace().collect();
62 if gold_words.is_empty() {
63 return if pred_words.is_empty() { 0.0 } else { 1.0 };
64 }
65 let dist = levenshtein_words(&pred_words, &gold_words);
66 dist as f64 / gold_words.len() as f64
67}
68
69pub fn numeric_stats(pred: &str, gold: &str) -> NumStats {
70 let pred_vals = extract_numbers(pred);
71 let gold_vals = extract_numbers(gold);
72 if gold_vals.is_empty() && pred_vals.is_empty() {
73 return NumStats {
74 precision: 1.0,
75 recall: 1.0,
76 f1: 1.0,
77 units_ok: 1.0,
78 };
79 }
80 let mut gold_used = vec![false; gold_vals.len()];
81 let mut matches = 0usize;
82 let mut units_match = 0usize;
83 for pred in &pred_vals {
84 if let Some((idx, gold_item)) = gold_vals
85 .iter()
86 .enumerate()
87 .find(|(i, g)| !gold_used[*i] && g.value == pred.value)
88 {
89 gold_used[idx] = true;
90 matches += 1;
91 if gold_item
92 .unit
93 .as_ref()
94 .map(|u| pred.unit.as_ref() == Some(u))
95 .unwrap_or(true)
96 {
97 units_match += 1;
98 }
99 }
100 }
101 let precision = if pred_vals.is_empty() {
102 1.0
103 } else {
104 matches as f64 / pred_vals.len() as f64
105 };
106 let recall = if gold_vals.is_empty() {
107 1.0
108 } else {
109 matches as f64 / gold_vals.len() as f64
110 };
111 let f1 = if precision == 0.0 || recall == 0.0 {
112 0.0
113 } else {
114 2.0 * precision * recall / (precision + recall)
115 };
116 let units_ok = if matches == 0 {
117 1.0
118 } else {
119 units_match as f64 / matches as f64
120 };
121 NumStats {
122 precision,
123 recall,
124 f1,
125 units_ok,
126 }
127}
128
129fn normalized_distance(pred: &str, gold: &str) -> f64 {
130 let pred_norm = normalize(pred);
131 let gold_norm = normalize(gold);
132 if gold_norm.is_empty() {
133 return if pred_norm.is_empty() { 0.0 } else { 1.0 };
134 }
135 let dist = levenshtein(&pred_norm, &gold_norm);
136 dist as f64 / gold_norm.chars().count() as f64
137}
138
139fn levenshtein_words(pred: &[&str], gold: &[&str]) -> usize {
140 let m = gold.len();
141 let n = pred.len();
142 if m == 0 {
143 return n;
144 }
145 if n == 0 {
146 return m;
147 }
148 let mut dp = vec![vec![0usize; n + 1]; m + 1];
149 for i in 0..=m {
150 dp[i][0] = i;
151 }
152 for j in 0..=n {
153 dp[0][j] = j;
154 }
155 for i in 1..=m {
156 for j in 1..=n {
157 let cost = if gold[i - 1] == pred[j - 1] { 0 } else { 1 };
158 dp[i][j] = (dp[i - 1][j] + 1)
159 .min(dp[i][j - 1] + 1)
160 .min(dp[i - 1][j - 1] + cost);
161 }
162 }
163 dp[m][n]
164}
165
166fn normalize(text: &str) -> String {
167 text.chars().filter(|c| !c.is_control()).collect::<String>()
168}
169
170#[derive(Debug, Clone)]
171struct NumberToken {
172 value: String,
173 unit: Option<String>,
174}
175
176static NUM_RE: Lazy<Regex> = Lazy::new(|| {
177 Regex::new(r"(?i)(?P<num>[+-]?\d{1,3}(?:[\d,.]*))(?:\s*(?P<unit>[a-z%$]{1,8}))?")
178 .expect("valid regex")
179});
180
181fn extract_numbers(text: &str) -> Vec<NumberToken> {
182 let mut out = Vec::new();
183 for caps in NUM_RE.captures_iter(text) {
184 let raw = caps.name("num").map(|m| m.as_str()).unwrap_or("");
185 let mut digits = raw.replace(',', "");
186 digits.retain(|c| c.is_ascii_digit() || c == '.' || c == '-');
187 if digits.is_empty() {
188 continue;
189 }
190 let unit = caps.name("unit").map(|m| m.as_str().trim().to_lowercase());
191 out.push(NumberToken {
192 value: digits,
193 unit,
194 });
195 }
196 out
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn cer_handles_exact_match() {
205 assert_eq!(cer("hello", "hello"), 0.0);
206 }
207
208 #[test]
209 fn numeric_stats_counts_matches() {
210 let stats = numeric_stats("Revenue $123", "Revenue $123");
211 assert_eq!(stats.precision, 1.0);
212 assert_eq!(stats.recall, 1.0);
213 }
214}