Skip to main content

scirs2_text/evaluation/
sequence_labeler.rs

1//! CRF-style Viterbi decoder for neural sequence labeling (NER etc.)
2//! with BIO tagging scheme and span-level evaluation metrics.
3
4use crate::error::{Result, TextError};
5use std::collections::HashMap;
6
7// ---------------------------------------------------------------------------
8// BIO tagging
9// ---------------------------------------------------------------------------
10
11/// BIO (Begin-Inside-Outside) tagging scheme.
12#[non_exhaustive]
13#[derive(Debug, Clone, PartialEq)]
14pub enum BioTag {
15    /// Begin of an entity of the given type.
16    B(String),
17    /// Inside (continuation of) an entity of the given type.
18    I(String),
19    /// Outside — not part of any entity.
20    O,
21}
22
23impl BioTag {
24    /// Returns the entity type string if the tag is B or I.
25    pub fn entity_type(&self) -> Option<&str> {
26        match self {
27            BioTag::B(t) | BioTag::I(t) => Some(t.as_str()),
28            BioTag::O => None,
29        }
30    }
31
32    /// True if this is a B tag.
33    pub fn is_begin(&self) -> bool {
34        matches!(self, BioTag::B(_))
35    }
36
37    /// True if this is an I tag.
38    pub fn is_inside(&self) -> bool {
39        matches!(self, BioTag::I(_))
40    }
41}
42
43// ---------------------------------------------------------------------------
44// Viterbi decoder
45// ---------------------------------------------------------------------------
46
47/// CRF-style Viterbi decoder operating over emission and transition log-probabilities.
48pub struct ViterbiDecoder {
49    /// Total number of output tags.
50    pub n_tags: usize,
51    /// Human-readable tag names in index order.
52    pub tag_names: Vec<String>,
53}
54
55impl ViterbiDecoder {
56    /// Construct a decoder from an ordered list of tag names.
57    pub fn new(tag_names: Vec<String>) -> Self {
58        let n_tags = tag_names.len();
59        Self { n_tags, tag_names }
60    }
61
62    /// Viterbi decoding over emission scores and a transition matrix.
63    ///
64    /// `emissions`: \[seq_len\]\[n_tags\] log-probabilities of each tag at each position.
65    /// `transitions`: \[n_tags\]\[n_tags\] log-probability of transitioning from tag *i* to tag *j*.
66    ///
67    /// Returns the most likely tag index sequence.
68    pub fn decode(&self, emissions: &[Vec<f64>], transitions: &[Vec<f64>]) -> Result<Vec<usize>> {
69        let seq_len = emissions.len();
70        if seq_len == 0 {
71            return Err(TextError::InvalidInput(
72                "Viterbi: empty emission sequence".into(),
73            ));
74        }
75        if transitions.len() != self.n_tags {
76            return Err(TextError::InvalidInput(format!(
77                "transitions rows {} != n_tags {}",
78                transitions.len(),
79                self.n_tags
80            )));
81        }
82        for row in emissions {
83            if row.len() != self.n_tags {
84                return Err(TextError::InvalidInput(format!(
85                    "emission width {} != n_tags {}",
86                    row.len(),
87                    self.n_tags
88                )));
89            }
90        }
91
92        let n = self.n_tags;
93        // dp[t][k] = best log-prob of tagging position t with tag k
94        let mut dp = vec![vec![f64::NEG_INFINITY; n]; seq_len];
95        // bp[t][k] = argmax predecessor tag at t-1
96        let mut bp = vec![vec![0_usize; n]; seq_len];
97
98        // Initialise with emissions at t=0 (uniform start)
99        for k in 0..n {
100            dp[0][k] = emissions[0][k];
101        }
102
103        // Forward
104        for t in 1..seq_len {
105            for k in 0..n {
106                let mut best_score = f64::NEG_INFINITY;
107                let mut best_prev = 0;
108                for j in 0..n {
109                    let score = dp[t - 1][j] + transitions[j][k] + emissions[t][k];
110                    if score > best_score {
111                        best_score = score;
112                        best_prev = j;
113                    }
114                }
115                dp[t][k] = best_score;
116                bp[t][k] = best_prev;
117            }
118        }
119
120        // Find best final tag
121        let mut best_last = 0;
122        let mut best_last_score = f64::NEG_INFINITY;
123        for k in 0..n {
124            if dp[seq_len - 1][k] > best_last_score {
125                best_last_score = dp[seq_len - 1][k];
126                best_last = k;
127            }
128        }
129
130        // Backtrack
131        let mut path = vec![0_usize; seq_len];
132        path[seq_len - 1] = best_last;
133        for t in (1..seq_len).rev() {
134            path[t - 1] = bp[t][path[t]];
135        }
136
137        Ok(path)
138    }
139
140    /// Convert a sequence of tag indices to BIO tags.
141    ///
142    /// Tags whose name starts with `B-` are parsed as `BioTag::B(type)`, `I-` → `BioTag::I(type)`,
143    /// `O` → `BioTag::O`.  Unknown names are treated as `O`.
144    pub fn indices_to_bio(&self, indices: &[usize]) -> Result<Vec<BioTag>> {
145        indices
146            .iter()
147            .map(|&idx| {
148                if idx >= self.n_tags {
149                    return Err(TextError::InvalidInput(format!(
150                        "tag index {} out of range {}",
151                        idx, self.n_tags
152                    )));
153                }
154                let name = &self.tag_names[idx];
155                let bio = if name.starts_with("B-") {
156                    BioTag::B(name[2..].to_owned())
157                } else if name.starts_with("I-") {
158                    BioTag::I(name[2..].to_owned())
159                } else {
160                    BioTag::O
161                };
162                Ok(bio)
163            })
164            .collect()
165    }
166
167    /// Extract named entities from a BIO-tagged sequence.
168    ///
169    /// Returns `(entity_type, start_index, end_index_exclusive)` triples.
170    pub fn extract_entities(bio_tags: &[BioTag]) -> Vec<(String, usize, usize)> {
171        let mut entities = Vec::new();
172        let mut i = 0;
173        while i < bio_tags.len() {
174            if let BioTag::B(etype) = &bio_tags[i] {
175                let start = i;
176                let entity_type = etype.clone();
177                i += 1;
178                while i < bio_tags.len() {
179                    match &bio_tags[i] {
180                        BioTag::I(t) if t == &entity_type => {
181                            i += 1;
182                        }
183                        _ => break,
184                    }
185                }
186                entities.push((entity_type, start, i));
187            } else {
188                i += 1;
189            }
190        }
191        entities
192    }
193}
194
195// ---------------------------------------------------------------------------
196// Evaluation metrics
197// ---------------------------------------------------------------------------
198
199/// Span-level precision, recall and F1 for sequence labeling.
200#[derive(Debug, Clone)]
201pub struct SequenceLabelMetrics {
202    /// Precision over all entity types.
203    pub precision: f64,
204    /// Recall over all entity types.
205    pub recall: f64,
206    /// F1 score (harmonic mean of precision and recall).
207    pub f1: f64,
208    /// Per-entity-type counts: `type → (tp, fp, fn_count)`.
209    pub entity_counts: HashMap<String, (usize, usize, usize)>,
210}
211
212/// Evaluate sequence labeling by comparing predicted to gold BIO sequences.
213///
214/// Entities are compared at the span level (type + start + end must match).
215pub fn evaluate_sequence_labeling(
216    predicted: &[Vec<BioTag>],
217    gold: &[Vec<BioTag>],
218) -> Result<SequenceLabelMetrics> {
219    if predicted.len() != gold.len() {
220        return Err(TextError::InvalidInput(format!(
221            "predicted {} sequences != gold {}",
222            predicted.len(),
223            gold.len()
224        )));
225    }
226
227    // Collect (type, start, end) spans from a BIO sequence with a sentence offset.
228    let collect_spans = |seq: &Vec<BioTag>, offset: usize| -> Vec<(String, usize, usize)> {
229        ViterbiDecoder::extract_entities(seq)
230            .into_iter()
231            .map(|(t, s, e)| (t, s + offset, e + offset))
232            .collect()
233    };
234
235    let mut all_pred: Vec<(String, usize, usize)> = Vec::new();
236    let mut all_gold: Vec<(String, usize, usize)> = Vec::new();
237    let mut offset = 0;
238    for (pred_seq, gold_seq) in predicted.iter().zip(gold) {
239        all_pred.extend(collect_spans(pred_seq, offset));
240        all_gold.extend(collect_spans(gold_seq, offset));
241        offset += pred_seq.len().max(gold_seq.len());
242    }
243
244    // Compute per-type tp/fp/fn
245    let mut counts: HashMap<String, (usize, usize, usize)> = HashMap::new();
246
247    for span in &all_gold {
248        counts.entry(span.0.clone()).or_insert((0, 0, 0));
249    }
250    for span in &all_pred {
251        counts.entry(span.0.clone()).or_insert((0, 0, 0));
252    }
253
254    for span in &all_pred {
255        let entry = counts.entry(span.0.clone()).or_insert((0, 0, 0));
256        if all_gold.contains(span) {
257            entry.0 += 1; // tp
258        } else {
259            entry.1 += 1; // fp
260        }
261    }
262    for span in &all_gold {
263        let entry = counts.entry(span.0.clone()).or_insert((0, 0, 0));
264        if !all_pred.contains(span) {
265            entry.2 += 1; // fn
266        }
267    }
268
269    // Micro-average
270    let (total_tp, total_fp, total_fn) = counts.values().fold((0, 0, 0), |(tp, fp, fnn), v| {
271        (tp + v.0, fp + v.1, fnn + v.2)
272    });
273
274    let precision = if total_tp + total_fp == 0 {
275        0.0
276    } else {
277        total_tp as f64 / (total_tp + total_fp) as f64
278    };
279    let recall = if total_tp + total_fn == 0 {
280        0.0
281    } else {
282        total_tp as f64 / (total_tp + total_fn) as f64
283    };
284    let f1 = if precision + recall < 1e-12 {
285        0.0
286    } else {
287        2.0 * precision * recall / (precision + recall)
288    };
289
290    Ok(SequenceLabelMetrics {
291        precision,
292        recall,
293        f1,
294        entity_counts: counts,
295    })
296}
297
298// ---------------------------------------------------------------------------
299// Tests
300// ---------------------------------------------------------------------------
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    fn make_decoder() -> ViterbiDecoder {
307        ViterbiDecoder::new(vec![
308            "O".into(),
309            "B-PER".into(),
310            "I-PER".into(),
311            "B-ORG".into(),
312            "I-ORG".into(),
313        ])
314    }
315
316    #[test]
317    fn test_viterbi_simple_chain() {
318        // 3 positions, 2 tags (0 and 1)
319        let decoder = ViterbiDecoder::new(vec!["O".into(), "B-PER".into()]);
320        // emissions strongly prefer 0, 1, 0
321        let emissions = vec![vec![-0.1, -10.0], vec![-10.0, -0.1], vec![-0.1, -10.0]];
322        // uniform transitions
323        let transitions = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
324        let path = decoder.decode(&emissions, &transitions).unwrap();
325        assert_eq!(path, vec![0, 1, 0]);
326    }
327
328    #[test]
329    fn test_viterbi_all_same() {
330        // All emissions identical — transitions govern
331        let decoder = ViterbiDecoder::new(vec!["O".into(), "B-LOC".into()]);
332        let emissions = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
333        // transitions: prefer staying in tag 1
334        let transitions = vec![vec![-1.0, 0.0], vec![0.0, 1.0]];
335        let path = decoder.decode(&emissions, &transitions).unwrap();
336        // Second tag (1) should dominate due to self-loop reward
337        // At t=0 both equal; at t=1 tag 1 gets +1 from stay
338        assert_eq!(path.len(), 2);
339    }
340
341    #[test]
342    fn test_indices_to_bio() {
343        let decoder = make_decoder();
344        // indices: O B-PER I-PER O B-ORG
345        let indices = vec![0, 1, 2, 0, 3];
346        let bio = decoder.indices_to_bio(&indices).unwrap();
347        assert_eq!(bio[0], BioTag::O);
348        assert_eq!(bio[1], BioTag::B("PER".into()));
349        assert_eq!(bio[2], BioTag::I("PER".into()));
350        assert_eq!(bio[3], BioTag::O);
351        assert_eq!(bio[4], BioTag::B("ORG".into()));
352    }
353
354    #[test]
355    fn test_extract_entities_basic() {
356        // B-PER I-PER O = one PER entity at positions 0..2
357        let tags = vec![BioTag::B("PER".into()), BioTag::I("PER".into()), BioTag::O];
358        let entities = ViterbiDecoder::extract_entities(&tags);
359        assert_eq!(entities.len(), 1);
360        assert_eq!(entities[0], ("PER".to_owned(), 0, 2));
361    }
362
363    #[test]
364    fn test_extract_entities_two_entities() {
365        let tags = vec![
366            BioTag::B("PER".into()),
367            BioTag::O,
368            BioTag::B("ORG".into()),
369            BioTag::I("ORG".into()),
370        ];
371        let entities = ViterbiDecoder::extract_entities(&tags);
372        assert_eq!(entities.len(), 2);
373        assert_eq!(entities[0], ("PER".to_owned(), 0, 1));
374        assert_eq!(entities[1], ("ORG".to_owned(), 2, 4));
375    }
376
377    #[test]
378    fn test_sequence_labeling_perfect_f1() {
379        let gold = vec![vec![
380            BioTag::B("PER".into()),
381            BioTag::I("PER".into()),
382            BioTag::O,
383        ]];
384        let pred = gold.clone();
385        let metrics = evaluate_sequence_labeling(&pred, &gold).unwrap();
386        assert!((metrics.f1 - 1.0).abs() < 1e-9, "perfect pred → F1 = 1.0");
387        assert!((metrics.precision - 1.0).abs() < 1e-9);
388        assert!((metrics.recall - 1.0).abs() < 1e-9);
389    }
390
391    #[test]
392    fn test_sequence_labeling_no_overlap() {
393        let gold = vec![vec![BioTag::B("PER".into()), BioTag::O]];
394        let pred = vec![vec![BioTag::O, BioTag::B("ORG".into())]];
395        let metrics = evaluate_sequence_labeling(&pred, &gold).unwrap();
396        assert_eq!(metrics.f1, 0.0, "no overlap → F1 = 0.0");
397    }
398
399    #[test]
400    fn test_empty_sequence_returns_error() {
401        let decoder = make_decoder();
402        let result = decoder.decode(&[], &[]);
403        assert!(result.is_err(), "empty emissions should fail");
404    }
405}