Skip to main content

scirs2_text/information_extraction/
seq_patterns.rs

1//! Pattern-based information extraction using token sequences.
2//!
3//! Provides a flexible rule engine operating on tokenized input.  Each pattern
4//! is a sequence of [`PatternElement`]s that are matched left-to-right against
5//! a slice of [`Token`]s.  Gap elements allow for skip-regions of bounded
6//! length.
7
8use regex::Regex;
9
10use crate::error::{Result, TextError};
11
12// ---------------------------------------------------------------------------
13// Token type
14// ---------------------------------------------------------------------------
15
16/// A single token with optional part-of-speech and lemma annotations.
17#[derive(Debug, Clone)]
18pub struct Token {
19    /// Surface form of the token.
20    pub text: String,
21    /// Optional part-of-speech tag (e.g. "NN", "VBZ").
22    pub pos: Option<String>,
23    /// Optional base/lemma form.
24    pub lemma: Option<String>,
25}
26
27impl Token {
28    /// Convenience constructor.
29    pub fn new(text: impl Into<String>) -> Token {
30        Token {
31            text: text.into(),
32            pos: None,
33            lemma: None,
34        }
35    }
36
37    /// Builder: set POS tag.
38    pub fn with_pos(mut self, pos: impl Into<String>) -> Token {
39        self.pos = Some(pos.into());
40        self
41    }
42
43    /// Builder: set lemma.
44    pub fn with_lemma(mut self, lemma: impl Into<String>) -> Token {
45        self.lemma = Some(lemma.into());
46        self
47    }
48}
49
50// ---------------------------------------------------------------------------
51// Pattern element
52// ---------------------------------------------------------------------------
53
54/// A single element in a named extraction pattern.
55#[derive(Debug, Clone)]
56pub enum PatternElement {
57    /// Match a fixed string (case-insensitive).
58    Literal(String),
59    /// Match a token whose POS tag equals the given string.
60    PoS(String),
61    /// Match a token whose text matches the given regex string.
62    Regex(String),
63    /// Match any single token.
64    Any,
65    /// Match between `min` and `max` tokens (inclusive), skipping them.
66    Gap {
67        /// Minimum number of tokens to skip.
68        min: usize,
69        /// Maximum number of tokens to skip.
70        max: usize,
71    },
72}
73
74// ---------------------------------------------------------------------------
75// Pattern and result types
76// ---------------------------------------------------------------------------
77
78/// A sequence of [`PatternElement`]s representing one extraction rule.
79#[derive(Debug, Clone)]
80pub struct Pattern {
81    /// Ordered list of elements that must match in sequence.
82    pub template: Vec<PatternElement>,
83}
84
85impl Pattern {
86    /// Construct a pattern from a template.
87    pub fn new(template: Vec<PatternElement>) -> Pattern {
88        Pattern { template }
89    }
90}
91
92/// A single extraction match.
93#[derive(Debug, Clone)]
94pub struct Match {
95    /// Name of the pattern that produced this match.
96    pub pattern_name: String,
97    /// Index of the first token in the match.
98    pub start: usize,
99    /// Index one past the last token in the match.
100    pub end: usize,
101    /// Captured text fragments (one per non-Gap, non-Any element).
102    pub groups: Vec<String>,
103}
104
105// ---------------------------------------------------------------------------
106// PatternMatcher
107// ---------------------------------------------------------------------------
108
109/// A collection of named extraction patterns applied to token sequences.
110#[derive(Default)]
111pub struct PatternMatcher {
112    patterns: Vec<(String, Pattern)>,
113    /// Compiled regexes cached by pattern string.
114    regex_cache: std::collections::HashMap<String, Regex>,
115}
116
117impl PatternMatcher {
118    /// Create an empty matcher.
119    pub fn new() -> PatternMatcher {
120        PatternMatcher::default()
121    }
122
123    /// Add a named pattern.
124    pub fn add_pattern(&mut self, name: impl Into<String>, pattern: Pattern) -> Result<()> {
125        // Pre-compile any Regex elements
126        for elem in &pattern.template {
127            if let PatternElement::Regex(re_str) = elem {
128                if !self.regex_cache.contains_key(re_str) {
129                    let compiled = Regex::new(re_str).map_err(|e| {
130                        TextError::InvalidInput(format!("Bad regex '{}': {}", re_str, e))
131                    })?;
132                    self.regex_cache.insert(re_str.clone(), compiled);
133                }
134            }
135        }
136        self.patterns.push((name.into(), pattern));
137        Ok(())
138    }
139
140    /// Find all pattern matches in the token sequence.
141    ///
142    /// Tries every starting position for every pattern; returns all matches
143    /// (possibly overlapping).
144    pub fn match_all(&self, tokens: &[Token]) -> Vec<Match> {
145        let mut results = Vec::new();
146        for (name, pattern) in &self.patterns {
147            for start in 0..tokens.len() {
148                if let Some((end, groups)) = self.try_match(pattern, tokens, start) {
149                    results.push(Match {
150                        pattern_name: name.clone(),
151                        start,
152                        end,
153                        groups,
154                    });
155                }
156            }
157        }
158        results
159    }
160
161    /// Attempt to match `pattern` starting at `start` in `tokens`.
162    ///
163    /// Returns `Some((end_exclusive, captured_groups))` on success.
164    fn try_match(
165        &self,
166        pattern: &Pattern,
167        tokens: &[Token],
168        start: usize,
169    ) -> Option<(usize, Vec<String>)> {
170        self.try_match_from(pattern, tokens, start, 0, Vec::new())
171    }
172
173    /// Recursive helper that matches pattern elements from `elem_idx` onwards,
174    /// starting at token position `pos`, with previously captured `groups`.
175    fn try_match_from(
176        &self,
177        pattern: &Pattern,
178        tokens: &[Token],
179        pos: usize,
180        elem_idx: usize,
181        groups: Vec<String>,
182    ) -> Option<(usize, Vec<String>)> {
183        // All elements matched
184        if elem_idx >= pattern.template.len() {
185            return Some((pos, groups));
186        }
187
188        let elem = &pattern.template[elem_idx];
189
190        match elem {
191            PatternElement::Literal(s) => {
192                if pos >= tokens.len() {
193                    return None;
194                }
195                if tokens[pos].text.to_lowercase() != s.to_lowercase() {
196                    return None;
197                }
198                let mut new_groups = groups;
199                new_groups.push(tokens[pos].text.clone());
200                self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
201            }
202            PatternElement::PoS(tag) => {
203                if pos >= tokens.len() {
204                    return None;
205                }
206                let tok_pos = tokens[pos].pos.as_deref().unwrap_or("");
207                if tok_pos != tag.as_str() {
208                    return None;
209                }
210                let mut new_groups = groups;
211                new_groups.push(tokens[pos].text.clone());
212                self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
213            }
214            PatternElement::Regex(re_str) => {
215                if pos >= tokens.len() {
216                    return None;
217                }
218                let re = self.regex_cache.get(re_str)?;
219                if !re.is_match(&tokens[pos].text) {
220                    return None;
221                }
222                let mut new_groups = groups;
223                new_groups.push(tokens[pos].text.clone());
224                self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
225            }
226            PatternElement::Any => {
227                if pos >= tokens.len() {
228                    return None;
229                }
230                let mut new_groups = groups;
231                new_groups.push(tokens[pos].text.clone());
232                self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
233            }
234            PatternElement::Gap { min, max } => {
235                // Try each skip length from min to max; return the first
236                // that allows the rest of the pattern to match.
237                for skip in *min..=*max {
238                    let new_pos = pos + skip;
239                    if new_pos > tokens.len() {
240                        break;
241                    }
242                    if let Some(result) =
243                        self.try_match_from(pattern, tokens, new_pos, elem_idx + 1, groups.clone())
244                    {
245                        return Some(result);
246                    }
247                }
248                None
249            }
250        }
251    }
252}
253
254// ---------------------------------------------------------------------------
255// Built-in rule-based NER patterns
256// ---------------------------------------------------------------------------
257
258/// Build a `PatternMatcher` pre-loaded with common NER patterns.
259///
260/// Includes patterns for dates, money, email, URL, and phone numbers.
261pub fn build_ner_pattern_matcher() -> Result<PatternMatcher> {
262    let mut matcher = PatternMatcher::new();
263
264    // Date: MM/DD/YYYY or YYYY-MM-DD
265    matcher.add_pattern(
266        "DATE",
267        Pattern::new(vec![PatternElement::Regex(
268            r"(?:(?:0?[1-9]|1[0-2])[\/\-](?:0?[1-9]|[12][0-9]|3[01])[\/\-](?:19|20)?\d{2}|(?:19|20)\d{2}[\/\-](?:0?[1-9]|1[0-2])[\/\-](?:0?[1-9]|[12][0-9]|3[01]))".to_string(),
269        )]),
270    )?;
271
272    // Money: $1234.56 or similar
273    matcher.add_pattern(
274        "MONEY",
275        Pattern::new(vec![PatternElement::Regex(
276            r"\$[0-9]+(?:\.[0-9]+)?".to_string(),
277        )]),
278    )?;
279
280    // Email
281    matcher.add_pattern(
282        "EMAIL",
283        Pattern::new(vec![PatternElement::Regex(
284            r"[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}".to_string(),
285        )]),
286    )?;
287
288    // URL
289    matcher.add_pattern(
290        "URL",
291        Pattern::new(vec![PatternElement::Regex(r"https?://[^\s]+".to_string())]),
292    )?;
293
294    // Phone: (NNN) NNN-NNNN or NNN-NNN-NNNN
295    matcher.add_pattern(
296        "PHONE",
297        Pattern::new(vec![PatternElement::Regex(
298            r"(?:\+?1[\-.\s]?)?\(?\d{3}\)?[\-.\s]\d{3}[\-.\s]\d{4}".to_string(),
299        )]),
300    )?;
301
302    Ok(matcher)
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    fn tokenize_simple(text: &str) -> Vec<Token> {
310        text.split_whitespace().map(Token::new).collect()
311    }
312
313    #[test]
314    fn test_literal_match() {
315        let mut matcher = PatternMatcher::new();
316        matcher
317            .add_pattern(
318                "greeting",
319                Pattern::new(vec![
320                    PatternElement::Literal("hello".to_string()),
321                    PatternElement::Literal("world".to_string()),
322                ]),
323            )
324            .expect("add_pattern failed");
325
326        let tokens = tokenize_simple("hello world foo bar");
327        let matches = matcher.match_all(&tokens);
328        assert_eq!(matches.len(), 1);
329        assert_eq!(matches[0].start, 0);
330        assert_eq!(matches[0].end, 2);
331    }
332
333    #[test]
334    fn test_pos_match() {
335        let mut matcher = PatternMatcher::new();
336        matcher
337            .add_pattern(
338                "dt_nn",
339                Pattern::new(vec![
340                    PatternElement::PoS("DT".to_string()),
341                    PatternElement::PoS("NN".to_string()),
342                ]),
343            )
344            .expect("add_pattern failed");
345
346        let tokens = vec![
347            Token::new("the").with_pos("DT"),
348            Token::new("dog").with_pos("NN"),
349            Token::new("runs").with_pos("VBZ"),
350        ];
351        let matches = matcher.match_all(&tokens);
352        assert_eq!(matches.len(), 1);
353        assert_eq!(matches[0].groups, vec!["the", "dog"]);
354    }
355
356    #[test]
357    fn test_regex_match() {
358        let mut matcher = PatternMatcher::new();
359        matcher
360            .add_pattern(
361                "money",
362                Pattern::new(vec![PatternElement::Regex(
363                    r"\$[0-9]+(?:\.[0-9]+)?".to_string(),
364                )]),
365            )
366            .expect("add_pattern failed");
367
368        let tokens = tokenize_simple("costs $29.99 shipping $5");
369        let matches = matcher.match_all(&tokens);
370        assert_eq!(matches.len(), 2);
371    }
372
373    #[test]
374    fn test_any_match() {
375        let mut matcher = PatternMatcher::new();
376        matcher
377            .add_pattern(
378                "any_word",
379                Pattern::new(vec![
380                    PatternElement::Literal("the".to_string()),
381                    PatternElement::Any,
382                ]),
383            )
384            .expect("add_pattern failed");
385
386        let tokens = tokenize_simple("the cat sat on the mat");
387        let matches = matcher.match_all(&tokens);
388        // "the cat" and "the mat"
389        assert!(matches.len() >= 2);
390    }
391
392    #[test]
393    fn test_gap_match() {
394        let mut matcher = PatternMatcher::new();
395        matcher
396            .add_pattern(
397                "verb_phrase",
398                Pattern::new(vec![
399                    PatternElement::Literal("john".to_string()),
400                    PatternElement::Gap { min: 0, max: 2 },
401                    PatternElement::Literal("mary".to_string()),
402                ]),
403            )
404            .expect("add_pattern failed");
405
406        let tokens = tokenize_simple("john loves mary");
407        let matches = matcher.match_all(&tokens);
408        assert!(!matches.is_empty());
409    }
410
411    #[test]
412    fn test_ner_patterns_email() {
413        let matcher = build_ner_pattern_matcher().expect("build failed");
414        let tokens = tokenize_simple("contact user@example.com for info");
415        let matches = matcher.match_all(&tokens);
416        assert!(matches.iter().any(|m| m.pattern_name == "EMAIL"));
417    }
418
419    #[test]
420    fn test_ner_patterns_money() {
421        let matcher = build_ner_pattern_matcher().expect("build failed");
422        let tokens = tokenize_simple("costs $100 today");
423        let matches = matcher.match_all(&tokens);
424        assert!(matches.iter().any(|m| m.pattern_name == "MONEY"));
425    }
426
427    #[test]
428    fn test_bad_regex_error() {
429        let mut matcher = PatternMatcher::new();
430        let result = matcher.add_pattern(
431            "bad",
432            Pattern::new(vec![PatternElement::Regex("[invalid".to_string())]),
433        );
434        assert!(result.is_err());
435    }
436}