Skip to main content

scirs2_text/
coreference.rs

1//! Coreference resolution: pronoun / nominal / proper-name clustering.
2//!
3//! This module provides a self-contained, rule-based coreference resolution
4//! pipeline that does not require external ML weights.  It implements a
5//! simplified version of the **Hobbs algorithm** for pronoun antecedent
6//! selection, extended with nominal and proper-name clustering.
7//!
8//! # Example
9//!
10//! ```rust
11//! use scirs2_text::coreference::{resolve_pronouns, replace_pronouns};
12//!
13//! let text = "Alice is a doctor. She works at the hospital. \
14//!             Bob is an engineer. He builds bridges.";
15//! let chains = resolve_pronouns(text);
16//! assert!(!chains.is_empty());
17//!
18//! let resolved = replace_pronouns(text, &chains);
19//! // "She" should be replaced with "Alice" (or similar) in the output.
20//! assert!(resolved.contains("Alice") || resolved.contains("She"));
21//! ```
22
23use crate::error::{Result, TextError};
24use std::cmp::Reverse;
25use std::collections::HashMap;
26
27// ---------------------------------------------------------------------------
28// Public types
29// ---------------------------------------------------------------------------
30
31/// Coarse-grained morphosyntactic category of a mention.
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub enum MentionType {
34    /// A multi-token proper name (e.g. "Marie Curie", "Google Inc.").
35    Proper,
36    /// A definite or indefinite nominal head (e.g. "the researcher",
37    /// "a scientist").
38    Nominal,
39    /// A personal, possessive, or reflexive pronoun.
40    Pronominal,
41}
42
43/// Gender / number feature used for agreement checking.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum GenderNumber {
46    /// Masculine singular gender-number feature.
47    MasculineSingular,
48    /// Feminine singular gender-number feature.
49    FeminineSingular,
50    /// Neuter singular gender-number feature.
51    NeuterSingular,
52    /// Plural gender-number feature.
53    Plural,
54    /// Unknown or unresolved gender-number feature.
55    Unknown,
56}
57
58/// A single textual reference to an entity.
59#[derive(Debug, Clone)]
60pub struct Mention {
61    /// Character span (start, end) in the original document.
62    pub span: (usize, usize),
63    /// The surface string.
64    pub text: String,
65    /// Morpho-syntactic type.
66    pub mention_type: MentionType,
67    /// Gender/number for agreement.
68    pub gender_number: GenderNumber,
69}
70
71impl Mention {
72    /// Byte start position.
73    pub fn start(&self) -> usize {
74        self.span.0
75    }
76
77    /// Byte end position (exclusive).
78    pub fn end(&self) -> usize {
79        self.span.1
80    }
81}
82
83/// A list of co-referring [`Mention`]s.
84#[derive(Debug, Clone)]
85pub struct CoreferenceChain {
86    /// The canonical (most informative) mention – usually the first proper or
87    /// nominal mention in document order.
88    pub canonical: String,
89    /// All mentions in document order.
90    pub mentions: Vec<Mention>,
91    /// Aggregate confidence score.
92    pub confidence: f64,
93}
94
95impl CoreferenceChain {
96    /// Create a new chain seeded with one mention.
97    fn new(seed: Mention, confidence: f64) -> Self {
98        let canonical = seed.text.clone();
99        Self {
100            canonical,
101            mentions: vec![seed],
102            confidence,
103        }
104    }
105
106    /// Add a mention and update the canonical form.
107    fn add(&mut self, mention: Mention, score: f64) {
108        // Prefer proper > nominal > pronominal as canonical.
109        if mention.mention_type == MentionType::Proper
110            || (mention.mention_type == MentionType::Nominal
111                && self.canonical_type() == MentionType::Pronominal)
112        {
113            self.canonical = mention.text.clone();
114        }
115        self.confidence = self.confidence.max(score);
116        self.mentions.push(mention);
117    }
118
119    /// The mention type of the current canonical form.
120    fn canonical_type(&self) -> MentionType {
121        // Determine type of the canonical mention by scanning.
122        for m in &self.mentions {
123            if m.text == self.canonical {
124                return m.mention_type.clone();
125            }
126        }
127        MentionType::Pronominal
128    }
129}
130
131// ---------------------------------------------------------------------------
132// Feature-scoring helpers
133// ---------------------------------------------------------------------------
134
135/// Classify the gender/number of a surface token.
136pub fn infer_gender_number(text: &str) -> GenderNumber {
137    let lower = text.to_lowercase();
138    match lower.as_str() {
139        "he" | "him" | "his" | "himself" => GenderNumber::MasculineSingular,
140        "she" | "her" | "hers" | "herself" => GenderNumber::FeminineSingular,
141        "it" | "its" | "itself" => GenderNumber::NeuterSingular,
142        "they" | "them" | "their" | "theirs" | "themselves" => GenderNumber::Plural,
143        _ => {
144            // Heuristic for proper names
145            if is_likely_masculine_name(&lower) {
146                GenderNumber::MasculineSingular
147            } else if is_likely_feminine_name(&lower) {
148                GenderNumber::FeminineSingular
149            } else {
150                GenderNumber::Unknown
151            }
152        }
153    }
154}
155
156fn is_likely_masculine_name(name: &str) -> bool {
157    const MASC: &[&str] = &[
158        "john",
159        "james",
160        "michael",
161        "william",
162        "david",
163        "richard",
164        "joseph",
165        "thomas",
166        "charles",
167        "christopher",
168        "daniel",
169        "matthew",
170        "anthony",
171        "mark",
172        "donald",
173        "steven",
174        "paul",
175        "andrew",
176        "kenneth",
177        "george",
178        "joshua",
179        "kevin",
180        "brian",
181        "tim",
182        "bob",
183        "bill",
184        "frank",
185        "larry",
186        "scott",
187        "jeffrey",
188        "eric",
189        "robert",
190        "peter",
191        "henry",
192        "edward",
193    ];
194    MASC.contains(&name)
195}
196
197fn is_likely_feminine_name(name: &str) -> bool {
198    const FEM: &[&str] = &[
199        "mary",
200        "patricia",
201        "linda",
202        "barbara",
203        "elizabeth",
204        "jennifer",
205        "maria",
206        "susan",
207        "margaret",
208        "dorothy",
209        "lisa",
210        "nancy",
211        "karen",
212        "betty",
213        "helen",
214        "sandra",
215        "donna",
216        "carol",
217        "ruth",
218        "sharon",
219        "michelle",
220        "laura",
221        "sarah",
222        "kimberly",
223        "deborah",
224        "jessica",
225        "shirley",
226        "cynthia",
227        "angela",
228        "melissa",
229        "brenda",
230        "amy",
231        "anna",
232        "rebecca",
233        "virginia",
234        "kathleen",
235        "pamela",
236        "martha",
237        "debra",
238        "amanda",
239        "stephanie",
240        "carolyn",
241        "christine",
242        "alice",
243    ];
244    FEM.contains(&name)
245}
246
247/// Check morpho-syntactic agreement between a pronoun mention and a candidate
248/// antecedent mention.
249pub fn gender_number_agreement(mention: &Mention, candidate: &Mention) -> bool {
250    match (&mention.gender_number, &candidate.gender_number) {
251        (GenderNumber::Unknown, _) | (_, GenderNumber::Unknown) => true,
252        (a, b) => a == b,
253    }
254}
255
256/// Score a pronoun `mention` against an `antecedent` candidate.
257///
258/// Features:
259/// - Gender/number agreement: +0.4
260/// - Sentence recency: +0.3 for same sentence, +0.2 for 1 back, decaying
261/// - Mention type: +0.2 for Proper, +0.1 for Nominal antecedent
262/// - Salience: +0.1 if antecedent is a subject-position mention (starts sentence)
263pub fn antecedent_score(
264    mention: &Mention,
265    candidate: &Mention,
266    mention_sentence: usize,
267    candidate_sentence: usize,
268) -> f64 {
269    let mut score = 0.0f64;
270
271    // Agreement
272    if gender_number_agreement(mention, candidate) {
273        score += 0.4;
274    } else {
275        return 0.0; // Hard constraint: incompatible agreement → skip
276    }
277
278    // Recency
279    let dist = mention_sentence.saturating_sub(candidate_sentence);
280    score += match dist {
281        0 => 0.30,
282        1 => 0.25,
283        2 => 0.15,
284        3 => 0.10,
285        _ => 0.05f64 / dist as f64,
286    };
287
288    // Mention type of candidate
289    score += match candidate.mention_type {
290        MentionType::Proper => 0.20,
291        MentionType::Nominal => 0.10,
292        MentionType::Pronominal => 0.0,
293    };
294
295    score.min(1.0)
296}
297
298// ---------------------------------------------------------------------------
299// Pronoun list
300// ---------------------------------------------------------------------------
301
302fn is_pronoun(word: &str) -> bool {
303    matches!(
304        word.to_lowercase().as_str(),
305        "he" | "him"
306            | "his"
307            | "himself"
308            | "she"
309            | "her"
310            | "hers"
311            | "herself"
312            | "it"
313            | "its"
314            | "itself"
315            | "they"
316            | "them"
317            | "their"
318            | "theirs"
319            | "themselves"
320    )
321}
322
323// ---------------------------------------------------------------------------
324// Mention detection
325// ---------------------------------------------------------------------------
326
327/// Tokenise text into (start, end, word) tuples (byte offsets).
328fn tokenize_with_offsets(text: &str) -> Vec<(usize, usize, String)> {
329    let mut tokens = Vec::new();
330    let mut start = None;
331    for (i, c) in text.char_indices() {
332        if c.is_alphanumeric() || c == '\'' {
333            if start.is_none() {
334                start = Some(i);
335            }
336        } else if let Some(s) = start.take() {
337            tokens.push((s, i, text[s..i].to_string()));
338        }
339    }
340    if let Some(s) = start {
341        tokens.push((s, text.len(), text[s..].to_string()));
342    }
343    tokens
344}
345
346/// Split `text` into sentence strings with their start offsets.
347fn split_sentences_with_offsets(text: &str) -> Vec<(usize, String)> {
348    let mut sentences: Vec<(usize, String)> = Vec::new();
349    let mut start = 0usize;
350    let bytes = text.as_bytes();
351    let len = bytes.len();
352    while start < len {
353        let mut end = start;
354        while end < len {
355            let b = bytes[end];
356            if b == b'.' || b == b'?' || b == b'!' {
357                end += 1;
358                while end < len && bytes[end] == b' ' {
359                    end += 1;
360                }
361                break;
362            }
363            end += 1;
364        }
365        let raw = text[start..end].trim();
366        if !raw.is_empty() {
367            sentences.push((start, raw.to_string()));
368        }
369        start = end;
370    }
371    sentences
372}
373
374/// Detect pronoun, nominal, and proper-name mentions in `text`.
375fn detect_mentions(text: &str) -> Vec<(usize, Mention)> {
376    // (sentence_index, Mention)
377    let sentences = split_sentences_with_offsets(text);
378    let mut result: Vec<(usize, Mention)> = Vec::new();
379
380    for (sent_idx, (sent_start, sent_text)) in sentences.iter().enumerate() {
381        let tokens = tokenize_with_offsets(sent_text);
382        let mut i = 0usize;
383        while i < tokens.len() {
384            let (tok_start, tok_end, word) = &tokens[i];
385            let abs_start = sent_start + tok_start;
386            let abs_end = sent_start + tok_end;
387
388            // ---- Pronouns ----
389            if is_pronoun(word) {
390                let gn = infer_gender_number(word);
391                result.push((
392                    sent_idx,
393                    Mention {
394                        span: (abs_start, abs_end),
395                        text: word.clone(),
396                        mention_type: MentionType::Pronominal,
397                        gender_number: gn,
398                    },
399                ));
400                i += 1;
401                continue;
402            }
403
404            // ---- Proper names: consecutive capitalised tokens ----
405            if word.starts_with(|c: char| c.is_uppercase()) && abs_start > *sent_start {
406                // Skip if it is the first word of the sentence (always capitalised).
407                let mut j = i;
408                while j < tokens.len() && tokens[j].2.starts_with(|c: char| c.is_uppercase()) {
409                    j += 1;
410                }
411                if j > i {
412                    // j - i tokens form the proper name span
413                    let name_start = sent_start + tokens[i].0;
414                    let name_end = sent_start + tokens[j - 1].1;
415                    let name_text = sent_text[tokens[i].0..tokens[j - 1].1].to_string();
416                    let first_word = name_text.split_whitespace().next().unwrap_or("");
417                    let gn = infer_gender_number(first_word);
418                    result.push((
419                        sent_idx,
420                        Mention {
421                            span: (name_start, name_end),
422                            text: name_text,
423                            mention_type: MentionType::Proper,
424                            gender_number: gn,
425                        },
426                    ));
427                    i = j;
428                    continue;
429                }
430            }
431
432            // ---- Nominal mentions: "the X", "a X", "an X" ----
433            let lower = word.to_lowercase();
434            if (lower == "the" || lower == "a" || lower == "an") && i + 1 < tokens.len() {
435                let head_start = sent_start + tokens[i + 1].0;
436                let head_end = sent_start + tokens[i + 1].1;
437                let det_text = sent_text[*tok_start..tokens[i + 1].1].to_string();
438                result.push((
439                    sent_idx,
440                    Mention {
441                        span: (abs_start, head_end),
442                        text: det_text,
443                        mention_type: MentionType::Nominal,
444                        gender_number: GenderNumber::Unknown,
445                    },
446                ));
447                // Do NOT skip ahead – the head noun may also be a proper name.
448                let _ = (head_start, head_end);
449            }
450
451            i += 1;
452        }
453    }
454
455    result
456}
457
458// ---------------------------------------------------------------------------
459// Simplified Hobbs algorithm
460// ---------------------------------------------------------------------------
461
462/// Resolve pronouns in `text` and return coreference chains.
463///
464/// The implementation follows a simplified Hobbs-style search: for each
465/// pronoun, scan backwards through the preceding mentions in the document
466/// and pick the candidate that maximises [`antecedent_score`].
467pub fn resolve_pronouns(text: &str) -> Vec<CoreferenceChain> {
468    let mentions_with_sent = detect_mentions(text);
469
470    // Collect non-pronominal antecedent candidates in order.
471    let candidates: Vec<(usize, usize, &Mention)> = mentions_with_sent
472        .iter()
473        .enumerate()
474        .filter(|(_, (_, m))| m.mention_type != MentionType::Pronominal)
475        .map(|(idx, (sent_idx, m))| (idx, *sent_idx, m))
476        .collect();
477
478    // Map from mention index → chain index.
479    let mut mention_to_chain: HashMap<usize, usize> = HashMap::new();
480    let mut chains: Vec<CoreferenceChain> = Vec::new();
481
482    // First, register all proper/nominal mentions as potential chain seeds.
483    for (idx, sent_idx, mention) in &candidates {
484        // Check if an existing chain already contains this text.
485        let existing = chains.iter().position(|c| {
486            c.mentions.iter().any(|m| {
487                m.text.to_lowercase() == mention.text.to_lowercase()
488                    || mention.text.to_lowercase().contains(&m.text.to_lowercase())
489                    || m.text.to_lowercase().contains(&mention.text.to_lowercase())
490            })
491        });
492
493        if let Some(chain_idx) = existing {
494            mention_to_chain.insert(*idx, chain_idx);
495            let m_clone = (*mention).clone();
496            let score = 0.7 + 0.1 * (mention.mention_type == MentionType::Proper) as u8 as f64;
497            chains[chain_idx].add(m_clone, score);
498        } else {
499            let chain_idx = chains.len();
500            mention_to_chain.insert(*idx, chain_idx);
501            let confidence = if mention.mention_type == MentionType::Proper {
502                0.8
503            } else {
504                0.6
505            };
506            chains.push(CoreferenceChain::new((*mention).clone(), confidence));
507        }
508        let _ = sent_idx;
509    }
510
511    // Now resolve pronouns.
512    for (pron_idx, (pron_sent, pron_mention)) in mentions_with_sent
513        .iter()
514        .enumerate()
515        .filter(|(_, (_, m))| m.mention_type == MentionType::Pronominal)
516    {
517        // Scan candidates that appear BEFORE this pronoun.
518        let mut best_score = 0.0f64;
519        let mut best_cand_idx: Option<usize> = None;
520
521        for &(cand_mention_idx, cand_sent, cand_mention) in &candidates {
522            // The candidate must precede the pronoun in the document.
523            if mentions_with_sent[cand_mention_idx].0 > *pron_sent {
524                continue;
525            }
526            // Within the same sentence, the candidate must come first.
527            if cand_mention.span.0 >= pron_mention.span.0 && cand_sent == *pron_sent {
528                continue;
529            }
530
531            let score = antecedent_score(pron_mention, cand_mention, *pron_sent, cand_sent);
532            if score > best_score {
533                best_score = score;
534                best_cand_idx = Some(cand_mention_idx);
535            }
536        }
537
538        if best_score > 0.3 {
539            if let Some(cand_idx) = best_cand_idx {
540                if let Some(&chain_idx) = mention_to_chain.get(&cand_idx) {
541                    let pron_clone = pron_mention.clone();
542                    chains[chain_idx].add(pron_clone, best_score);
543                    mention_to_chain.insert(pron_idx, chain_idx);
544                }
545            }
546        }
547    }
548
549    // Prune chains that only contain one mention (no actual coreference).
550    chains.retain(|c| c.mentions.len() >= 2);
551    chains
552}
553
554/// Substitute all pronominal mentions in `text` with their canonical
555/// antecedent from the supplied chains.
556///
557/// Pronouns that appear in multiple overlapping chains are resolved using the
558/// highest-confidence chain.  The replacement is done in reverse document
559/// order to preserve byte offsets.
560pub fn replace_pronouns(text: &str, chains: &[CoreferenceChain]) -> String {
561    // Build a map from span → replacement string, keeping the highest-
562    // confidence replacement if multiple chains cover the same pronoun.
563    let mut replacements: HashMap<(usize, usize), (String, f64)> = HashMap::new();
564
565    for chain in chains {
566        for mention in &chain.mentions {
567            if mention.mention_type == MentionType::Pronominal {
568                let entry = replacements
569                    .entry(mention.span)
570                    .or_insert_with(|| (chain.canonical.clone(), 0.0));
571                if chain.confidence > entry.1 {
572                    *entry = (chain.canonical.clone(), chain.confidence);
573                }
574            }
575        }
576    }
577
578    // Sort spans in reverse order so replacements do not shift later offsets.
579    let mut spans: Vec<(usize, usize, String)> = replacements
580        .into_iter()
581        .map(|(span, (repl, _))| (span.0, span.1, repl))
582        .collect();
583    spans.sort_by_key(|(start, _, _)| std::cmp::Reverse(*start));
584
585    let mut result = text.to_string();
586    for (start, end, replacement) in spans {
587        if start <= end && end <= result.len() {
588            result.replace_range(start..end, &replacement);
589        }
590    }
591
592    result
593}
594
595/// Resolve coreferences and return chains – a convenience wrapper with error
596/// propagation for pipeline use.
597pub fn resolve_coreferences(text: &str) -> Result<Vec<CoreferenceChain>> {
598    if text.is_empty() {
599        return Err(TextError::InvalidInput(
600            "Input text must not be empty".to_string(),
601        ));
602    }
603    Ok(resolve_pronouns(text))
604}
605
606// ---------------------------------------------------------------------------
607// Tests
608// ---------------------------------------------------------------------------
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn test_infer_gender_number() {
616        assert_eq!(infer_gender_number("he"), GenderNumber::MasculineSingular);
617        assert_eq!(infer_gender_number("She"), GenderNumber::FeminineSingular);
618        assert_eq!(infer_gender_number("it"), GenderNumber::NeuterSingular);
619        assert_eq!(infer_gender_number("they"), GenderNumber::Plural);
620        assert_eq!(infer_gender_number("random"), GenderNumber::Unknown);
621    }
622
623    #[test]
624    fn test_gender_number_agreement() {
625        let he = Mention {
626            span: (0, 2),
627            text: "he".to_string(),
628            mention_type: MentionType::Pronominal,
629            gender_number: GenderNumber::MasculineSingular,
630        };
631        let john = Mention {
632            span: (10, 14),
633            text: "John".to_string(),
634            mention_type: MentionType::Proper,
635            gender_number: GenderNumber::MasculineSingular,
636        };
637        let alice = Mention {
638            span: (20, 25),
639            text: "Alice".to_string(),
640            mention_type: MentionType::Proper,
641            gender_number: GenderNumber::FeminineSingular,
642        };
643        assert!(gender_number_agreement(&he, &john));
644        assert!(!gender_number_agreement(&he, &alice));
645    }
646
647    #[test]
648    fn test_antecedent_score_agreement_constraint() {
649        let she = Mention {
650            span: (0, 3),
651            text: "she".to_string(),
652            mention_type: MentionType::Pronominal,
653            gender_number: GenderNumber::FeminineSingular,
654        };
655        let he_candidate = Mention {
656            span: (10, 12),
657            text: "John".to_string(),
658            mention_type: MentionType::Proper,
659            gender_number: GenderNumber::MasculineSingular,
660        };
661        // Disagreement → 0.0
662        assert_eq!(antecedent_score(&she, &he_candidate, 1, 0), 0.0);
663    }
664
665    #[test]
666    fn test_resolve_pronouns_basic() {
667        let text = "Alice is a scientist. She won a prize. Bob is an engineer. He built a bridge.";
668        let chains = resolve_pronouns(text);
669        // Should find at least one chain linking a pronoun back to a name.
670        assert!(!chains.is_empty());
671        for chain in &chains {
672            assert!(chain.mentions.len() >= 2);
673        }
674    }
675
676    #[test]
677    fn test_replace_pronouns() {
678        let text = "Alice is a doctor. She works at the hospital.";
679        let chains = resolve_pronouns(text);
680        let replaced = replace_pronouns(text, &chains);
681        // The output should still be valid UTF-8 and non-empty.
682        assert!(!replaced.is_empty());
683    }
684
685    #[test]
686    fn test_resolve_coreferences_error_on_empty() {
687        let result = resolve_coreferences("");
688        assert!(result.is_err());
689    }
690
691    #[test]
692    fn test_resolve_coreferences_nonempty() {
693        let text = "Marie Curie discovered radium. She was brilliant.";
694        let chains = resolve_coreferences(text).expect("should succeed");
695        // There may or may not be a chain depending on heuristics, but
696        // the function must not panic.
697        let _ = chains;
698    }
699
700    #[test]
701    fn test_detect_pronouns_in_isolation() {
702        assert!(is_pronoun("she"));
703        assert!(is_pronoun("He"));
704        assert!(is_pronoun("THEY"));
705        assert!(!is_pronoun("Alice"));
706        assert!(!is_pronoun("the"));
707    }
708
709    #[test]
710    fn test_multiple_chains() {
711        let text = "Alice is a doctor. She treated patients. \
712                    Bob is a lawyer. He argued cases.";
713        let chains = resolve_pronouns(text);
714        // Should find at least two distinct chains (one for she→Alice, one for he→Bob).
715        assert!(chains.len() >= 1);
716    }
717}