1use crate::error::{Result, TextError};
13use lazy_static::lazy_static;
14use regex::Regex;
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum EntityType {
24 Person,
26 Organization,
28 Location,
30 Date,
32 Time,
34 Number,
36 Email,
38 Url,
40 PhoneNumber,
42 Currency,
44 Percentage,
46 Custom(String),
48}
49
50impl std::fmt::Display for EntityType {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 Self::Person => write!(f, "PERSON"),
54 Self::Organization => write!(f, "ORGANIZATION"),
55 Self::Location => write!(f, "LOCATION"),
56 Self::Date => write!(f, "DATE"),
57 Self::Time => write!(f, "TIME"),
58 Self::Number => write!(f, "NUMBER"),
59 Self::Email => write!(f, "EMAIL"),
60 Self::Url => write!(f, "URL"),
61 Self::PhoneNumber => write!(f, "PHONE_NUMBER"),
62 Self::Currency => write!(f, "CURRENCY"),
63 Self::Percentage => write!(f, "PERCENTAGE"),
64 Self::Custom(label) => write!(f, "CUSTOM({})", label),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct Entity {
72 pub text: String,
74 pub entity_type: EntityType,
76 pub start: usize,
78 pub end: usize,
80 pub score: f32,
82}
83
84#[derive(Default)]
86pub struct NerConfig {
87 pub case_sensitive: bool,
89 pub custom_patterns: Vec<(String, EntityType)>,
91 pub gazetteer: HashMap<String, EntityType>,
93}
94
95impl NerConfig {
96 pub fn new() -> Self {
98 Self::default()
99 }
100}
101
102lazy_static! {
107 static ref RE_EMAIL: Regex = Regex::new(
109 r"(?i)\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b"
110 ).expect("email regex");
111
112 static ref RE_URL: Regex = Regex::new(
114 r"(?i)https?://[^\s<>\x22\{\}\|\\\^\[\]`]+"
115 ).expect("url regex");
116
117 static ref RE_PHONE: Regex = Regex::new(
119 r"(?:(?:\+?1[-.\s]?)?(?:\(\d{3}\)|\d{3})[-.\s]?\d{3}[-.\s]?\d{4})\b"
120 ).expect("phone regex");
121
122 static ref RE_DATE: Regex = Regex::new(
124 r"(?i)(?:\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b|\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,?\s+\d{4}\b|\b\d{1,2}(?:st|nd|rd|th)?\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}\b)"
125 ).expect("date regex");
126
127 static ref RE_TIME: Regex = Regex::new(
129 r"(?i)\b\d{1,2}:\d{2}(?::\d{2})?(?:\s*[AP]M)?\b"
130 ).expect("time regex");
131
132 static ref RE_CURRENCY: Regex = Regex::new(
134 r"(?:[\$\x{20AC}\x{00A3}\x{00A5}])\s*\d[\d,]*(?:\.\d{1,2})?|\d[\d,]*(?:\.\d{1,2})?\s*(?:USD|EUR|GBP|JPY|CAD|AUD|CHF|CNY)\b"
135 ).expect("currency regex");
136
137 static ref RE_PERCENTAGE: Regex = Regex::new(
139 r"(?i)\b\d+(?:\.\d+)?\s*(?:%|percent\b)"
140 ).expect("percentage regex");
141
142 static ref RE_NUMBER: Regex = Regex::new(
144 r"\b(?:\d+(?:\.\d+)?[eE][+\-]?\d+|\d+(?:\.\d+)?|\d+(?:st|nd|rd|th))\b"
145 ).expect("number regex");
146
147 static ref RE_PERSON_PREFIX: Regex = Regex::new(
149 r"\b(?:Dr|Prof|Mr|Mrs|Ms|Miss|Rev|Gen|Col|Capt|Lt|Sgt|Cpl|Pte|Sir|Lord|Lady|dr|prof|mr|mrs|ms|miss|rev|gen|col|capt|lt|sgt|cpl|pte|sir|lord|lady)\.?\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)"
150 ).expect("person prefix regex");
151
152 static ref RE_ORG_SUFFIX: Regex = Regex::new(
154 r"\b([A-Z][A-Za-z&\s]+(?:Inc|LLC|Ltd|Corp|Co|GmbH|AG|SA|PLC|LLP|LP|NV|BV|AB|AS|Pty)\.?)\b"
155 ).expect("org suffix regex");
156}
157
158fn default_location_gazetteer() -> &'static [&'static str] {
163 &[
164 "Africa",
165 "America",
166 "Antarctica",
167 "Arctic",
168 "Asia",
169 "Australia",
170 "Europe",
171 "China",
172 "France",
173 "Germany",
174 "India",
175 "Italy",
176 "Japan",
177 "Russia",
178 "Spain",
179 "United States",
180 "United Kingdom",
181 "Canada",
182 "Brazil",
183 "Mexico",
184 "Argentina",
185 "South Korea",
186 "North Korea",
187 "Saudi Arabia",
188 "South Africa",
189 "New York",
190 "London",
191 "Paris",
192 "Tokyo",
193 "Beijing",
194 "Shanghai",
195 "Sydney",
196 "Moscow",
197 "Berlin",
198 "Madrid",
199 "Rome",
200 "Seoul",
201 "Mumbai",
202 "Dubai",
203 "Los Angeles",
204 "Chicago",
205 "San Francisco",
206 "Houston",
207 "Phoenix",
208 "California",
209 "Texas",
210 "Florida",
211 "Illinois",
212 "Pennsylvania",
213 "Ohio",
214 "Georgia",
215 "Michigan",
216 "New Jersey",
217 "Virginia",
218 "Washington",
219 "Arizona",
220 "Massachusetts",
221 "Tennessee",
222 "Indiana",
223 ]
224}
225
226fn default_org_gazetteer() -> &'static [&'static str] {
227 &[
228 "Google",
229 "Apple",
230 "Microsoft",
231 "Amazon",
232 "Meta",
233 "Netflix",
234 "Tesla",
235 "IBM",
236 "Intel",
237 "Oracle",
238 "SAP",
239 "Adobe",
240 "Salesforce",
241 "Twitter",
242 "LinkedIn",
243 "Facebook",
244 "WhatsApp",
245 "Instagram",
246 "YouTube",
247 "TikTok",
248 "Snapchat",
249 "Uber",
250 "Lyft",
251 "Airbnb",
252 "Spotify",
253 "Slack",
254 "Zoom",
255 "Dropbox",
256 "NASA",
257 "CIA",
258 "FBI",
259 "NSA",
260 "UN",
261 "NATO",
262 "WHO",
263 "IMF",
264 "WTO",
265 "Harvard",
266 "MIT",
267 "Stanford",
268 "Oxford",
269 "Cambridge",
270 ]
271}
272
273pub struct NerExtractor {
290 config: NerConfig,
291 compiled_custom: Vec<(Regex, EntityType)>,
293 effective_gazetteer: HashMap<String, EntityType>,
295}
296
297impl NerExtractor {
298 pub fn new(config: NerConfig) -> Self {
300 let compiled_custom: Vec<(Regex, EntityType)> = config
301 .custom_patterns
302 .iter()
303 .filter_map(|(pattern, etype)| Regex::new(pattern).ok().map(|re| (re, etype.clone())))
304 .collect();
305
306 let mut effective_gazetteer: HashMap<String, EntityType> = HashMap::new();
307
308 for loc in default_location_gazetteer() {
309 let key = if config.case_sensitive {
310 loc.to_string()
311 } else {
312 loc.to_lowercase()
313 };
314 effective_gazetteer.insert(key, EntityType::Location);
315 }
316 for org in default_org_gazetteer() {
317 let key = if config.case_sensitive {
318 org.to_string()
319 } else {
320 org.to_lowercase()
321 };
322 effective_gazetteer.insert(key, EntityType::Organization);
323 }
324 for (word, etype) in &config.gazetteer {
325 let key = if config.case_sensitive {
326 word.clone()
327 } else {
328 word.to_lowercase()
329 };
330 effective_gazetteer.insert(key, etype.clone());
331 }
332
333 Self {
334 config,
335 compiled_custom,
336 effective_gazetteer,
337 }
338 }
339
340 pub fn try_new(config: NerConfig) -> Result<Self> {
343 for (pattern, _) in &config.custom_patterns {
344 Regex::new(pattern).map_err(|e| {
345 TextError::InvalidInput(format!(
346 "Custom NER pattern '{}' is invalid: {}",
347 pattern, e
348 ))
349 })?;
350 }
351 Ok(Self::new(config))
352 }
353
354 pub fn add_gazetteer_entry(&mut self, word: &str, entity_type: EntityType) {
356 let key = if self.config.case_sensitive {
357 word.to_string()
358 } else {
359 word.to_lowercase()
360 };
361 self.effective_gazetteer.insert(key, entity_type.clone());
362 self.config.gazetteer.insert(word.to_string(), entity_type);
363 }
364
365 pub fn extract(&self, text: &str) -> Result<Vec<Entity>> {
370 if text.is_empty() {
371 return Ok(Vec::new());
372 }
373
374 let mut candidates: Vec<Entity> = Vec::new();
375
376 self.apply_pattern(text, &RE_EMAIL, EntityType::Email, 1.0, &mut candidates);
378 self.apply_pattern(text, &RE_URL, EntityType::Url, 1.0, &mut candidates);
379 self.apply_pattern(
380 text,
381 &RE_PHONE,
382 EntityType::PhoneNumber,
383 1.0,
384 &mut candidates,
385 );
386 self.apply_pattern(text, &RE_DATE, EntityType::Date, 1.0, &mut candidates);
387 self.apply_pattern(text, &RE_TIME, EntityType::Time, 1.0, &mut candidates);
388 self.apply_pattern(
389 text,
390 &RE_CURRENCY,
391 EntityType::Currency,
392 1.0,
393 &mut candidates,
394 );
395 self.apply_pattern(
396 text,
397 &RE_PERCENTAGE,
398 EntityType::Percentage,
399 1.0,
400 &mut candidates,
401 );
402 self.apply_pattern(text, &RE_NUMBER, EntityType::Number, 1.0, &mut candidates);
403
404 self.extract_persons(text, &mut candidates);
406
407 self.extract_organizations(text, &mut candidates);
409
410 self.extract_gazetteer(text, &mut candidates);
412
413 for (re, etype) in &self.compiled_custom {
415 self.apply_pattern(text, re, etype.clone(), 1.0, &mut candidates);
416 }
417
418 let resolved = resolve_overlaps(candidates);
420 Ok(resolved)
421 }
422
423 fn apply_pattern(
428 &self,
429 text: &str,
430 re: &Regex,
431 etype: EntityType,
432 score: f32,
433 out: &mut Vec<Entity>,
434 ) {
435 for m in re.find_iter(text) {
436 out.push(Entity {
437 text: m.as_str().to_string(),
438 entity_type: etype.clone(),
439 start: m.start(),
440 end: m.end(),
441 score,
442 });
443 }
444 }
445
446 fn extract_persons(&self, text: &str, out: &mut Vec<Entity>) {
447 for cap in RE_PERSON_PREFIX.captures_iter(text) {
448 if let Some(full) = cap.get(0) {
449 out.push(Entity {
450 text: full.as_str().to_string(),
451 entity_type: EntityType::Person,
452 start: full.start(),
453 end: full.end(),
454 score: 0.9,
455 });
456 }
457 }
458 }
459
460 fn extract_organizations(&self, text: &str, out: &mut Vec<Entity>) {
461 for cap in RE_ORG_SUFFIX.captures_iter(text) {
462 if let Some(m) = cap.get(1) {
463 out.push(Entity {
464 text: m.as_str().to_string(),
465 entity_type: EntityType::Organization,
466 start: m.start(),
467 end: m.end(),
468 score: 0.85,
469 });
470 }
471 }
472 }
473
474 fn extract_gazetteer(&self, text: &str, out: &mut Vec<Entity>) {
475 let lookup_text = if self.config.case_sensitive {
476 text.to_string()
477 } else {
478 text.to_lowercase()
479 };
480
481 let mut entries: Vec<(&String, &EntityType)> = self.effective_gazetteer.iter().collect();
482 entries.sort_by_key(|(k, _)| std::cmp::Reverse(k.len()));
483
484 for (entry, etype) in entries {
485 let escaped = regex::escape(entry);
486 let pattern = format!(r"(?i)\b{}\b", escaped);
487 if let Ok(re) = Regex::new(&pattern) {
488 for m in re.find_iter(&lookup_text) {
489 let original_text = &text[m.start()..m.end()];
490 out.push(Entity {
491 text: original_text.to_string(),
492 entity_type: etype.clone(),
493 start: m.start(),
494 end: m.end(),
495 score: 1.0,
496 });
497 }
498 }
499 }
500 }
501}
502
503fn resolve_overlaps(mut entities: Vec<Entity>) -> Vec<Entity> {
508 if entities.is_empty() {
509 return entities;
510 }
511
512 entities.sort_by(|a, b| {
513 a.start
514 .cmp(&b.start)
515 .then_with(|| (b.end - b.start).cmp(&(a.end - a.start)))
516 .then_with(|| {
517 b.score
518 .partial_cmp(&a.score)
519 .unwrap_or(std::cmp::Ordering::Equal)
520 })
521 });
522
523 let mut result: Vec<Entity> = Vec::new();
524 let mut last_end: usize = 0;
525
526 for entity in entities {
527 if entity.start >= last_end {
528 last_end = entity.end;
529 result.push(entity);
530 }
531 }
532
533 result
534}
535
536#[cfg(test)]
541mod tests {
542 use super::*;
543
544 fn default_extractor() -> NerExtractor {
545 NerExtractor::new(NerConfig::default())
546 }
547
548 #[test]
549 fn test_email_extraction() {
550 let extractor = default_extractor();
551 let entities = extractor
552 .extract("Please contact support@example.com for help.")
553 .expect("should succeed");
554 let emails: Vec<&Entity> = entities
555 .iter()
556 .filter(|e| e.entity_type == EntityType::Email)
557 .collect();
558 assert!(!emails.is_empty(), "Should detect at least one email");
559 assert_eq!(emails[0].text, "support@example.com");
560 }
561
562 #[test]
563 fn test_multiple_emails() {
564 let extractor = default_extractor();
565 let text = "Send to alice@foo.com and bob@bar.org please.";
566 let entities = extractor.extract(text).expect("ok");
567 let emails: Vec<_> = entities
568 .iter()
569 .filter(|e| e.entity_type == EntityType::Email)
570 .collect();
571 assert_eq!(emails.len(), 2);
572 }
573
574 #[test]
575 fn test_url_extraction() {
576 let extractor = default_extractor();
577 let entities = extractor
578 .extract("Visit https://www.rust-lang.org for docs.")
579 .expect("ok");
580 assert!(entities.iter().any(|e| e.entity_type == EntityType::Url));
581 }
582
583 #[test]
584 fn test_phone_extraction() {
585 let extractor = default_extractor();
586 let entities = extractor.extract("Call us at (800) 555-1234.").expect("ok");
587 assert!(
588 entities
589 .iter()
590 .any(|e| e.entity_type == EntityType::PhoneNumber),
591 "Should detect phone number"
592 );
593 }
594
595 #[test]
596 fn test_iso_date() {
597 let extractor = default_extractor();
598 let entities = extractor.extract("Event on 2025-06-15.").expect("ok");
599 assert!(entities.iter().any(|e| e.entity_type == EntityType::Date));
600 }
601
602 #[test]
603 fn test_written_date() {
604 let extractor = default_extractor();
605 let entities = extractor
606 .extract("He was born on March 5, 1990.")
607 .expect("ok");
608 assert!(entities.iter().any(|e| e.entity_type == EntityType::Date));
609 }
610
611 #[test]
612 fn test_currency_dollar() {
613 let extractor = default_extractor();
614 let entities = extractor.extract("The price is $42.99.").expect("ok");
615 assert!(
616 entities
617 .iter()
618 .any(|e| e.entity_type == EntityType::Currency),
619 "Should detect currency"
620 );
621 }
622
623 #[test]
624 fn test_percentage() {
625 let extractor = default_extractor();
626 let entities = extractor.extract("Growth rate is 15.3%.").expect("ok");
627 assert!(entities
628 .iter()
629 .any(|e| e.entity_type == EntityType::Percentage));
630 }
631
632 #[test]
633 fn test_integer_number() {
634 let extractor = default_extractor();
635 let entities = extractor.extract("There are 42 items.").expect("ok");
636 assert!(entities.iter().any(|e| e.entity_type == EntityType::Number));
637 }
638
639 #[test]
640 fn test_person_with_title() {
641 let extractor = default_extractor();
642 let entities = extractor
643 .extract("We met Dr. Jane Smith yesterday.")
644 .expect("ok");
645 assert!(
646 entities.iter().any(|e| e.entity_type == EntityType::Person),
647 "Should detect person with title"
648 );
649 }
650
651 #[test]
652 fn test_org_with_suffix() {
653 let extractor = default_extractor();
654 let entities = extractor.extract("She works at Acme Corp.").expect("ok");
655 assert!(
656 entities
657 .iter()
658 .any(|e| e.entity_type == EntityType::Organization),
659 "Should detect organization"
660 );
661 }
662
663 #[test]
664 fn test_gazetteer_location() {
665 let extractor = default_extractor();
666 let entities = extractor
667 .extract("The summit was held in Paris.")
668 .expect("ok");
669 assert!(
670 entities.iter().any(|e| {
671 e.entity_type == EntityType::Location && e.text.to_lowercase() == "paris"
672 }),
673 "Should detect Paris as location via gazetteer"
674 );
675 }
676
677 #[test]
678 fn test_gazetteer_organization() {
679 let extractor = default_extractor();
680 let entities = extractor
681 .extract("Google announced new products.")
682 .expect("ok");
683 assert!(
684 entities
685 .iter()
686 .any(|e| e.entity_type == EntityType::Organization),
687 "Should detect Google as organization"
688 );
689 }
690
691 #[test]
692 fn test_custom_pattern() {
693 let config = NerConfig {
694 custom_patterns: vec![(
695 r"\b[A-Z]{3,5}-\d{4}\b".to_string(),
696 EntityType::Custom("TICKET_ID".to_string()),
697 )],
698 ..NerConfig::default()
699 };
700 let extractor = NerExtractor::new(config);
701 let entities = extractor
702 .extract("Issue JIRA-1234 is resolved.")
703 .expect("ok");
704 assert!(entities.iter().any(|e| matches!(
705 &e.entity_type,
706 EntityType::Custom(label) if label == "TICKET_ID"
707 )));
708 }
709
710 #[test]
711 fn test_invalid_custom_pattern_returns_error() {
712 let config = NerConfig {
713 custom_patterns: vec![(r"[invalid".to_string(), EntityType::Custom("X".to_string()))],
714 ..NerConfig::default()
715 };
716 assert!(NerExtractor::try_new(config).is_err());
717 }
718
719 #[test]
720 fn test_add_gazetteer_entry() {
721 let mut extractor = NerExtractor::new(NerConfig::default());
722 extractor.add_gazetteer_entry("Rustacean", EntityType::Custom("COMMUNITY".to_string()));
723 let entities = extractor
724 .extract("The Rustacean organized an event.")
725 .expect("ok");
726 assert!(entities.iter().any(|e| matches!(
727 &e.entity_type,
728 EntityType::Custom(label) if label == "COMMUNITY"
729 )));
730 }
731
732 #[test]
733 fn test_entities_non_overlapping() {
734 let extractor = default_extractor();
735 let text = "Email info@test.com, call (555) 123-4567.";
736 let entities = extractor.extract(text).expect("ok");
737 for i in 1..entities.len() {
738 assert!(
739 entities[i].start >= entities[i - 1].end,
740 "Entities should not overlap"
741 );
742 }
743 }
744
745 #[test]
746 fn test_empty_text() {
747 let extractor = default_extractor();
748 let entities = extractor.extract("").expect("ok");
749 assert!(entities.is_empty());
750 }
751
752 #[test]
753 fn test_email_score_is_one() {
754 let extractor = default_extractor();
755 let entities = extractor.extract("user@domain.com").expect("ok");
756 let emails: Vec<_> = entities
757 .iter()
758 .filter(|e| e.entity_type == EntityType::Email)
759 .collect();
760 assert!(!emails.is_empty());
761 assert!((emails[0].score - 1.0).abs() < 1e-6);
762 }
763
764 #[test]
765 fn test_entity_type_display() {
766 assert_eq!(EntityType::Email.to_string(), "EMAIL");
767 assert_eq!(EntityType::Person.to_string(), "PERSON");
768 assert_eq!(
769 EntityType::Custom("FOO".to_string()).to_string(),
770 "CUSTOM(FOO)"
771 );
772 }
773
774 #[test]
775 fn test_mixed_entities() {
776 let text = "On 2025-01-15 at 10:30, Dr. John Smith emailed john@example.com.";
777 let entities = extractor_all().extract(text).expect("ok");
778 let types: std::collections::HashSet<String> =
779 entities.iter().map(|e| e.entity_type.to_string()).collect();
780 assert!(types.contains("DATE"), "missing DATE in {:?}", types);
781 assert!(types.contains("EMAIL"), "missing EMAIL in {:?}", types);
782 }
783
784 fn extractor_all() -> NerExtractor {
785 NerExtractor::new(NerConfig::default())
786 }
787}