Skip to main content

scirs2_text/information_extraction/
dep_relations.rs

1//! Dependency-parse based relation extraction.
2//!
3//! Extracts Subject–Verb–Object and custom relational triples from a
4//! dependency parse tree represented as a flat list of arcs.
5
6use std::collections::HashMap;
7
8use crate::error::Result;
9
10// ---------------------------------------------------------------------------
11// Core types
12// ---------------------------------------------------------------------------
13
14/// A single arc in a dependency tree.
15#[derive(Debug, Clone)]
16pub struct DependencyRelation {
17    /// Head token text.
18    pub head: String,
19    /// Dependency relation label (e.g. "nsubj", "obj", "dobj").
20    pub relation: String,
21    /// Dependent token text.
22    pub dependent: String,
23}
24
25/// A pattern that describes what Subject-Predicate-Object triples to extract.
26#[derive(Debug, Clone)]
27pub struct RelationPattern {
28    /// Optional POS filter on the subject (e.g. "NN").
29    pub subject_pos: Option<String>,
30    /// Verbs / predicates that trigger extraction (case-insensitive).
31    pub predicate: Vec<String>,
32    /// Optional POS filter on the object.
33    pub object_pos: Option<String>,
34    /// Label assigned to extracted triples.
35    pub label: String,
36}
37
38// ---------------------------------------------------------------------------
39// RelationExtractorDep
40// ---------------------------------------------------------------------------
41
42/// Dependency-based relation extractor.
43pub struct DependencyRelationExtractor {
44    patterns: Vec<RelationPattern>,
45}
46
47impl Default for DependencyRelationExtractor {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl DependencyRelationExtractor {
54    /// Create an empty extractor.
55    pub fn new() -> DependencyRelationExtractor {
56        DependencyRelationExtractor {
57            patterns: Vec::new(),
58        }
59    }
60
61    /// Add a relation pattern.
62    pub fn add_pattern(&mut self, pattern: RelationPattern) {
63        self.patterns.push(pattern);
64    }
65
66    /// Build a simple default extractor that finds SVO triples.
67    pub fn with_svo_defaults() -> DependencyRelationExtractor {
68        let mut ext = DependencyRelationExtractor::new();
69        ext.add_pattern(RelationPattern {
70            subject_pos: None,
71            predicate: vec![], // any verb
72            object_pos: None,
73            label: "SVO".to_string(),
74        });
75        ext
76    }
77
78    /// Extract (subject, relation_type, object) triples from a dependency tree.
79    ///
80    /// # Arguments
81    /// * `_text`          – original sentence text (unused; available for future context)
82    /// * `dependency_tree` – flat list of dependency arcs
83    ///
84    /// # Returns
85    /// A `Vec<(subject, relation_type, object)>`.
86    pub fn extract(
87        &self,
88        _text: &str,
89        dependency_tree: &[DependencyRelation],
90    ) -> Result<Vec<(String, String, String)>> {
91        // Build head → {relation → [dependents]} index
92        let mut head_map: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
93        for arc in dependency_tree {
94            head_map
95                .entry(arc.head.as_str())
96                .or_default()
97                .push((arc.relation.as_str(), arc.dependent.as_str()));
98        }
99
100        let mut triples = Vec::new();
101
102        // Collect unique head words to avoid duplicate processing
103        let mut seen_heads = std::collections::HashSet::new();
104        let head_words: Vec<String> = dependency_tree
105            .iter()
106            .filter_map(|arc| {
107                if seen_heads.insert(arc.head.clone()) {
108                    Some(arc.head.clone())
109                } else {
110                    None
111                }
112            })
113            .collect();
114
115        for head in &head_words {
116            let Some(deps) = head_map.get(head.as_str()) else {
117                continue;
118            };
119
120            // Find nsubj / obj pairs under the same head
121            let subjects: Vec<&str> = deps
122                .iter()
123                .filter(|(rel, _)| *rel == "nsubj" || *rel == "nsubjpass")
124                .map(|(_, dep)| *dep)
125                .collect();
126
127            let objects: Vec<&str> = deps
128                .iter()
129                .filter(|(rel, _)| {
130                    *rel == "obj"
131                        || *rel == "dobj"
132                        || *rel == "iobj"
133                        || *rel == "obl"
134                        || *rel == "xobj"
135                })
136                .map(|(_, dep)| *dep)
137                .collect();
138
139            if subjects.is_empty() || objects.is_empty() {
140                continue;
141            }
142
143            // Check against patterns
144            for subj in &subjects {
145                for obj in &objects {
146                    for pattern in &self.patterns {
147                        // Check predicate filter
148                        if !pattern.predicate.is_empty() {
149                            let head_lower = head.to_lowercase();
150                            if !pattern
151                                .predicate
152                                .iter()
153                                .any(|p| p.to_lowercase() == head_lower)
154                            {
155                                continue;
156                            }
157                        }
158                        triples.push((subj.to_string(), pattern.label.clone(), obj.to_string()));
159                    }
160                }
161            }
162        }
163
164        Ok(triples)
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Simple coreference resolver
170// ---------------------------------------------------------------------------
171
172/// A minimal pronoun → antecedent resolver using recency heuristics.
173///
174/// It maintains a short-term memory of recently seen noun phrases and
175/// replaces pronouns with their most likely antecedent based on gender/number
176/// agreement and distance.
177pub struct CorefResolver {
178    /// Recently seen noun phrases: `(text, is_plural, gender)`.
179    history: Vec<(String, bool, PronounGender)>,
180    /// How many candidates to keep in memory.
181    window: usize,
182}
183
184/// Coarse gender category for pronoun resolution.
185#[derive(Debug, Clone, Copy, PartialEq, Eq)]
186pub enum PronounGender {
187    /// Masculine pronoun (he, him, his).
188    Masculine,
189    /// Feminine pronoun (she, her, hers).
190    Feminine,
191    /// Gender-neutral pronoun (it, they singular).
192    Neutral,
193    /// Plural pronoun (they, them, their).
194    Plural,
195    /// Gender could not be determined.
196    Unknown,
197}
198
199impl CorefResolver {
200    /// Create a new resolver with a recency window of `window` phrases.
201    pub fn new(window: usize) -> CorefResolver {
202        CorefResolver {
203            history: Vec::new(),
204            window,
205        }
206    }
207
208    /// Register a noun phrase as a potential antecedent.
209    pub fn register(
210        &mut self,
211        noun_phrase: impl Into<String>,
212        is_plural: bool,
213        gender: PronounGender,
214    ) {
215        if self.history.len() >= self.window {
216            self.history.remove(0);
217        }
218        self.history.push((noun_phrase.into(), is_plural, gender));
219    }
220
221    /// Resolve a pronoun to its most recent compatible antecedent.
222    ///
223    /// Returns `None` if no compatible antecedent is found in the window.
224    pub fn resolve(&self, pronoun: &str) -> Option<&str> {
225        let (target_plural, target_gender) = pronoun_attributes(pronoun)?;
226        // Search from most recent backwards
227        for (np, is_plural, gender) in self.history.iter().rev() {
228            if *is_plural != target_plural {
229                continue;
230            }
231            if target_gender != PronounGender::Unknown
232                && *gender != PronounGender::Unknown
233                && *gender != target_gender
234            {
235                continue;
236            }
237            return Some(np.as_str());
238        }
239        None
240    }
241}
242
243/// Return (is_plural, gender) for common English pronouns.
244fn pronoun_attributes(pronoun: &str) -> Option<(bool, PronounGender)> {
245    match pronoun.to_lowercase().as_str() {
246        "he" | "him" | "his" | "himself" => Some((false, PronounGender::Masculine)),
247        "she" | "her" | "hers" | "herself" => Some((false, PronounGender::Feminine)),
248        "it" | "its" | "itself" => Some((false, PronounGender::Neutral)),
249        "they" | "them" | "their" | "theirs" | "themselves" => Some((true, PronounGender::Plural)),
250        "we" | "us" | "our" | "ours" | "ourselves" => Some((true, PronounGender::Plural)),
251        _ => None,
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    fn make_svo_tree() -> Vec<DependencyRelation> {
260        vec![
261            DependencyRelation {
262                head: "loves".to_string(),
263                relation: "nsubj".to_string(),
264                dependent: "John".to_string(),
265            },
266            DependencyRelation {
267                head: "loves".to_string(),
268                relation: "obj".to_string(),
269                dependent: "Mary".to_string(),
270            },
271        ]
272    }
273
274    #[test]
275    fn test_svo_extraction() {
276        let extractor = DependencyRelationExtractor::with_svo_defaults();
277        let triples = extractor
278            .extract("John loves Mary", &make_svo_tree())
279            .expect("extract failed");
280        assert_eq!(triples.len(), 1);
281        assert_eq!(triples[0].0, "John");
282        assert_eq!(triples[0].2, "Mary");
283    }
284
285    #[test]
286    fn test_no_triples_without_subject() {
287        let extractor = DependencyRelationExtractor::with_svo_defaults();
288        let tree = vec![DependencyRelation {
289            head: "runs".to_string(),
290            relation: "obj".to_string(),
291            dependent: "race".to_string(),
292        }];
293        let triples = extractor
294            .extract("runs race", &tree)
295            .expect("extract failed");
296        assert!(triples.is_empty());
297    }
298
299    #[test]
300    fn test_predicate_filter() {
301        let mut extractor = DependencyRelationExtractor::new();
302        extractor.add_pattern(RelationPattern {
303            subject_pos: None,
304            predicate: vec!["loves".to_string()],
305            object_pos: None,
306            label: "LOVE".to_string(),
307        });
308
309        // Matching predicate
310        let triples = extractor
311            .extract("John loves Mary", &make_svo_tree())
312            .expect("extract failed");
313        assert_eq!(triples.len(), 1);
314        assert_eq!(triples[0].1, "LOVE");
315
316        // Non-matching predicate — extractor has only "loves" pattern
317        let tree2 = vec![
318            DependencyRelation {
319                head: "hates".to_string(),
320                relation: "nsubj".to_string(),
321                dependent: "John".to_string(),
322            },
323            DependencyRelation {
324                head: "hates".to_string(),
325                relation: "obj".to_string(),
326                dependent: "Mary".to_string(),
327            },
328        ];
329        let triples2 = extractor
330            .extract("John hates Mary", &tree2)
331            .expect("extract failed");
332        assert!(triples2.is_empty());
333    }
334
335    #[test]
336    fn test_coref_resolver_basic() {
337        let mut resolver = CorefResolver::new(5);
338        resolver.register("John Smith", false, PronounGender::Masculine);
339        let antecedent = resolver.resolve("he");
340        assert_eq!(antecedent, Some("John Smith"));
341    }
342
343    #[test]
344    fn test_coref_resolver_gender_mismatch() {
345        let mut resolver = CorefResolver::new(5);
346        resolver.register("Alice", false, PronounGender::Feminine);
347        // "he" should NOT resolve to Alice
348        let antecedent = resolver.resolve("he");
349        assert!(antecedent.is_none());
350    }
351
352    #[test]
353    fn test_coref_resolver_recency() {
354        let mut resolver = CorefResolver::new(5);
355        resolver.register("Bob", false, PronounGender::Masculine);
356        resolver.register("Alice", false, PronounGender::Feminine);
357        // "he" should resolve to Bob (most recent masculine)
358        let antecedent = resolver.resolve("he");
359        assert_eq!(antecedent, Some("Bob"));
360    }
361
362    #[test]
363    fn test_coref_resolver_window_eviction() {
364        let mut resolver = CorefResolver::new(2);
365        resolver.register("Old Guy", false, PronounGender::Masculine);
366        resolver.register("Middle Person", false, PronounGender::Unknown);
367        resolver.register("New Person", false, PronounGender::Unknown);
368        // "Old Guy" should have been evicted (window=2)
369        let names: Vec<&str> = resolver
370            .history
371            .iter()
372            .map(|(n, _, _)| n.as_str())
373            .collect();
374        assert!(!names.contains(&"Old Guy"));
375    }
376}