Skip to main content

three_dcf_core/
metrics.rs

1use once_cell::sync::Lazy;
2use regex::Regex;
3use serde::{Deserialize, Serialize};
4use strsim::levenshtein;
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
7pub struct Metrics {
8    pub pages: u32,
9    pub lines_total: u32,
10    pub cells_total: u32,
11    pub cells_kept: u32,
12    pub dedup_ratio: f32,
13    pub numguard_count: u32,
14    pub raw_tokens_estimate: Option<u32>,
15    pub compressed_tokens_estimate: Option<u32>,
16    pub compression_factor: Option<f32>,
17}
18
19impl Metrics {
20    pub fn with_token_metrics(mut self, raw: Option<u32>, compressed: Option<u32>) -> Self {
21        self.record_tokens(raw, compressed);
22        self
23    }
24
25    pub fn record_tokens(&mut self, raw: Option<u32>, compressed: Option<u32>) {
26        if let Some(raw) = raw {
27            self.raw_tokens_estimate = Some(raw);
28        }
29        if let Some(compressed) = compressed {
30            self.compressed_tokens_estimate = Some(compressed);
31        }
32        self.compression_factor = match (self.raw_tokens_estimate, self.compressed_tokens_estimate)
33        {
34            (Some(raw), Some(comp)) if comp > 0 => Some(raw as f32 / comp as f32),
35            _ => None,
36        };
37    }
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct TokenMetrics {
42    pub raw: u32,
43    pub compressed: u32,
44    pub factor: f32,
45}
46
47#[derive(Debug, Clone, Copy, Default)]
48pub struct NumStats {
49    pub precision: f64,
50    pub recall: f64,
51    pub f1: f64,
52    pub units_ok: f64,
53}
54
55pub fn cer(pred: &str, gold: &str) -> f64 {
56    normalized_distance(pred, gold)
57}
58
59pub fn wer(pred: &str, gold: &str) -> f64 {
60    let pred_words: Vec<&str> = pred.split_whitespace().collect();
61    let gold_words: Vec<&str> = gold.split_whitespace().collect();
62    if gold_words.is_empty() {
63        return if pred_words.is_empty() { 0.0 } else { 1.0 };
64    }
65    let dist = levenshtein_words(&pred_words, &gold_words);
66    dist as f64 / gold_words.len() as f64
67}
68
69pub fn numeric_stats(pred: &str, gold: &str) -> NumStats {
70    let pred_vals = extract_numbers(pred);
71    let gold_vals = extract_numbers(gold);
72    if gold_vals.is_empty() && pred_vals.is_empty() {
73        return NumStats {
74            precision: 1.0,
75            recall: 1.0,
76            f1: 1.0,
77            units_ok: 1.0,
78        };
79    }
80    let mut gold_used = vec![false; gold_vals.len()];
81    let mut matches = 0usize;
82    let mut units_match = 0usize;
83    for pred in &pred_vals {
84        if let Some((idx, gold_item)) = gold_vals
85            .iter()
86            .enumerate()
87            .find(|(i, g)| !gold_used[*i] && g.value == pred.value)
88        {
89            gold_used[idx] = true;
90            matches += 1;
91            if gold_item
92                .unit
93                .as_ref()
94                .map(|u| pred.unit.as_ref() == Some(u))
95                .unwrap_or(true)
96            {
97                units_match += 1;
98            }
99        }
100    }
101    let precision = if pred_vals.is_empty() {
102        1.0
103    } else {
104        matches as f64 / pred_vals.len() as f64
105    };
106    let recall = if gold_vals.is_empty() {
107        1.0
108    } else {
109        matches as f64 / gold_vals.len() as f64
110    };
111    let f1 = if precision == 0.0 || recall == 0.0 {
112        0.0
113    } else {
114        2.0 * precision * recall / (precision + recall)
115    };
116    let units_ok = if matches == 0 {
117        1.0
118    } else {
119        units_match as f64 / matches as f64
120    };
121    NumStats {
122        precision,
123        recall,
124        f1,
125        units_ok,
126    }
127}
128
129fn normalized_distance(pred: &str, gold: &str) -> f64 {
130    let pred_norm = normalize(pred);
131    let gold_norm = normalize(gold);
132    if gold_norm.is_empty() {
133        return if pred_norm.is_empty() { 0.0 } else { 1.0 };
134    }
135    let dist = levenshtein(&pred_norm, &gold_norm);
136    dist as f64 / gold_norm.chars().count() as f64
137}
138
139fn levenshtein_words(pred: &[&str], gold: &[&str]) -> usize {
140    let m = gold.len();
141    let n = pred.len();
142    if m == 0 {
143        return n;
144    }
145    if n == 0 {
146        return m;
147    }
148    let mut dp = vec![vec![0usize; n + 1]; m + 1];
149    for i in 0..=m {
150        dp[i][0] = i;
151    }
152    for j in 0..=n {
153        dp[0][j] = j;
154    }
155    for i in 1..=m {
156        for j in 1..=n {
157            let cost = if gold[i - 1] == pred[j - 1] { 0 } else { 1 };
158            dp[i][j] = (dp[i - 1][j] + 1)
159                .min(dp[i][j - 1] + 1)
160                .min(dp[i - 1][j - 1] + cost);
161        }
162    }
163    dp[m][n]
164}
165
166fn normalize(text: &str) -> String {
167    text.chars().filter(|c| !c.is_control()).collect::<String>()
168}
169
170#[derive(Debug, Clone)]
171struct NumberToken {
172    value: String,
173    unit: Option<String>,
174}
175
176static NUM_RE: Lazy<Regex> = Lazy::new(|| {
177    Regex::new(r"(?i)(?P<num>[+-]?\d{1,3}(?:[\d,.]*))(?:\s*(?P<unit>[a-z%$]{1,8}))?")
178        .expect("valid regex")
179});
180
181fn extract_numbers(text: &str) -> Vec<NumberToken> {
182    let mut out = Vec::new();
183    for caps in NUM_RE.captures_iter(text) {
184        let raw = caps.name("num").map(|m| m.as_str()).unwrap_or("");
185        let mut digits = raw.replace(',', "");
186        digits.retain(|c| c.is_ascii_digit() || c == '.' || c == '-');
187        if digits.is_empty() {
188            continue;
189        }
190        let unit = caps.name("unit").map(|m| m.as_str().trim().to_lowercase());
191        out.push(NumberToken {
192            value: digits,
193            unit,
194        });
195    }
196    out
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn cer_handles_exact_match() {
205        assert_eq!(cer("hello", "hello"), 0.0);
206    }
207
208    #[test]
209    fn numeric_stats_counts_matches() {
210        let stats = numeric_stats("Revenue $123", "Revenue $123");
211        assert_eq!(stats.precision, 1.0);
212        assert_eq!(stats.recall, 1.0);
213    }
214}