1use crate::error::{Result, TextError};
24use std::cmp::Reverse;
25use std::collections::HashMap;
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub enum MentionType {
34 Proper,
36 Nominal,
39 Pronominal,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum GenderNumber {
46 MasculineSingular,
48 FeminineSingular,
50 NeuterSingular,
52 Plural,
54 Unknown,
56}
57
58#[derive(Debug, Clone)]
60pub struct Mention {
61 pub span: (usize, usize),
63 pub text: String,
65 pub mention_type: MentionType,
67 pub gender_number: GenderNumber,
69}
70
71impl Mention {
72 pub fn start(&self) -> usize {
74 self.span.0
75 }
76
77 pub fn end(&self) -> usize {
79 self.span.1
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct CoreferenceChain {
86 pub canonical: String,
89 pub mentions: Vec<Mention>,
91 pub confidence: f64,
93}
94
95impl CoreferenceChain {
96 fn new(seed: Mention, confidence: f64) -> Self {
98 let canonical = seed.text.clone();
99 Self {
100 canonical,
101 mentions: vec![seed],
102 confidence,
103 }
104 }
105
106 fn add(&mut self, mention: Mention, score: f64) {
108 if mention.mention_type == MentionType::Proper
110 || (mention.mention_type == MentionType::Nominal
111 && self.canonical_type() == MentionType::Pronominal)
112 {
113 self.canonical = mention.text.clone();
114 }
115 self.confidence = self.confidence.max(score);
116 self.mentions.push(mention);
117 }
118
119 fn canonical_type(&self) -> MentionType {
121 for m in &self.mentions {
123 if m.text == self.canonical {
124 return m.mention_type.clone();
125 }
126 }
127 MentionType::Pronominal
128 }
129}
130
131pub fn infer_gender_number(text: &str) -> GenderNumber {
137 let lower = text.to_lowercase();
138 match lower.as_str() {
139 "he" | "him" | "his" | "himself" => GenderNumber::MasculineSingular,
140 "she" | "her" | "hers" | "herself" => GenderNumber::FeminineSingular,
141 "it" | "its" | "itself" => GenderNumber::NeuterSingular,
142 "they" | "them" | "their" | "theirs" | "themselves" => GenderNumber::Plural,
143 _ => {
144 if is_likely_masculine_name(&lower) {
146 GenderNumber::MasculineSingular
147 } else if is_likely_feminine_name(&lower) {
148 GenderNumber::FeminineSingular
149 } else {
150 GenderNumber::Unknown
151 }
152 }
153 }
154}
155
156fn is_likely_masculine_name(name: &str) -> bool {
157 const MASC: &[&str] = &[
158 "john",
159 "james",
160 "michael",
161 "william",
162 "david",
163 "richard",
164 "joseph",
165 "thomas",
166 "charles",
167 "christopher",
168 "daniel",
169 "matthew",
170 "anthony",
171 "mark",
172 "donald",
173 "steven",
174 "paul",
175 "andrew",
176 "kenneth",
177 "george",
178 "joshua",
179 "kevin",
180 "brian",
181 "tim",
182 "bob",
183 "bill",
184 "frank",
185 "larry",
186 "scott",
187 "jeffrey",
188 "eric",
189 "robert",
190 "peter",
191 "henry",
192 "edward",
193 ];
194 MASC.contains(&name)
195}
196
197fn is_likely_feminine_name(name: &str) -> bool {
198 const FEM: &[&str] = &[
199 "mary",
200 "patricia",
201 "linda",
202 "barbara",
203 "elizabeth",
204 "jennifer",
205 "maria",
206 "susan",
207 "margaret",
208 "dorothy",
209 "lisa",
210 "nancy",
211 "karen",
212 "betty",
213 "helen",
214 "sandra",
215 "donna",
216 "carol",
217 "ruth",
218 "sharon",
219 "michelle",
220 "laura",
221 "sarah",
222 "kimberly",
223 "deborah",
224 "jessica",
225 "shirley",
226 "cynthia",
227 "angela",
228 "melissa",
229 "brenda",
230 "amy",
231 "anna",
232 "rebecca",
233 "virginia",
234 "kathleen",
235 "pamela",
236 "martha",
237 "debra",
238 "amanda",
239 "stephanie",
240 "carolyn",
241 "christine",
242 "alice",
243 ];
244 FEM.contains(&name)
245}
246
247pub fn gender_number_agreement(mention: &Mention, candidate: &Mention) -> bool {
250 match (&mention.gender_number, &candidate.gender_number) {
251 (GenderNumber::Unknown, _) | (_, GenderNumber::Unknown) => true,
252 (a, b) => a == b,
253 }
254}
255
256pub fn antecedent_score(
264 mention: &Mention,
265 candidate: &Mention,
266 mention_sentence: usize,
267 candidate_sentence: usize,
268) -> f64 {
269 let mut score = 0.0f64;
270
271 if gender_number_agreement(mention, candidate) {
273 score += 0.4;
274 } else {
275 return 0.0; }
277
278 let dist = mention_sentence.saturating_sub(candidate_sentence);
280 score += match dist {
281 0 => 0.30,
282 1 => 0.25,
283 2 => 0.15,
284 3 => 0.10,
285 _ => 0.05f64 / dist as f64,
286 };
287
288 score += match candidate.mention_type {
290 MentionType::Proper => 0.20,
291 MentionType::Nominal => 0.10,
292 MentionType::Pronominal => 0.0,
293 };
294
295 score.min(1.0)
296}
297
298fn is_pronoun(word: &str) -> bool {
303 matches!(
304 word.to_lowercase().as_str(),
305 "he" | "him"
306 | "his"
307 | "himself"
308 | "she"
309 | "her"
310 | "hers"
311 | "herself"
312 | "it"
313 | "its"
314 | "itself"
315 | "they"
316 | "them"
317 | "their"
318 | "theirs"
319 | "themselves"
320 )
321}
322
323fn tokenize_with_offsets(text: &str) -> Vec<(usize, usize, String)> {
329 let mut tokens = Vec::new();
330 let mut start = None;
331 for (i, c) in text.char_indices() {
332 if c.is_alphanumeric() || c == '\'' {
333 if start.is_none() {
334 start = Some(i);
335 }
336 } else if let Some(s) = start.take() {
337 tokens.push((s, i, text[s..i].to_string()));
338 }
339 }
340 if let Some(s) = start {
341 tokens.push((s, text.len(), text[s..].to_string()));
342 }
343 tokens
344}
345
346fn split_sentences_with_offsets(text: &str) -> Vec<(usize, String)> {
348 let mut sentences: Vec<(usize, String)> = Vec::new();
349 let mut start = 0usize;
350 let bytes = text.as_bytes();
351 let len = bytes.len();
352 while start < len {
353 let mut end = start;
354 while end < len {
355 let b = bytes[end];
356 if b == b'.' || b == b'?' || b == b'!' {
357 end += 1;
358 while end < len && bytes[end] == b' ' {
359 end += 1;
360 }
361 break;
362 }
363 end += 1;
364 }
365 let raw = text[start..end].trim();
366 if !raw.is_empty() {
367 sentences.push((start, raw.to_string()));
368 }
369 start = end;
370 }
371 sentences
372}
373
374fn detect_mentions(text: &str) -> Vec<(usize, Mention)> {
376 let sentences = split_sentences_with_offsets(text);
378 let mut result: Vec<(usize, Mention)> = Vec::new();
379
380 for (sent_idx, (sent_start, sent_text)) in sentences.iter().enumerate() {
381 let tokens = tokenize_with_offsets(sent_text);
382 let mut i = 0usize;
383 while i < tokens.len() {
384 let (tok_start, tok_end, word) = &tokens[i];
385 let abs_start = sent_start + tok_start;
386 let abs_end = sent_start + tok_end;
387
388 if is_pronoun(word) {
390 let gn = infer_gender_number(word);
391 result.push((
392 sent_idx,
393 Mention {
394 span: (abs_start, abs_end),
395 text: word.clone(),
396 mention_type: MentionType::Pronominal,
397 gender_number: gn,
398 },
399 ));
400 i += 1;
401 continue;
402 }
403
404 if word.starts_with(|c: char| c.is_uppercase()) && abs_start > *sent_start {
406 let mut j = i;
408 while j < tokens.len() && tokens[j].2.starts_with(|c: char| c.is_uppercase()) {
409 j += 1;
410 }
411 if j > i {
412 let name_start = sent_start + tokens[i].0;
414 let name_end = sent_start + tokens[j - 1].1;
415 let name_text = sent_text[tokens[i].0..tokens[j - 1].1].to_string();
416 let first_word = name_text.split_whitespace().next().unwrap_or("");
417 let gn = infer_gender_number(first_word);
418 result.push((
419 sent_idx,
420 Mention {
421 span: (name_start, name_end),
422 text: name_text,
423 mention_type: MentionType::Proper,
424 gender_number: gn,
425 },
426 ));
427 i = j;
428 continue;
429 }
430 }
431
432 let lower = word.to_lowercase();
434 if (lower == "the" || lower == "a" || lower == "an") && i + 1 < tokens.len() {
435 let head_start = sent_start + tokens[i + 1].0;
436 let head_end = sent_start + tokens[i + 1].1;
437 let det_text = sent_text[*tok_start..tokens[i + 1].1].to_string();
438 result.push((
439 sent_idx,
440 Mention {
441 span: (abs_start, head_end),
442 text: det_text,
443 mention_type: MentionType::Nominal,
444 gender_number: GenderNumber::Unknown,
445 },
446 ));
447 let _ = (head_start, head_end);
449 }
450
451 i += 1;
452 }
453 }
454
455 result
456}
457
458pub fn resolve_pronouns(text: &str) -> Vec<CoreferenceChain> {
468 let mentions_with_sent = detect_mentions(text);
469
470 let candidates: Vec<(usize, usize, &Mention)> = mentions_with_sent
472 .iter()
473 .enumerate()
474 .filter(|(_, (_, m))| m.mention_type != MentionType::Pronominal)
475 .map(|(idx, (sent_idx, m))| (idx, *sent_idx, m))
476 .collect();
477
478 let mut mention_to_chain: HashMap<usize, usize> = HashMap::new();
480 let mut chains: Vec<CoreferenceChain> = Vec::new();
481
482 for (idx, sent_idx, mention) in &candidates {
484 let existing = chains.iter().position(|c| {
486 c.mentions.iter().any(|m| {
487 m.text.to_lowercase() == mention.text.to_lowercase()
488 || mention.text.to_lowercase().contains(&m.text.to_lowercase())
489 || m.text.to_lowercase().contains(&mention.text.to_lowercase())
490 })
491 });
492
493 if let Some(chain_idx) = existing {
494 mention_to_chain.insert(*idx, chain_idx);
495 let m_clone = (*mention).clone();
496 let score = 0.7 + 0.1 * (mention.mention_type == MentionType::Proper) as u8 as f64;
497 chains[chain_idx].add(m_clone, score);
498 } else {
499 let chain_idx = chains.len();
500 mention_to_chain.insert(*idx, chain_idx);
501 let confidence = if mention.mention_type == MentionType::Proper {
502 0.8
503 } else {
504 0.6
505 };
506 chains.push(CoreferenceChain::new((*mention).clone(), confidence));
507 }
508 let _ = sent_idx;
509 }
510
511 for (pron_idx, (pron_sent, pron_mention)) in mentions_with_sent
513 .iter()
514 .enumerate()
515 .filter(|(_, (_, m))| m.mention_type == MentionType::Pronominal)
516 {
517 let mut best_score = 0.0f64;
519 let mut best_cand_idx: Option<usize> = None;
520
521 for &(cand_mention_idx, cand_sent, cand_mention) in &candidates {
522 if mentions_with_sent[cand_mention_idx].0 > *pron_sent {
524 continue;
525 }
526 if cand_mention.span.0 >= pron_mention.span.0 && cand_sent == *pron_sent {
528 continue;
529 }
530
531 let score = antecedent_score(pron_mention, cand_mention, *pron_sent, cand_sent);
532 if score > best_score {
533 best_score = score;
534 best_cand_idx = Some(cand_mention_idx);
535 }
536 }
537
538 if best_score > 0.3 {
539 if let Some(cand_idx) = best_cand_idx {
540 if let Some(&chain_idx) = mention_to_chain.get(&cand_idx) {
541 let pron_clone = pron_mention.clone();
542 chains[chain_idx].add(pron_clone, best_score);
543 mention_to_chain.insert(pron_idx, chain_idx);
544 }
545 }
546 }
547 }
548
549 chains.retain(|c| c.mentions.len() >= 2);
551 chains
552}
553
554pub fn replace_pronouns(text: &str, chains: &[CoreferenceChain]) -> String {
561 let mut replacements: HashMap<(usize, usize), (String, f64)> = HashMap::new();
564
565 for chain in chains {
566 for mention in &chain.mentions {
567 if mention.mention_type == MentionType::Pronominal {
568 let entry = replacements
569 .entry(mention.span)
570 .or_insert_with(|| (chain.canonical.clone(), 0.0));
571 if chain.confidence > entry.1 {
572 *entry = (chain.canonical.clone(), chain.confidence);
573 }
574 }
575 }
576 }
577
578 let mut spans: Vec<(usize, usize, String)> = replacements
580 .into_iter()
581 .map(|(span, (repl, _))| (span.0, span.1, repl))
582 .collect();
583 spans.sort_by_key(|(start, _, _)| std::cmp::Reverse(*start));
584
585 let mut result = text.to_string();
586 for (start, end, replacement) in spans {
587 if start <= end && end <= result.len() {
588 result.replace_range(start..end, &replacement);
589 }
590 }
591
592 result
593}
594
595pub fn resolve_coreferences(text: &str) -> Result<Vec<CoreferenceChain>> {
598 if text.is_empty() {
599 return Err(TextError::InvalidInput(
600 "Input text must not be empty".to_string(),
601 ));
602 }
603 Ok(resolve_pronouns(text))
604}
605
606#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_infer_gender_number() {
616 assert_eq!(infer_gender_number("he"), GenderNumber::MasculineSingular);
617 assert_eq!(infer_gender_number("She"), GenderNumber::FeminineSingular);
618 assert_eq!(infer_gender_number("it"), GenderNumber::NeuterSingular);
619 assert_eq!(infer_gender_number("they"), GenderNumber::Plural);
620 assert_eq!(infer_gender_number("random"), GenderNumber::Unknown);
621 }
622
623 #[test]
624 fn test_gender_number_agreement() {
625 let he = Mention {
626 span: (0, 2),
627 text: "he".to_string(),
628 mention_type: MentionType::Pronominal,
629 gender_number: GenderNumber::MasculineSingular,
630 };
631 let john = Mention {
632 span: (10, 14),
633 text: "John".to_string(),
634 mention_type: MentionType::Proper,
635 gender_number: GenderNumber::MasculineSingular,
636 };
637 let alice = Mention {
638 span: (20, 25),
639 text: "Alice".to_string(),
640 mention_type: MentionType::Proper,
641 gender_number: GenderNumber::FeminineSingular,
642 };
643 assert!(gender_number_agreement(&he, &john));
644 assert!(!gender_number_agreement(&he, &alice));
645 }
646
647 #[test]
648 fn test_antecedent_score_agreement_constraint() {
649 let she = Mention {
650 span: (0, 3),
651 text: "she".to_string(),
652 mention_type: MentionType::Pronominal,
653 gender_number: GenderNumber::FeminineSingular,
654 };
655 let he_candidate = Mention {
656 span: (10, 12),
657 text: "John".to_string(),
658 mention_type: MentionType::Proper,
659 gender_number: GenderNumber::MasculineSingular,
660 };
661 assert_eq!(antecedent_score(&she, &he_candidate, 1, 0), 0.0);
663 }
664
665 #[test]
666 fn test_resolve_pronouns_basic() {
667 let text = "Alice is a scientist. She won a prize. Bob is an engineer. He built a bridge.";
668 let chains = resolve_pronouns(text);
669 assert!(!chains.is_empty());
671 for chain in &chains {
672 assert!(chain.mentions.len() >= 2);
673 }
674 }
675
676 #[test]
677 fn test_replace_pronouns() {
678 let text = "Alice is a doctor. She works at the hospital.";
679 let chains = resolve_pronouns(text);
680 let replaced = replace_pronouns(text, &chains);
681 assert!(!replaced.is_empty());
683 }
684
685 #[test]
686 fn test_resolve_coreferences_error_on_empty() {
687 let result = resolve_coreferences("");
688 assert!(result.is_err());
689 }
690
691 #[test]
692 fn test_resolve_coreferences_nonempty() {
693 let text = "Marie Curie discovered radium. She was brilliant.";
694 let chains = resolve_coreferences(text).expect("should succeed");
695 let _ = chains;
698 }
699
700 #[test]
701 fn test_detect_pronouns_in_isolation() {
702 assert!(is_pronoun("she"));
703 assert!(is_pronoun("He"));
704 assert!(is_pronoun("THEY"));
705 assert!(!is_pronoun("Alice"));
706 assert!(!is_pronoun("the"));
707 }
708
709 #[test]
710 fn test_multiple_chains() {
711 let text = "Alice is a doctor. She treated patients. \
712 Bob is a lawyer. He argued cases.";
713 let chains = resolve_pronouns(text);
714 assert!(chains.len() >= 1);
716 }
717}