scirs2_text/information_extraction/
extractors.rs

1//! Core information extractors for named entities, key phrases, and patterns
2
3use super::entities::{Entity, EntityType};
4use super::patterns::*;
5use crate::error::Result;
6use crate::tokenize::Tokenizer;
7use regex::Regex;
8use std::collections::{HashMap, HashSet};
9
10/// Simple rule-based named entity recognizer
11pub struct RuleBasedNER {
12    person_names: HashSet<String>,
13    organizations: HashSet<String>,
14    locations: HashSet<String>,
15    custom_patterns: HashMap<String, Regex>,
16}
17
18impl RuleBasedNER {
19    /// Create a new rule-based NER
20    pub fn new() -> Self {
21        Self {
22            person_names: HashSet::new(),
23            organizations: HashSet::new(),
24            locations: HashSet::new(),
25            custom_patterns: HashMap::new(),
26        }
27    }
28
29    /// Create a new rule-based NER with basic knowledge
30    pub fn with_basic_knowledge() -> Self {
31        let mut ner = Self::new();
32
33        // Add common person names and titles
34        ner.add_person_names(vec![
35            "Tim Cook".to_string(),
36            "Satya Nadella".to_string(),
37            "Elon Musk".to_string(),
38            "Jeff Bezos".to_string(),
39            "Mark Zuckerberg".to_string(),
40            "Bill Gates".to_string(),
41            "Sundar Pichai".to_string(),
42            "Andy Jassy".to_string(),
43            "Susan Wojcicki".to_string(),
44            "Reed Hastings".to_string(),
45            "Jensen Huang".to_string(),
46            "Lisa Su".to_string(),
47        ]);
48
49        // Add common organizations
50        ner.add_organizations(vec![
51            "Apple Inc.".to_string(),
52            "Apple".to_string(),
53            "Microsoft Corporation".to_string(),
54            "Microsoft".to_string(),
55            "Google".to_string(),
56            "Alphabet Inc.".to_string(),
57            "Amazon".to_string(),
58            "Meta".to_string(),
59            "Facebook".to_string(),
60            "Tesla".to_string(),
61            "Netflix".to_string(),
62            "NVIDIA".to_string(),
63            "AMD".to_string(),
64            "Intel".to_string(),
65            "IBM".to_string(),
66            "Oracle".to_string(),
67            "Salesforce".to_string(),
68        ]);
69
70        // Add common locations
71        ner.add_locations(vec![
72            "San Francisco".to_string(),
73            "New York".to_string(),
74            "London".to_string(),
75            "Tokyo".to_string(),
76            "Paris".to_string(),
77            "Berlin".to_string(),
78            "Sydney".to_string(),
79            "Toronto".to_string(),
80            "Singapore".to_string(),
81            "Hong Kong".to_string(),
82            "Los Angeles".to_string(),
83            "Chicago".to_string(),
84            "Boston".to_string(),
85            "Seattle".to_string(),
86            "Austin".to_string(),
87            "Denver".to_string(),
88            "California".to_string(),
89            "New York".to_string(),
90            "Texas".to_string(),
91            "Washington".to_string(),
92            "Florida".to_string(),
93        ]);
94
95        ner
96    }
97
98    /// Add person names to the recognizer
99    pub fn add_person_names<I: IntoIterator<Item = String>>(&mut self, names: I) {
100        self.person_names.extend(names);
101    }
102
103    /// Add organization names
104    pub fn add_organizations<I: IntoIterator<Item = String>>(&mut self, orgs: I) {
105        self.organizations.extend(orgs);
106    }
107
108    /// Add location names
109    pub fn add_locations<I: IntoIterator<Item = String>>(&mut self, locations: I) {
110        self.locations.extend(locations);
111    }
112
113    /// Add custom pattern for entity extraction
114    pub fn add_custom_pattern(&mut self, name: String, pattern: Regex) {
115        self.custom_patterns.insert(name, pattern);
116    }
117
118    /// Extract entities from text
119    pub fn extract_entities(&self, text: &str) -> Result<Vec<Entity>> {
120        let mut entities = Vec::new();
121
122        // Extract regex-based entities
123        entities.extend(self.extract_pattern_entities(text, &EMAIL_PATTERN, EntityType::Email)?);
124        entities.extend(self.extract_pattern_entities(text, &URL_PATTERN, EntityType::Url)?);
125        entities.extend(self.extract_pattern_entities(text, &PHONE_PATTERN, EntityType::Phone)?);
126        entities.extend(self.extract_pattern_entities(text, &DATE_PATTERN, EntityType::Date)?);
127        entities.extend(self.extract_pattern_entities(text, &TIME_PATTERN, EntityType::Time)?);
128        entities.extend(self.extract_pattern_entities(text, &MONEY_PATTERN, EntityType::Money)?);
129        entities.extend(self.extract_pattern_entities(
130            text,
131            &PERCENTAGE_PATTERN,
132            EntityType::Percentage,
133        )?);
134
135        // Extract custom patterns
136        for (name, pattern) in &self.custom_patterns {
137            entities.extend(self.extract_pattern_entities(
138                text,
139                pattern,
140                EntityType::Custom(name.clone()),
141            )?);
142        }
143
144        // Extract dictionary-based entities
145        entities.extend(self.extract_dictionary_entities(text)?);
146
147        // Sort by start position
148        entities.sort_by_key(|e| e.start);
149
150        Ok(entities)
151    }
152
153    /// Extract entities using regex patterns
154    fn extract_pattern_entities(
155        &self,
156        text: &str,
157        pattern: &Regex,
158        entity_type: EntityType,
159    ) -> Result<Vec<Entity>> {
160        let mut entities = Vec::new();
161
162        for mat in pattern.find_iter(text) {
163            entities.push(Entity {
164                text: mat.as_str().to_string(),
165                entity_type: entity_type.clone(),
166                start: mat.start(),
167                end: mat.end(),
168                confidence: 1.0, // High confidence for pattern matches
169            });
170        }
171
172        Ok(entities)
173    }
174
175    /// Extract dictionary-based entities
176    fn extract_dictionary_entities(&self, text: &str) -> Result<Vec<Entity>> {
177        let mut entities = Vec::new();
178        let text_lower = text.to_lowercase();
179
180        // Check for multi-word entities first (e.g., "Apple Inc.", "Tim Cook")
181        for entity_name in &self.person_names {
182            let entity_lower = entity_name.to_lowercase();
183            if let Some(start) = text_lower.find(&entity_lower) {
184                // Verify word boundaries
185                let at_word_start =
186                    start == 0 || !text.chars().nth(start - 1).unwrap_or(' ').is_alphanumeric();
187                let at_word_end = start + entity_name.len() >= text.len()
188                    || !text
189                        .chars()
190                        .nth(start + entity_name.len())
191                        .unwrap_or(' ')
192                        .is_alphanumeric();
193
194                if at_word_start && at_word_end {
195                    entities.push(Entity {
196                        text: text[start..start + entity_name.len()].to_string(),
197                        entity_type: EntityType::Person,
198                        start,
199                        end: start + entity_name.len(),
200                        confidence: 0.9,
201                    });
202                }
203            }
204        }
205
206        for entity_name in &self.organizations {
207            let entity_lower = entity_name.to_lowercase();
208            if let Some(start) = text_lower.find(&entity_lower) {
209                // Verify word boundaries
210                let at_word_start =
211                    start == 0 || !text.chars().nth(start - 1).unwrap_or(' ').is_alphanumeric();
212                let at_word_end = start + entity_name.len() >= text.len()
213                    || !text
214                        .chars()
215                        .nth(start + entity_name.len())
216                        .unwrap_or(' ')
217                        .is_alphanumeric();
218
219                if at_word_start && at_word_end {
220                    entities.push(Entity {
221                        text: text[start..start + entity_name.len()].to_string(),
222                        entity_type: EntityType::Organization,
223                        start,
224                        end: start + entity_name.len(),
225                        confidence: 0.9,
226                    });
227                }
228            }
229        }
230
231        for entity_name in &self.locations {
232            let entity_lower = entity_name.to_lowercase();
233            if let Some(start) = text_lower.find(&entity_lower) {
234                // Verify word boundaries
235                let at_word_start =
236                    start == 0 || !text.chars().nth(start - 1).unwrap_or(' ').is_alphanumeric();
237                let at_word_end = start + entity_name.len() >= text.len()
238                    || !text
239                        .chars()
240                        .nth(start + entity_name.len())
241                        .unwrap_or(' ')
242                        .is_alphanumeric();
243
244                if at_word_start && at_word_end {
245                    entities.push(Entity {
246                        text: text[start..start + entity_name.len()].to_string(),
247                        entity_type: EntityType::Location,
248                        start,
249                        end: start + entity_name.len(),
250                        confidence: 0.9,
251                    });
252                }
253            }
254        }
255
256        Ok(entities)
257    }
258}
259
260impl Default for RuleBasedNER {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266/// Key phrase extractor using statistical methods
267pub struct KeyPhraseExtractor {
268    min_phrase_length: usize,
269    max_phrase_length: usize,
270    min_frequency: usize,
271}
272
273impl KeyPhraseExtractor {
274    /// Create a new key phrase extractor
275    pub fn new() -> Self {
276        Self {
277            min_phrase_length: 1,
278            max_phrase_length: 3,
279            min_frequency: 2,
280        }
281    }
282
283    /// Set minimum phrase length
284    pub fn with_min_length(mut self, length: usize) -> Self {
285        self.min_phrase_length = length;
286        self
287    }
288
289    /// Set maximum phrase length
290    pub fn with_max_length(mut self, length: usize) -> Self {
291        self.max_phrase_length = length;
292        self
293    }
294
295    /// Set minimum frequency threshold
296    pub fn with_min_frequency(mut self, freq: usize) -> Self {
297        self.min_frequency = freq;
298        self
299    }
300
301    /// Extract key phrases from text
302    pub fn extract(&self, text: &str, tokenizer: &dyn Tokenizer) -> Result<Vec<(String, f64)>> {
303        let tokens = tokenizer.tokenize(text)?;
304        let mut phrase_counts: HashMap<String, usize> = HashMap::new();
305
306        // Generate n-grams
307        for n in self.min_phrase_length..=self.max_phrase_length {
308            if tokens.len() >= n {
309                for i in 0..=tokens.len() - n {
310                    let phrase = tokens[i..i + n].join(" ");
311                    *phrase_counts.entry(phrase).or_insert(0) += 1;
312                }
313            }
314        }
315
316        // Filter by frequency and calculate scores
317        let mut phrases: Vec<(String, f64)> = phrase_counts
318            .into_iter()
319            .filter(|(_, count)| *count >= self.min_frequency)
320            .map(|(phrase, count)| {
321                // Simple scoring: frequency * length
322                let score = count as f64 * (phrase.split_whitespace().count() as f64).sqrt();
323                (phrase, score)
324            })
325            .collect();
326
327        // Sort by score descending
328        phrases.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
329
330        Ok(phrases)
331    }
332}
333
334impl Default for KeyPhraseExtractor {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340/// Pattern-based information extractor
341pub struct PatternExtractor {
342    patterns: Vec<(String, Regex)>,
343}
344
345impl PatternExtractor {
346    /// Create a new pattern extractor
347    pub fn new() -> Self {
348        Self {
349            patterns: Vec::new(),
350        }
351    }
352
353    /// Add a named pattern
354    pub fn add_pattern(&mut self, name: String, pattern: Regex) {
355        self.patterns.push((name, pattern));
356    }
357
358    /// Extract information matching patterns
359    pub fn extract(&self, text: &str) -> Result<HashMap<String, Vec<String>>> {
360        let mut results: HashMap<String, Vec<String>> = HashMap::new();
361
362        for (name, pattern) in &self.patterns {
363            let mut matches = Vec::new();
364
365            for mat in pattern.find_iter(text) {
366                matches.push(mat.as_str().to_string());
367            }
368
369            if !matches.is_empty() {
370                results.insert(name.clone(), matches);
371            }
372        }
373
374        Ok(results)
375    }
376
377    /// Extract with capture groups
378    pub fn extract_with_groups(
379        &self,
380        text: &str,
381    ) -> Result<HashMap<String, Vec<HashMap<String, String>>>> {
382        let mut results: HashMap<String, Vec<HashMap<String, String>>> = HashMap::new();
383
384        for (name, pattern) in &self.patterns {
385            let mut matches = Vec::new();
386
387            for caps in pattern.captures_iter(text) {
388                let mut groups = HashMap::new();
389
390                // Add full match
391                if let Some(full_match) = caps.get(0) {
392                    groups.insert("full".to_string(), full_match.as_str().to_string());
393                }
394
395                // Add numbered groups
396                for i in 1..caps.len() {
397                    if let Some(group) = caps.get(i) {
398                        groups.insert(format!("group{i}"), group.as_str().to_string());
399                    }
400                }
401
402                // Add named groups if any
403                for name in pattern.capture_names().flatten() {
404                    if let Some(group) = caps.name(name) {
405                        groups.insert(name.to_string(), group.as_str().to_string());
406                    }
407                }
408
409                matches.push(groups);
410            }
411
412            if !matches.is_empty() {
413                results.insert(name.clone(), matches);
414            }
415        }
416
417        Ok(results)
418    }
419}
420
421impl Default for PatternExtractor {
422    fn default() -> Self {
423        Self::new()
424    }
425}