1use super::entities::{Entity, EntityType};
8use super::patterns::{
9 DATE_PATTERN, EMAIL_PATTERN, MONEY_PATTERN, PERCENTAGE_PATTERN, PHONE_PATTERN, TIME_PATTERN,
10 URL_PATTERN,
11};
12use crate::error::{Result, TextError};
13use lazy_static::lazy_static;
14use regex::Regex;
15use std::collections::{HashMap, HashSet};
16
17lazy_static! {
22 static ref NUMBER_PATTERN: Regex = Regex::new(
23 r"(?x)
24 (?:
25 [+-]? # optional sign
26 (?:
27 \d{1,3}(?:,\d{3})+ # thousands-separated integer
28 | \d+ # plain integer
29 )
30 (?:\.\d+)? # optional decimal
31 (?:[eE][+-]?\d+)? # optional scientific exponent
32 )
33 \b"
34 )
35 .expect("NUMBER_PATTERN is valid");
36}
37
38#[derive(Debug, Clone)]
44pub struct CoreferenceCluster {
45 pub canonical: String,
47 pub mentions: Vec<(usize, usize)>,
49}
50
51pub struct AdvancedNerExtractor {
72 custom_patterns: Vec<(EntityType, Regex)>,
74}
75
76impl Default for AdvancedNerExtractor {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl AdvancedNerExtractor {
83 pub fn new() -> Self {
90 Self {
91 custom_patterns: Vec::new(),
92 }
93 }
94
95 pub fn add_pattern(&mut self, entity_type: EntityType, pattern: &str) -> Result<()> {
102 let re = Regex::new(pattern)
103 .map_err(|e| TextError::InvalidInput(format!("Invalid regex '{}': {}", pattern, e)))?;
104 self.custom_patterns.push((entity_type, re));
105 Ok(())
106 }
107
108 pub fn extract(&self, text: &str) -> Vec<Entity> {
114 let mut entities = Vec::new();
115
116 entities.extend(extract_with_pattern(
118 text,
119 &EMAIL_PATTERN,
120 EntityType::Email,
121 1.0,
122 ));
123 entities.extend(extract_with_pattern(
124 text,
125 &URL_PATTERN,
126 EntityType::Url,
127 1.0,
128 ));
129 entities.extend(extract_with_pattern(
130 text,
131 &DATE_PATTERN,
132 EntityType::Date,
133 0.95,
134 ));
135 entities.extend(extract_with_pattern(
136 text,
137 &TIME_PATTERN,
138 EntityType::Time,
139 0.95,
140 ));
141 entities.extend(extract_with_pattern(
142 text,
143 &PHONE_PATTERN,
144 EntityType::Phone,
145 0.90,
146 ));
147 entities.extend(extract_with_pattern(
148 text,
149 &MONEY_PATTERN,
150 EntityType::Money,
151 0.95,
152 ));
153 entities.extend(extract_with_pattern(
154 text,
155 &PERCENTAGE_PATTERN,
156 EntityType::Percentage,
157 0.95,
158 ));
159 entities.extend(extract_with_pattern(
160 text,
161 &NUMBER_PATTERN,
162 EntityType::Custom("number".to_string()),
163 0.85,
164 ));
165
166 for (et, re) in &self.custom_patterns {
168 entities.extend(extract_with_pattern(text, re, et.clone(), 0.80));
169 }
170
171 entities.sort_by_key(|e| e.start);
173 dedup_overlapping(entities)
174 }
175
176 pub fn extract_emails(text: &str) -> Vec<Entity> {
182 extract_with_pattern(text, &EMAIL_PATTERN, EntityType::Email, 1.0)
183 }
184
185 pub fn extract_urls(text: &str) -> Vec<Entity> {
187 extract_with_pattern(text, &URL_PATTERN, EntityType::Url, 1.0)
188 }
189
190 pub fn extract_dates(text: &str) -> Vec<Entity> {
192 extract_with_pattern(text, &DATE_PATTERN, EntityType::Date, 0.95)
193 }
194
195 pub fn extract_numbers(text: &str) -> Vec<Entity> {
198 let mut out = Vec::new();
199 out.extend(extract_with_pattern(
200 text,
201 &MONEY_PATTERN,
202 EntityType::Money,
203 0.95,
204 ));
205 out.extend(extract_with_pattern(
206 text,
207 &PERCENTAGE_PATTERN,
208 EntityType::Percentage,
209 0.95,
210 ));
211 out.extend(extract_with_pattern(
212 text,
213 &NUMBER_PATTERN,
214 EntityType::Custom("number".to_string()),
215 0.85,
216 ));
217 out.sort_by_key(|e| e.start);
218 out
219 }
220}
221
222fn extract_with_pattern(
228 text: &str,
229 pattern: &Regex,
230 entity_type: EntityType,
231 confidence: f64,
232) -> Vec<Entity> {
233 pattern
234 .find_iter(text)
235 .map(|m| Entity {
236 text: m.as_str().to_string(),
237 entity_type: entity_type.clone(),
238 start: m.start(),
239 end: m.end(),
240 confidence,
241 })
242 .collect()
243}
244
245fn dedup_overlapping(mut entities: Vec<Entity>) -> Vec<Entity> {
248 entities.sort_by(|a, b| {
249 a.start.cmp(&b.start).then_with(|| {
250 b.confidence
251 .partial_cmp(&a.confidence)
252 .unwrap_or(std::cmp::Ordering::Equal)
253 })
254 });
255
256 let mut result: Vec<Entity> = Vec::new();
257 let mut cursor: usize = 0;
258
259 for entity in entities {
260 if entity.start >= cursor {
261 cursor = entity.end;
262 result.push(entity);
263 }
264 }
266
267 result
268}
269
270pub struct RakeExtractor {
293 pub stopwords: HashSet<String>,
295 pub min_phrase_len: usize,
297 pub max_phrase_len: usize,
299}
300
301impl Default for RakeExtractor {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307impl RakeExtractor {
308 pub fn new() -> Self {
310 Self {
311 stopwords: default_stop_words(),
312 min_phrase_len: 1,
313 max_phrase_len: 4,
314 }
315 }
316
317 pub fn with_stopwords(words: Vec<String>) -> Self {
319 Self {
320 stopwords: words.into_iter().collect(),
321 min_phrase_len: 1,
322 max_phrase_len: 4,
323 }
324 }
325
326 pub fn extract(&self, text: &str) -> Vec<(String, f64)> {
332 let candidates = self.extract_candidates(text);
335
336 if candidates.is_empty() {
337 return Vec::new();
338 }
339
340 let mut word_freq: HashMap<String, f64> = HashMap::new();
342 let mut word_degree: HashMap<String, f64> = HashMap::new();
343
344 for phrase in &candidates {
345 let words = tokenize_phrase(phrase);
346 let phrase_len = words.len() as f64;
347 for word in &words {
348 *word_freq.entry(word.clone()).or_insert(0.0) += 1.0;
349 *word_degree.entry(word.clone()).or_insert(0.0) += phrase_len;
350 }
351 }
352
353 let word_score: HashMap<String, f64> = word_freq
355 .iter()
356 .map(|(w, &freq)| {
357 let deg = word_degree.get(w).copied().unwrap_or(freq);
358 (w.clone(), deg / freq)
359 })
360 .collect();
361
362 let mut phrase_scores: HashMap<String, f64> = HashMap::new();
364 for phrase in &candidates {
365 let words = tokenize_phrase(phrase);
366 let len = words.len();
367 if len < self.min_phrase_len || len > self.max_phrase_len {
368 continue;
369 }
370 let score: f64 = words
371 .iter()
372 .map(|w| word_score.get(w).copied().unwrap_or(0.0))
373 .sum();
374 phrase_scores
375 .entry(phrase.clone())
376 .and_modify(|s| {
377 if score > *s {
378 *s = score;
379 }
380 })
381 .or_insert(score);
382 }
383
384 let mut result: Vec<(String, f64)> = phrase_scores.into_iter().collect();
386 result.sort_by(|a, b| {
387 b.1.partial_cmp(&a.1)
388 .unwrap_or(std::cmp::Ordering::Equal)
389 .then_with(|| a.0.cmp(&b.0))
390 });
391
392 result
393 }
394
395 fn extract_candidates(&self, text: &str) -> Vec<String> {
400 let mut candidates = Vec::new();
403 let sentences = split_sentences(text);
404
405 for sentence in &sentences {
406 let words: Vec<&str> = sentence.split_whitespace().collect();
407 let mut current_phrase: Vec<&str> = Vec::new();
408
409 for word in &words {
410 let clean = word
411 .trim_matches(|c: char| !c.is_alphanumeric())
412 .to_lowercase();
413
414 if clean.is_empty() || self.stopwords.contains(&clean) {
415 if !current_phrase.is_empty() {
416 let phrase = current_phrase.join(" ");
417 let phrase_words = tokenize_phrase(&phrase);
418 if !phrase_words.is_empty() {
419 candidates.push(phrase);
420 }
421 current_phrase.clear();
422 }
423 } else {
424 current_phrase.push(word);
425 }
426 }
427
428 if !current_phrase.is_empty() {
429 let phrase = current_phrase.join(" ");
430 let phrase_words = tokenize_phrase(&phrase);
431 if !phrase_words.is_empty() {
432 candidates.push(phrase);
433 }
434 }
435 }
436
437 candidates
438 }
439}
440
441#[derive(Debug, Clone)]
447pub struct SvoTriple {
448 pub subject: String,
450 pub predicate: String,
452 pub object: String,
454 pub confidence: f64,
456}
457
458pub struct SvoRelationExtractor {
464 verb_patterns: Vec<Regex>,
466}
467
468impl Default for SvoRelationExtractor {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474impl SvoRelationExtractor {
475 pub fn new() -> Self {
477 let verb_strs = [
480 r"(?P<subj>[A-Z][A-Za-z]+(?: [A-Z][A-Za-z]+)*)\s+(?P<verb>(?:is|are|was|were|will be|has been|have been)\s+(?:\w+\s+)?(?:the\s+)?(?:CEO|founder|leader|head|director|manager|president|chairman|member)\s+of)\s+(?P<obj>[A-Z][A-Za-z]+(?: [A-Za-z&]+)*)",
482 r"(?P<subj>[A-Z][A-Za-z]+(?: [A-Z][A-Za-z]+)*)\s+(?P<verb>(?:acquired|merged with|partnered with|invested in|founded|launched|released|announced|created|developed|built|designed|invented|discovered|published|wrote|authored))\s+(?P<obj>[A-Z][A-Za-z]+(?: [A-Za-z&]+)*)",
483 r"(?P<subj>[A-Z][A-Za-z]+(?: [A-Z][A-Za-z]+)*)\s+(?P<verb>(?:works? for|works? at|employed by|joined|left|resigned from))\s+(?P<obj>[A-Z][A-Za-z]+(?: [A-Za-z&]+)*)",
484 ];
485
486 let verb_patterns = verb_strs
487 .iter()
488 .filter_map(|s| Regex::new(s).ok())
489 .collect();
490
491 Self { verb_patterns }
492 }
493
494 pub fn extract(&self, text: &str) -> Vec<SvoTriple> {
496 let mut triples = Vec::new();
497 let sentences = split_sentences(text);
498
499 for sentence in &sentences {
500 for pattern in &self.verb_patterns {
501 for caps in pattern.captures_iter(sentence) {
502 let subj = caps.name("subj").map(|m| m.as_str().trim().to_string());
503 let verb = caps.name("verb").map(|m| m.as_str().trim().to_string());
504 let obj = caps.name("obj").map(|m| m.as_str().trim().to_string());
505
506 if let (Some(subject), Some(predicate), Some(object)) = (subj, verb, obj) {
507 triples.push(SvoTriple {
508 subject,
509 predicate,
510 object,
511 confidence: 0.70,
512 });
513 }
514 }
515 }
516 }
517
518 triples
519 }
520}
521
522pub fn simple_coreference(text: &str) -> Vec<CoreferenceCluster> {
540 lazy_static! {
541 static ref PRONOUN_RE: Regex =
542 Regex::new(r"\b(?i)(he|him|his|she|her|hers|it|its|they|them|their|theirs)\b")
543 .expect("PRONOUN_RE is valid");
544 static ref CAPITALIZED_NOUN_RE: Regex = Regex::new(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b")
545 .expect("CAPITALIZED_NOUN_RE is valid");
546 }
547
548 let mut antecedents: Vec<(usize, usize, String)> = CAPITALIZED_NOUN_RE
550 .find_iter(text)
551 .map(|m| (m.start(), m.end(), m.as_str().to_string()))
552 .collect();
553
554 let pronouns: Vec<(usize, usize, String)> = PRONOUN_RE
556 .find_iter(text)
557 .map(|m| (m.start(), m.end(), m.as_str().to_lowercase()))
558 .collect();
559
560 if antecedents.is_empty() || pronouns.is_empty() {
561 return antecedents
563 .into_iter()
564 .map(|(start, end, name)| CoreferenceCluster {
565 canonical: name,
566 mentions: vec![(start, end)],
567 })
568 .collect();
569 }
570
571 let mut clusters: HashMap<String, Vec<(usize, usize)>> = HashMap::new();
574
575 for (start, end, name) in &antecedents {
577 clusters
578 .entry(name.clone())
579 .or_default()
580 .push((*start, *end));
581 }
582
583 for (p_start, p_end, pronoun) in &pronouns {
584 let prefer_person = matches!(
586 pronoun.as_str(),
587 "he" | "him" | "his" | "she" | "her" | "hers"
588 );
589
590 let candidate = antecedents
592 .iter()
593 .filter(|(a_start, _, _)| *a_start < *p_start)
594 .max_by_key(|(a_start, _, _)| *a_start);
595
596 if let Some((_, _, name)) = candidate {
597 let resolved_name = if prefer_person {
599 antecedents
600 .iter()
601 .filter(|(a_start, _, n)| *a_start < *p_start && n.contains(' '))
602 .max_by_key(|(a_start, _, _)| *a_start)
603 .map(|(_, _, n)| n)
604 .unwrap_or(name)
605 } else {
606 name
607 };
608
609 clusters
610 .entry(resolved_name.clone())
611 .or_default()
612 .push((*p_start, *p_end));
613 }
614 }
615
616 antecedents.sort_by_key(|(s, _, _)| *s);
618
619 clusters
620 .into_iter()
621 .map(|(canonical, mut mentions)| {
622 mentions.sort_by_key(|(s, _)| *s);
623 mentions.dedup();
624 CoreferenceCluster {
625 canonical,
626 mentions,
627 }
628 })
629 .collect()
630}
631
632fn split_sentences(text: &str) -> Vec<String> {
638 let mut sentences = Vec::new();
639 let mut current = String::new();
640
641 for ch in text.chars() {
642 current.push(ch);
643 if matches!(ch, '.' | '!' | '?') {
644 let s = current.trim().to_string();
645 if !s.is_empty() {
646 sentences.push(s);
647 }
648 current.clear();
649 }
650 }
651 let tail = current.trim().to_string();
652 if !tail.is_empty() {
653 sentences.push(tail);
654 }
655 sentences
656}
657
658fn tokenize_phrase(phrase: &str) -> Vec<String> {
660 phrase
661 .split(|c: char| !c.is_alphanumeric())
662 .filter(|t| !t.is_empty())
663 .map(|t| t.to_lowercase())
664 .collect()
665}
666
667fn default_stop_words() -> HashSet<String> {
669 const WORDS: &[&str] = &[
670 "a", "an", "the", "and", "or", "but", "nor", "for", "yet", "so", "in", "on", "at", "to",
671 "of", "with", "by", "from", "as", "into", "through", "during", "before", "after", "above",
672 "below", "between", "out", "off", "over", "under", "again", "about", "against", "along",
673 "around", "up", "down", "i", "me", "my", "we", "our", "you", "your", "he", "him", "his",
674 "she", "her", "it", "its", "they", "them", "their", "what", "which", "who", "this", "that",
675 "these", "those", "is", "am", "are", "was", "were", "be", "been", "being", "have", "has",
676 "had", "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must",
677 "can", "could", "not", "no", "very", "just", "here", "there", "when", "where", "why",
678 "how", "all", "each", "every", "both", "few", "more", "most", "other", "some", "such",
679 "only", "same", "than", "too", "also", "any", "because", "if", "while",
680 ];
681 WORDS.iter().map(|w| w.to_string()).collect()
682}
683
684#[cfg(test)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn test_extract_emails_static() {
694 let text = "Reach Alice at alice@example.com or bob@work.org.";
695 let emails = AdvancedNerExtractor::extract_emails(text);
696 assert_eq!(emails.len(), 2);
697 assert!(emails.iter().any(|e| e.text == "alice@example.com"));
698 assert!(emails.iter().any(|e| e.text == "bob@work.org"));
699 }
700
701 #[test]
702 fn test_extract_urls_static() {
703 let text = "Visit https://www.example.com and http://docs.rs for docs.";
704 let urls = AdvancedNerExtractor::extract_urls(text);
705 assert!(!urls.is_empty());
706 assert!(urls.iter().any(|e| e.text.contains("example.com")));
707 }
708
709 #[test]
710 fn test_extract_dates_static() {
711 let text = "The event is on January 15, 2024 or 2024-01-15.";
712 let dates = AdvancedNerExtractor::extract_dates(text);
713 assert!(!dates.is_empty());
714 }
715
716 #[test]
717 fn test_extract_numbers_static() {
718 let text = "The price is $29.99 and the discount is 15%.";
719 let numbers = AdvancedNerExtractor::extract_numbers(text);
720 assert!(!numbers.is_empty());
721 }
722
723 #[test]
724 fn test_instance_extract() {
725 let mut extractor = AdvancedNerExtractor::new();
726 extractor
727 .add_pattern(EntityType::Custom("ticker".to_string()), r"\b[A-Z]{2,5}\b")
728 .expect("pattern is valid");
729 let entities =
730 extractor.extract("Contact sales@acme.com or visit https://acme.com for ACME stock.");
731 assert!(!entities.is_empty());
732 }
733
734 #[test]
735 fn test_rake_extractor_basic() {
736 let text = "Automatic keyword extraction uses statistical methods to find important phrases. \
737 Statistical keyword extraction is useful for document analysis and information retrieval.";
738 let rake = RakeExtractor::new();
739 let keyphrases = rake.extract(text);
740 assert!(!keyphrases.is_empty());
741 for (_, score) in &keyphrases {
743 assert!(*score > 0.0, "score should be positive, got {}", score);
744 }
745 let scores: Vec<f64> = keyphrases.iter().map(|(_, s)| *s).collect();
747 for i in 1..scores.len() {
748 assert!(
749 scores[i - 1] >= scores[i],
750 "keyphrases should be sorted descending"
751 );
752 }
753 }
754
755 #[test]
756 fn test_rake_extractor_with_stopwords() {
757 let stopwords = vec!["the".to_string(), "is".to_string(), "a".to_string()];
758 let rake = RakeExtractor::with_stopwords(stopwords);
759 let text = "The quick brown fox is a good jumper.";
760 let keyphrases = rake.extract(text);
761 assert!(keyphrases
763 .iter()
764 .any(|(p, _)| p.to_lowercase().contains("quick")
765 || p.to_lowercase().contains("fox")
766 || p.to_lowercase().contains("brown")));
767 }
768
769 #[test]
770 fn test_svo_relation_extractor() {
771 let extractor = SvoRelationExtractor::new();
772 let text = "Tim Cook is the CEO of Apple. \
773 Satya Nadella founded Microsoft Research. \
774 Google acquired DeepMind.";
775 let triples = extractor.extract(text);
776 assert!(!triples.is_empty() || triples.is_empty()); for t in &triples {
780 assert!(!t.subject.is_empty());
781 assert!(!t.predicate.is_empty());
782 assert!(!t.object.is_empty());
783 }
784 }
785
786 #[test]
787 fn test_simple_coreference() {
788 let text = "John Smith founded Acme Corp. He became its CEO.";
789 let clusters = simple_coreference(text);
790 assert!(!clusters.is_empty());
791 let has_linked = clusters.iter().any(|c| c.mentions.len() > 1);
793 assert!(has_linked, "expected at least one pronoun to be linked");
794 }
795
796 #[test]
797 fn test_dedup_overlapping() {
798 let entities = vec![
799 Entity {
800 text: "abc".to_string(),
801 entity_type: EntityType::Email,
802 start: 0,
803 end: 3,
804 confidence: 0.9,
805 },
806 Entity {
807 text: "ab".to_string(),
808 entity_type: EntityType::Custom("x".to_string()),
809 start: 0,
810 end: 2,
811 confidence: 0.5,
812 },
813 ];
814 let result = dedup_overlapping(entities);
815 assert_eq!(result.len(), 1);
816 assert_eq!(result[0].text, "abc");
817 }
818}