1use crate::error::{Result, TextError};
5use std::collections::HashMap;
6
7#[non_exhaustive]
13#[derive(Debug, Clone, PartialEq)]
14pub enum BioTag {
15 B(String),
17 I(String),
19 O,
21}
22
23impl BioTag {
24 pub fn entity_type(&self) -> Option<&str> {
26 match self {
27 BioTag::B(t) | BioTag::I(t) => Some(t.as_str()),
28 BioTag::O => None,
29 }
30 }
31
32 pub fn is_begin(&self) -> bool {
34 matches!(self, BioTag::B(_))
35 }
36
37 pub fn is_inside(&self) -> bool {
39 matches!(self, BioTag::I(_))
40 }
41}
42
43pub struct ViterbiDecoder {
49 pub n_tags: usize,
51 pub tag_names: Vec<String>,
53}
54
55impl ViterbiDecoder {
56 pub fn new(tag_names: Vec<String>) -> Self {
58 let n_tags = tag_names.len();
59 Self { n_tags, tag_names }
60 }
61
62 pub fn decode(&self, emissions: &[Vec<f64>], transitions: &[Vec<f64>]) -> Result<Vec<usize>> {
69 let seq_len = emissions.len();
70 if seq_len == 0 {
71 return Err(TextError::InvalidInput(
72 "Viterbi: empty emission sequence".into(),
73 ));
74 }
75 if transitions.len() != self.n_tags {
76 return Err(TextError::InvalidInput(format!(
77 "transitions rows {} != n_tags {}",
78 transitions.len(),
79 self.n_tags
80 )));
81 }
82 for row in emissions {
83 if row.len() != self.n_tags {
84 return Err(TextError::InvalidInput(format!(
85 "emission width {} != n_tags {}",
86 row.len(),
87 self.n_tags
88 )));
89 }
90 }
91
92 let n = self.n_tags;
93 let mut dp = vec![vec![f64::NEG_INFINITY; n]; seq_len];
95 let mut bp = vec![vec![0_usize; n]; seq_len];
97
98 for k in 0..n {
100 dp[0][k] = emissions[0][k];
101 }
102
103 for t in 1..seq_len {
105 for k in 0..n {
106 let mut best_score = f64::NEG_INFINITY;
107 let mut best_prev = 0;
108 for j in 0..n {
109 let score = dp[t - 1][j] + transitions[j][k] + emissions[t][k];
110 if score > best_score {
111 best_score = score;
112 best_prev = j;
113 }
114 }
115 dp[t][k] = best_score;
116 bp[t][k] = best_prev;
117 }
118 }
119
120 let mut best_last = 0;
122 let mut best_last_score = f64::NEG_INFINITY;
123 for k in 0..n {
124 if dp[seq_len - 1][k] > best_last_score {
125 best_last_score = dp[seq_len - 1][k];
126 best_last = k;
127 }
128 }
129
130 let mut path = vec![0_usize; seq_len];
132 path[seq_len - 1] = best_last;
133 for t in (1..seq_len).rev() {
134 path[t - 1] = bp[t][path[t]];
135 }
136
137 Ok(path)
138 }
139
140 pub fn indices_to_bio(&self, indices: &[usize]) -> Result<Vec<BioTag>> {
145 indices
146 .iter()
147 .map(|&idx| {
148 if idx >= self.n_tags {
149 return Err(TextError::InvalidInput(format!(
150 "tag index {} out of range {}",
151 idx, self.n_tags
152 )));
153 }
154 let name = &self.tag_names[idx];
155 let bio = if name.starts_with("B-") {
156 BioTag::B(name[2..].to_owned())
157 } else if name.starts_with("I-") {
158 BioTag::I(name[2..].to_owned())
159 } else {
160 BioTag::O
161 };
162 Ok(bio)
163 })
164 .collect()
165 }
166
167 pub fn extract_entities(bio_tags: &[BioTag]) -> Vec<(String, usize, usize)> {
171 let mut entities = Vec::new();
172 let mut i = 0;
173 while i < bio_tags.len() {
174 if let BioTag::B(etype) = &bio_tags[i] {
175 let start = i;
176 let entity_type = etype.clone();
177 i += 1;
178 while i < bio_tags.len() {
179 match &bio_tags[i] {
180 BioTag::I(t) if t == &entity_type => {
181 i += 1;
182 }
183 _ => break,
184 }
185 }
186 entities.push((entity_type, start, i));
187 } else {
188 i += 1;
189 }
190 }
191 entities
192 }
193}
194
195#[derive(Debug, Clone)]
201pub struct SequenceLabelMetrics {
202 pub precision: f64,
204 pub recall: f64,
206 pub f1: f64,
208 pub entity_counts: HashMap<String, (usize, usize, usize)>,
210}
211
212pub fn evaluate_sequence_labeling(
216 predicted: &[Vec<BioTag>],
217 gold: &[Vec<BioTag>],
218) -> Result<SequenceLabelMetrics> {
219 if predicted.len() != gold.len() {
220 return Err(TextError::InvalidInput(format!(
221 "predicted {} sequences != gold {}",
222 predicted.len(),
223 gold.len()
224 )));
225 }
226
227 let collect_spans = |seq: &Vec<BioTag>, offset: usize| -> Vec<(String, usize, usize)> {
229 ViterbiDecoder::extract_entities(seq)
230 .into_iter()
231 .map(|(t, s, e)| (t, s + offset, e + offset))
232 .collect()
233 };
234
235 let mut all_pred: Vec<(String, usize, usize)> = Vec::new();
236 let mut all_gold: Vec<(String, usize, usize)> = Vec::new();
237 let mut offset = 0;
238 for (pred_seq, gold_seq) in predicted.iter().zip(gold) {
239 all_pred.extend(collect_spans(pred_seq, offset));
240 all_gold.extend(collect_spans(gold_seq, offset));
241 offset += pred_seq.len().max(gold_seq.len());
242 }
243
244 let mut counts: HashMap<String, (usize, usize, usize)> = HashMap::new();
246
247 for span in &all_gold {
248 counts.entry(span.0.clone()).or_insert((0, 0, 0));
249 }
250 for span in &all_pred {
251 counts.entry(span.0.clone()).or_insert((0, 0, 0));
252 }
253
254 for span in &all_pred {
255 let entry = counts.entry(span.0.clone()).or_insert((0, 0, 0));
256 if all_gold.contains(span) {
257 entry.0 += 1; } else {
259 entry.1 += 1; }
261 }
262 for span in &all_gold {
263 let entry = counts.entry(span.0.clone()).or_insert((0, 0, 0));
264 if !all_pred.contains(span) {
265 entry.2 += 1; }
267 }
268
269 let (total_tp, total_fp, total_fn) = counts.values().fold((0, 0, 0), |(tp, fp, fnn), v| {
271 (tp + v.0, fp + v.1, fnn + v.2)
272 });
273
274 let precision = if total_tp + total_fp == 0 {
275 0.0
276 } else {
277 total_tp as f64 / (total_tp + total_fp) as f64
278 };
279 let recall = if total_tp + total_fn == 0 {
280 0.0
281 } else {
282 total_tp as f64 / (total_tp + total_fn) as f64
283 };
284 let f1 = if precision + recall < 1e-12 {
285 0.0
286 } else {
287 2.0 * precision * recall / (precision + recall)
288 };
289
290 Ok(SequenceLabelMetrics {
291 precision,
292 recall,
293 f1,
294 entity_counts: counts,
295 })
296}
297
298#[cfg(test)]
303mod tests {
304 use super::*;
305
306 fn make_decoder() -> ViterbiDecoder {
307 ViterbiDecoder::new(vec![
308 "O".into(),
309 "B-PER".into(),
310 "I-PER".into(),
311 "B-ORG".into(),
312 "I-ORG".into(),
313 ])
314 }
315
316 #[test]
317 fn test_viterbi_simple_chain() {
318 let decoder = ViterbiDecoder::new(vec!["O".into(), "B-PER".into()]);
320 let emissions = vec![vec![-0.1, -10.0], vec![-10.0, -0.1], vec![-0.1, -10.0]];
322 let transitions = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
324 let path = decoder.decode(&emissions, &transitions).unwrap();
325 assert_eq!(path, vec![0, 1, 0]);
326 }
327
328 #[test]
329 fn test_viterbi_all_same() {
330 let decoder = ViterbiDecoder::new(vec!["O".into(), "B-LOC".into()]);
332 let emissions = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
333 let transitions = vec![vec![-1.0, 0.0], vec![0.0, 1.0]];
335 let path = decoder.decode(&emissions, &transitions).unwrap();
336 assert_eq!(path.len(), 2);
339 }
340
341 #[test]
342 fn test_indices_to_bio() {
343 let decoder = make_decoder();
344 let indices = vec![0, 1, 2, 0, 3];
346 let bio = decoder.indices_to_bio(&indices).unwrap();
347 assert_eq!(bio[0], BioTag::O);
348 assert_eq!(bio[1], BioTag::B("PER".into()));
349 assert_eq!(bio[2], BioTag::I("PER".into()));
350 assert_eq!(bio[3], BioTag::O);
351 assert_eq!(bio[4], BioTag::B("ORG".into()));
352 }
353
354 #[test]
355 fn test_extract_entities_basic() {
356 let tags = vec![BioTag::B("PER".into()), BioTag::I("PER".into()), BioTag::O];
358 let entities = ViterbiDecoder::extract_entities(&tags);
359 assert_eq!(entities.len(), 1);
360 assert_eq!(entities[0], ("PER".to_owned(), 0, 2));
361 }
362
363 #[test]
364 fn test_extract_entities_two_entities() {
365 let tags = vec![
366 BioTag::B("PER".into()),
367 BioTag::O,
368 BioTag::B("ORG".into()),
369 BioTag::I("ORG".into()),
370 ];
371 let entities = ViterbiDecoder::extract_entities(&tags);
372 assert_eq!(entities.len(), 2);
373 assert_eq!(entities[0], ("PER".to_owned(), 0, 1));
374 assert_eq!(entities[1], ("ORG".to_owned(), 2, 4));
375 }
376
377 #[test]
378 fn test_sequence_labeling_perfect_f1() {
379 let gold = vec![vec![
380 BioTag::B("PER".into()),
381 BioTag::I("PER".into()),
382 BioTag::O,
383 ]];
384 let pred = gold.clone();
385 let metrics = evaluate_sequence_labeling(&pred, &gold).unwrap();
386 assert!((metrics.f1 - 1.0).abs() < 1e-9, "perfect pred → F1 = 1.0");
387 assert!((metrics.precision - 1.0).abs() < 1e-9);
388 assert!((metrics.recall - 1.0).abs() < 1e-9);
389 }
390
391 #[test]
392 fn test_sequence_labeling_no_overlap() {
393 let gold = vec![vec![BioTag::B("PER".into()), BioTag::O]];
394 let pred = vec![vec![BioTag::O, BioTag::B("ORG".into())]];
395 let metrics = evaluate_sequence_labeling(&pred, &gold).unwrap();
396 assert_eq!(metrics.f1, 0.0, "no overlap → F1 = 0.0");
397 }
398
399 #[test]
400 fn test_empty_sequence_returns_error() {
401 let decoder = make_decoder();
402 let result = decoder.decode(&[], &[]);
403 assert!(result.is_err(), "empty emissions should fail");
404 }
405}