scirs2_text/information_extraction/
extractors.rs1use super::entities::{Entity, EntityType};
4use super::patterns::*;
5use crate::error::Result;
6use crate::tokenize::Tokenizer;
7use regex::Regex;
8use std::collections::{HashMap, HashSet};
9
10pub struct RuleBasedNER {
12 person_names: HashSet<String>,
13 organizations: HashSet<String>,
14 locations: HashSet<String>,
15 custom_patterns: HashMap<String, Regex>,
16}
17
18impl RuleBasedNER {
19 pub fn new() -> Self {
21 Self {
22 person_names: HashSet::new(),
23 organizations: HashSet::new(),
24 locations: HashSet::new(),
25 custom_patterns: HashMap::new(),
26 }
27 }
28
29 pub fn with_basic_knowledge() -> Self {
31 let mut ner = Self::new();
32
33 ner.add_person_names(vec![
35 "Tim Cook".to_string(),
36 "Satya Nadella".to_string(),
37 "Elon Musk".to_string(),
38 "Jeff Bezos".to_string(),
39 "Mark Zuckerberg".to_string(),
40 "Bill Gates".to_string(),
41 "Sundar Pichai".to_string(),
42 "Andy Jassy".to_string(),
43 "Susan Wojcicki".to_string(),
44 "Reed Hastings".to_string(),
45 "Jensen Huang".to_string(),
46 "Lisa Su".to_string(),
47 ]);
48
49 ner.add_organizations(vec![
51 "Apple Inc.".to_string(),
52 "Apple".to_string(),
53 "Microsoft Corporation".to_string(),
54 "Microsoft".to_string(),
55 "Google".to_string(),
56 "Alphabet Inc.".to_string(),
57 "Amazon".to_string(),
58 "Meta".to_string(),
59 "Facebook".to_string(),
60 "Tesla".to_string(),
61 "Netflix".to_string(),
62 "NVIDIA".to_string(),
63 "AMD".to_string(),
64 "Intel".to_string(),
65 "IBM".to_string(),
66 "Oracle".to_string(),
67 "Salesforce".to_string(),
68 ]);
69
70 ner.add_locations(vec![
72 "San Francisco".to_string(),
73 "New York".to_string(),
74 "London".to_string(),
75 "Tokyo".to_string(),
76 "Paris".to_string(),
77 "Berlin".to_string(),
78 "Sydney".to_string(),
79 "Toronto".to_string(),
80 "Singapore".to_string(),
81 "Hong Kong".to_string(),
82 "Los Angeles".to_string(),
83 "Chicago".to_string(),
84 "Boston".to_string(),
85 "Seattle".to_string(),
86 "Austin".to_string(),
87 "Denver".to_string(),
88 "California".to_string(),
89 "New York".to_string(),
90 "Texas".to_string(),
91 "Washington".to_string(),
92 "Florida".to_string(),
93 ]);
94
95 ner
96 }
97
98 pub fn add_person_names<I: IntoIterator<Item = String>>(&mut self, names: I) {
100 self.person_names.extend(names);
101 }
102
103 pub fn add_organizations<I: IntoIterator<Item = String>>(&mut self, orgs: I) {
105 self.organizations.extend(orgs);
106 }
107
108 pub fn add_locations<I: IntoIterator<Item = String>>(&mut self, locations: I) {
110 self.locations.extend(locations);
111 }
112
113 pub fn add_custom_pattern(&mut self, name: String, pattern: Regex) {
115 self.custom_patterns.insert(name, pattern);
116 }
117
118 pub fn extract_entities(&self, text: &str) -> Result<Vec<Entity>> {
120 let mut entities = Vec::new();
121
122 entities.extend(self.extract_pattern_entities(text, &EMAIL_PATTERN, EntityType::Email)?);
124 entities.extend(self.extract_pattern_entities(text, &URL_PATTERN, EntityType::Url)?);
125 entities.extend(self.extract_pattern_entities(text, &PHONE_PATTERN, EntityType::Phone)?);
126 entities.extend(self.extract_pattern_entities(text, &DATE_PATTERN, EntityType::Date)?);
127 entities.extend(self.extract_pattern_entities(text, &TIME_PATTERN, EntityType::Time)?);
128 entities.extend(self.extract_pattern_entities(text, &MONEY_PATTERN, EntityType::Money)?);
129 entities.extend(self.extract_pattern_entities(
130 text,
131 &PERCENTAGE_PATTERN,
132 EntityType::Percentage,
133 )?);
134
135 for (name, pattern) in &self.custom_patterns {
137 entities.extend(self.extract_pattern_entities(
138 text,
139 pattern,
140 EntityType::Custom(name.clone()),
141 )?);
142 }
143
144 entities.extend(self.extract_dictionary_entities(text)?);
146
147 entities.sort_by_key(|e| e.start);
149
150 Ok(entities)
151 }
152
153 fn extract_pattern_entities(
155 &self,
156 text: &str,
157 pattern: &Regex,
158 entity_type: EntityType,
159 ) -> Result<Vec<Entity>> {
160 let mut entities = Vec::new();
161
162 for mat in pattern.find_iter(text) {
163 entities.push(Entity {
164 text: mat.as_str().to_string(),
165 entity_type: entity_type.clone(),
166 start: mat.start(),
167 end: mat.end(),
168 confidence: 1.0, });
170 }
171
172 Ok(entities)
173 }
174
175 fn extract_dictionary_entities(&self, text: &str) -> Result<Vec<Entity>> {
177 let mut entities = Vec::new();
178 let text_lower = text.to_lowercase();
179
180 for entity_name in &self.person_names {
182 let entity_lower = entity_name.to_lowercase();
183 if let Some(start) = text_lower.find(&entity_lower) {
184 let at_word_start =
186 start == 0 || !text.chars().nth(start - 1).unwrap_or(' ').is_alphanumeric();
187 let at_word_end = start + entity_name.len() >= text.len()
188 || !text
189 .chars()
190 .nth(start + entity_name.len())
191 .unwrap_or(' ')
192 .is_alphanumeric();
193
194 if at_word_start && at_word_end {
195 entities.push(Entity {
196 text: text[start..start + entity_name.len()].to_string(),
197 entity_type: EntityType::Person,
198 start,
199 end: start + entity_name.len(),
200 confidence: 0.9,
201 });
202 }
203 }
204 }
205
206 for entity_name in &self.organizations {
207 let entity_lower = entity_name.to_lowercase();
208 if let Some(start) = text_lower.find(&entity_lower) {
209 let at_word_start =
211 start == 0 || !text.chars().nth(start - 1).unwrap_or(' ').is_alphanumeric();
212 let at_word_end = start + entity_name.len() >= text.len()
213 || !text
214 .chars()
215 .nth(start + entity_name.len())
216 .unwrap_or(' ')
217 .is_alphanumeric();
218
219 if at_word_start && at_word_end {
220 entities.push(Entity {
221 text: text[start..start + entity_name.len()].to_string(),
222 entity_type: EntityType::Organization,
223 start,
224 end: start + entity_name.len(),
225 confidence: 0.9,
226 });
227 }
228 }
229 }
230
231 for entity_name in &self.locations {
232 let entity_lower = entity_name.to_lowercase();
233 if let Some(start) = text_lower.find(&entity_lower) {
234 let at_word_start =
236 start == 0 || !text.chars().nth(start - 1).unwrap_or(' ').is_alphanumeric();
237 let at_word_end = start + entity_name.len() >= text.len()
238 || !text
239 .chars()
240 .nth(start + entity_name.len())
241 .unwrap_or(' ')
242 .is_alphanumeric();
243
244 if at_word_start && at_word_end {
245 entities.push(Entity {
246 text: text[start..start + entity_name.len()].to_string(),
247 entity_type: EntityType::Location,
248 start,
249 end: start + entity_name.len(),
250 confidence: 0.9,
251 });
252 }
253 }
254 }
255
256 Ok(entities)
257 }
258}
259
260impl Default for RuleBasedNER {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266pub struct KeyPhraseExtractor {
268 min_phrase_length: usize,
269 max_phrase_length: usize,
270 min_frequency: usize,
271}
272
273impl KeyPhraseExtractor {
274 pub fn new() -> Self {
276 Self {
277 min_phrase_length: 1,
278 max_phrase_length: 3,
279 min_frequency: 2,
280 }
281 }
282
283 pub fn with_min_length(mut self, length: usize) -> Self {
285 self.min_phrase_length = length;
286 self
287 }
288
289 pub fn with_max_length(mut self, length: usize) -> Self {
291 self.max_phrase_length = length;
292 self
293 }
294
295 pub fn with_min_frequency(mut self, freq: usize) -> Self {
297 self.min_frequency = freq;
298 self
299 }
300
301 pub fn extract(&self, text: &str, tokenizer: &dyn Tokenizer) -> Result<Vec<(String, f64)>> {
303 let tokens = tokenizer.tokenize(text)?;
304 let mut phrase_counts: HashMap<String, usize> = HashMap::new();
305
306 for n in self.min_phrase_length..=self.max_phrase_length {
308 if tokens.len() >= n {
309 for i in 0..=tokens.len() - n {
310 let phrase = tokens[i..i + n].join(" ");
311 *phrase_counts.entry(phrase).or_insert(0) += 1;
312 }
313 }
314 }
315
316 let mut phrases: Vec<(String, f64)> = phrase_counts
318 .into_iter()
319 .filter(|(_, count)| *count >= self.min_frequency)
320 .map(|(phrase, count)| {
321 let score = count as f64 * (phrase.split_whitespace().count() as f64).sqrt();
323 (phrase, score)
324 })
325 .collect();
326
327 phrases.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
329
330 Ok(phrases)
331 }
332}
333
334impl Default for KeyPhraseExtractor {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340pub struct PatternExtractor {
342 patterns: Vec<(String, Regex)>,
343}
344
345impl PatternExtractor {
346 pub fn new() -> Self {
348 Self {
349 patterns: Vec::new(),
350 }
351 }
352
353 pub fn add_pattern(&mut self, name: String, pattern: Regex) {
355 self.patterns.push((name, pattern));
356 }
357
358 pub fn extract(&self, text: &str) -> Result<HashMap<String, Vec<String>>> {
360 let mut results: HashMap<String, Vec<String>> = HashMap::new();
361
362 for (name, pattern) in &self.patterns {
363 let mut matches = Vec::new();
364
365 for mat in pattern.find_iter(text) {
366 matches.push(mat.as_str().to_string());
367 }
368
369 if !matches.is_empty() {
370 results.insert(name.clone(), matches);
371 }
372 }
373
374 Ok(results)
375 }
376
377 pub fn extract_with_groups(
379 &self,
380 text: &str,
381 ) -> Result<HashMap<String, Vec<HashMap<String, String>>>> {
382 let mut results: HashMap<String, Vec<HashMap<String, String>>> = HashMap::new();
383
384 for (name, pattern) in &self.patterns {
385 let mut matches = Vec::new();
386
387 for caps in pattern.captures_iter(text) {
388 let mut groups = HashMap::new();
389
390 if let Some(full_match) = caps.get(0) {
392 groups.insert("full".to_string(), full_match.as_str().to_string());
393 }
394
395 for i in 1..caps.len() {
397 if let Some(group) = caps.get(i) {
398 groups.insert(format!("group{i}"), group.as_str().to_string());
399 }
400 }
401
402 for name in pattern.capture_names().flatten() {
404 if let Some(group) = caps.name(name) {
405 groups.insert(name.to_string(), group.as_str().to_string());
406 }
407 }
408
409 matches.push(groups);
410 }
411
412 if !matches.is_empty() {
413 results.insert(name.clone(), matches);
414 }
415 }
416
417 Ok(results)
418 }
419}
420
421impl Default for PatternExtractor {
422 fn default() -> Self {
423 Self::new()
424 }
425}