Skip to main content

scirs2_text/
named_entity_recognition.rs

1//! Named Entity Recognition module
2//!
3//! Enhanced rule-based and pattern-based NER that goes beyond the basic
4//! `information_extraction::RuleBasedNER` by adding:
5//!
6//! - Comprehensive date/time patterns (ISO, relative, informal)
7//! - IP address, hashtag, and mention patterns
8//! - Scientific notation, ordinals, and rich number patterns
9//! - Capitalization heuristics for person / organisation / location detection
10//! - A unified [`extract_entities`] API
11//!
12//! All detection is purely rule-based (no trained models) and 100% Pure Rust.
13
14use crate::error::{Result, TextError};
15use lazy_static::lazy_static;
16use regex::Regex;
17use std::collections::HashSet;
18
19// ---------------------------------------------------------------------------
20// Types
21// ---------------------------------------------------------------------------
22
23/// Kind of entity recognised.
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub enum NerEntityType {
26    /// A person name.
27    Person,
28    /// An organisation.
29    Organisation,
30    /// A geographic location.
31    Location,
32    /// A date expression.
33    Date,
34    /// A time expression.
35    Time,
36    /// An email address.
37    Email,
38    /// A URL or URI.
39    Url,
40    /// An IP address (v4 or v6-prefix).
41    IpAddress,
42    /// A social-media hashtag (#topic).
43    Hashtag,
44    /// A social-media mention (@user).
45    Mention,
46    /// A monetary amount.
47    Money,
48    /// A percentage value.
49    Percentage,
50    /// A phone number.
51    Phone,
52    /// A number (integer, float, scientific notation, ordinal).
53    Number,
54    /// User-defined entity type.
55    Custom(String),
56}
57
58/// An entity extracted from text.
59#[derive(Debug, Clone)]
60pub struct NerEntity {
61    /// The matched text.
62    pub text: String,
63    /// The entity type.
64    pub entity_type: NerEntityType,
65    /// Byte offset of the start in the original text.
66    pub start: usize,
67    /// Byte offset of the end in the original text.
68    pub end: usize,
69    /// Confidence in [0, 1]. Pattern matches are typically 1.0, heuristic
70    /// matches are lower.
71    pub confidence: f64,
72}
73
74/// Which pattern groups to enable.
75#[derive(Debug, Clone)]
76pub struct NerPatternConfig {
77    /// Enable date patterns.
78    pub dates: bool,
79    /// Enable time patterns.
80    pub times: bool,
81    /// Enable email patterns.
82    pub emails: bool,
83    /// Enable URL patterns.
84    pub urls: bool,
85    /// Enable IP address patterns.
86    pub ip_addresses: bool,
87    /// Enable hashtag patterns.
88    pub hashtags: bool,
89    /// Enable mention patterns.
90    pub mentions: bool,
91    /// Enable money patterns.
92    pub money: bool,
93    /// Enable percentage patterns.
94    pub percentages: bool,
95    /// Enable phone number patterns.
96    pub phones: bool,
97    /// Enable number patterns (integers, floats, scientific, ordinals).
98    pub numbers: bool,
99    /// Enable heuristic person / org / location detection.
100    pub heuristic_entities: bool,
101}
102
103impl Default for NerPatternConfig {
104    fn default() -> Self {
105        Self {
106            dates: true,
107            times: true,
108            emails: true,
109            urls: true,
110            ip_addresses: true,
111            hashtags: true,
112            mentions: true,
113            money: true,
114            percentages: true,
115            phones: true,
116            numbers: true,
117            heuristic_entities: true,
118        }
119    }
120}
121
122impl NerPatternConfig {
123    /// Enable all patterns.
124    pub fn all() -> Self {
125        Self::default()
126    }
127
128    /// Disable all patterns (use as a starting point to enable only what
129    /// you need).
130    pub fn none() -> Self {
131        Self {
132            dates: false,
133            times: false,
134            emails: false,
135            urls: false,
136            ip_addresses: false,
137            hashtags: false,
138            mentions: false,
139            money: false,
140            percentages: false,
141            phones: false,
142            numbers: false,
143            heuristic_entities: false,
144        }
145    }
146}
147
148// ---------------------------------------------------------------------------
149// Compiled patterns (lazy_static)
150// ---------------------------------------------------------------------------
151
152lazy_static! {
153    // --- Date patterns ---
154    static ref ISO_DATE_RE: Regex = Regex::new(
155        r"\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b"
156    ).expect("valid regex");
157
158    static ref US_DATE_RE: Regex = Regex::new(
159        r"\b(?:0?[1-9]|1[0-2])[/-](?:0?[1-9]|[12]\d|3[01])[/-](?:19|20)\d{2}\b"
160    ).expect("valid regex");
161
162    static ref MONTH_NAME_DATE_RE: Regex = Regex::new(
163        r"(?i)\b(?:January|February|March|April|May|June|July|August|September|October|November|December|Jan|Feb|Mar|Apr|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2}(?:st|nd|rd|th)?,?\s+\d{4}\b"
164    ).expect("valid regex");
165
166    static ref DAY_MONTH_YEAR_RE: Regex = Regex::new(
167        r"(?i)\b\d{1,2}(?:st|nd|rd|th)?\s+(?:January|February|March|April|May|June|July|August|September|October|November|December|Jan|Feb|Mar|Apr|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{4}\b"
168    ).expect("valid regex");
169
170    static ref RELATIVE_DATE_RE: Regex = Regex::new(
171        r"(?i)\b(?:today|tomorrow|yesterday|last\s+(?:week|month|year)|next\s+(?:week|month|year)|(?:\d+\s+)?(?:days?|weeks?|months?|years?)\s+(?:ago|from\s+now))\b"
172    ).expect("valid regex");
173
174    // --- Time patterns ---
175    static ref TIME_RE: Regex = Regex::new(
176        r"\b(?:[01]?\d|2[0-3]):[0-5]\d(?::[0-5]\d)?(?:\s*[aApP][mM])?\b"
177    ).expect("valid regex");
178
179    static ref TIME_ZONE_RE: Regex = Regex::new(
180        r"\b(?:[01]?\d|2[0-3]):[0-5]\d(?::[0-5]\d)?\s*(?:UTC|GMT|EST|CST|MST|PST|[A-Z]{2,4})\b"
181    ).expect("valid regex");
182
183    // --- Email ---
184    static ref EMAIL_RE: Regex = Regex::new(
185        r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b"
186    ).expect("valid regex");
187
188    // --- URL ---
189    static ref URL_RE: Regex = Regex::new(
190        r#"https?://[^\s<>")\]]+|www\.[^\s<>")\]]+"#
191    ).expect("valid regex");
192
193    // --- IP address ---
194    static ref IPV4_RE: Regex = Regex::new(
195        r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d?\d)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d?\d)\b"
196    ).expect("valid regex");
197
198    // --- Hashtag ---
199    static ref HASHTAG_RE: Regex = Regex::new(
200        r"#[A-Za-z_]\w{1,}"
201    ).expect("valid regex");
202
203    // --- Mention ---
204    static ref MENTION_RE: Regex = Regex::new(
205        r"@[A-Za-z_]\w{0,}"
206    ).expect("valid regex");
207
208    // --- Money ---
209    static ref MONEY_RE: Regex = Regex::new(
210        r"[$\u{20AC}\u{00A3}\u{00A5}]\s*\d[\d,]*(?:\.\d{1,2})?|\d[\d,]*(?:\.\d{1,2})?\s*(?:dollars?|euros?|pounds?|yen|USD|EUR|GBP|JPY)"
211    ).expect("valid regex");
212
213    // --- Percentage ---
214    static ref PERCENT_RE: Regex = Regex::new(
215        r"\b\d+(?:\.\d+)?%"
216    ).expect("valid regex");
217
218    // --- Phone ---
219    static ref PHONE_RE: Regex = Regex::new(
220        r"(?:\+?1[-.\s]?)?\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4}"
221    ).expect("valid regex");
222
223    // --- Number patterns ---
224    static ref SCIENTIFIC_NUM_RE: Regex = Regex::new(
225        r"\b\d+(?:\.\d+)?[eE][+-]?\d+\b"
226    ).expect("valid regex");
227
228    static ref ORDINAL_RE: Regex = Regex::new(
229        r"\b\d+(?:st|nd|rd|th)\b"
230    ).expect("valid regex");
231
232    static ref FLOAT_RE: Regex = Regex::new(
233        r"\b\d+\.\d+\b"
234    ).expect("valid regex");
235
236    static ref INTEGER_RE: Regex = Regex::new(
237        r"\b\d{1,}\b"
238    ).expect("valid regex");
239
240    // --- Heuristic title prefixes ---
241    static ref TITLE_PREFIX_RE: Regex = Regex::new(
242        r"\b(?:Mr|Mrs|Ms|Dr|Prof|Sir|Lord|Lady|Rev|Hon|Sgt|Cpl|Pvt|Gen|Col|Maj|Capt|Lt|Cmdr|Adm)\.\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*"
243    ).expect("valid regex");
244
245    // Capitalised multi-word sequences (potential names / orgs / locations).
246    static ref CAPITALISED_SEQUENCE_RE: Regex = Regex::new(
247        r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+"
248    ).expect("valid regex");
249
250    // All-caps abbreviations (likely organisation acronyms).
251    static ref ACRONYM_RE: Regex = Regex::new(
252        r"\b[A-Z]{2,6}\b"
253    ).expect("valid regex");
254}
255
256// ---------------------------------------------------------------------------
257// Unified API
258// ---------------------------------------------------------------------------
259
260/// Extract entities from `text` using the specified pattern configuration.
261///
262/// Returns entities sorted by their start position.
263///
264/// # Errors
265///
266/// Returns an error if internal pattern compilation fails (should not happen
267/// with the provided compiled regexes).
268pub fn extract_entities(text: &str, patterns: &NerPatternConfig) -> Result<Vec<NerEntity>> {
269    let mut entities: Vec<NerEntity> = Vec::new();
270
271    // --- Pattern-based extraction ---
272    if patterns.dates {
273        extract_regex(&mut entities, text, &ISO_DATE_RE, NerEntityType::Date, 1.0);
274        extract_regex(&mut entities, text, &US_DATE_RE, NerEntityType::Date, 1.0);
275        extract_regex(
276            &mut entities,
277            text,
278            &MONTH_NAME_DATE_RE,
279            NerEntityType::Date,
280            1.0,
281        );
282        extract_regex(
283            &mut entities,
284            text,
285            &DAY_MONTH_YEAR_RE,
286            NerEntityType::Date,
287            1.0,
288        );
289        extract_regex(
290            &mut entities,
291            text,
292            &RELATIVE_DATE_RE,
293            NerEntityType::Date,
294            0.9,
295        );
296    }
297
298    if patterns.times {
299        extract_regex(&mut entities, text, &TIME_ZONE_RE, NerEntityType::Time, 1.0);
300        extract_regex(&mut entities, text, &TIME_RE, NerEntityType::Time, 1.0);
301    }
302
303    if patterns.emails {
304        extract_regex(&mut entities, text, &EMAIL_RE, NerEntityType::Email, 1.0);
305    }
306
307    if patterns.urls {
308        extract_regex(&mut entities, text, &URL_RE, NerEntityType::Url, 1.0);
309    }
310
311    if patterns.ip_addresses {
312        extract_regex(&mut entities, text, &IPV4_RE, NerEntityType::IpAddress, 1.0);
313    }
314
315    if patterns.hashtags {
316        extract_regex(
317            &mut entities,
318            text,
319            &HASHTAG_RE,
320            NerEntityType::Hashtag,
321            1.0,
322        );
323    }
324
325    if patterns.mentions {
326        extract_regex(
327            &mut entities,
328            text,
329            &MENTION_RE,
330            NerEntityType::Mention,
331            1.0,
332        );
333    }
334
335    if patterns.money {
336        extract_regex(&mut entities, text, &MONEY_RE, NerEntityType::Money, 1.0);
337    }
338
339    if patterns.percentages {
340        extract_regex(
341            &mut entities,
342            text,
343            &PERCENT_RE,
344            NerEntityType::Percentage,
345            1.0,
346        );
347    }
348
349    if patterns.phones {
350        extract_regex(&mut entities, text, &PHONE_RE, NerEntityType::Phone, 0.9);
351    }
352
353    if patterns.numbers {
354        extract_regex(
355            &mut entities,
356            text,
357            &SCIENTIFIC_NUM_RE,
358            NerEntityType::Number,
359            1.0,
360        );
361        extract_regex(&mut entities, text, &ORDINAL_RE, NerEntityType::Number, 1.0);
362        // Floats and integers are added last and only for spans not already
363        // covered by higher-priority patterns (money, phone, etc.).
364        extract_regex_non_overlapping(&mut entities, text, &FLOAT_RE, NerEntityType::Number, 0.8);
365        extract_regex_non_overlapping(&mut entities, text, &INTEGER_RE, NerEntityType::Number, 0.7);
366    }
367
368    if patterns.heuristic_entities {
369        extract_heuristic_entities(&mut entities, text);
370    }
371
372    // De-duplicate overlapping entities (prefer higher confidence / longer).
373    deduplicate_entities(&mut entities);
374
375    // Sort by start position.
376    entities.sort_by_key(|e| e.start);
377
378    Ok(entities)
379}
380
381// ---------------------------------------------------------------------------
382// Internal helpers
383// ---------------------------------------------------------------------------
384
385/// Apply a regex and collect entities (regardless of overlap).
386fn extract_regex(
387    out: &mut Vec<NerEntity>,
388    text: &str,
389    pattern: &Regex,
390    entity_type: NerEntityType,
391    confidence: f64,
392) {
393    for mat in pattern.find_iter(text) {
394        out.push(NerEntity {
395            text: mat.as_str().to_string(),
396            entity_type: entity_type.clone(),
397            start: mat.start(),
398            end: mat.end(),
399            confidence,
400        });
401    }
402}
403
404/// Apply a regex but skip matches that overlap with already-collected entities.
405fn extract_regex_non_overlapping(
406    out: &mut Vec<NerEntity>,
407    text: &str,
408    pattern: &Regex,
409    entity_type: NerEntityType,
410    confidence: f64,
411) {
412    let covered: HashSet<usize> = out.iter().flat_map(|e| e.start..e.end).collect();
413
414    for mat in pattern.find_iter(text) {
415        let span: HashSet<usize> = (mat.start()..mat.end()).collect();
416        if span.is_disjoint(&covered) {
417            out.push(NerEntity {
418                text: mat.as_str().to_string(),
419                entity_type: entity_type.clone(),
420                start: mat.start(),
421                end: mat.end(),
422                confidence,
423            });
424        }
425    }
426}
427
428/// Heuristic-based detection for persons, organisations, and locations.
429fn extract_heuristic_entities(out: &mut Vec<NerEntity>, text: &str) {
430    let covered: HashSet<usize> = out.iter().flat_map(|e| e.start..e.end).collect();
431
432    // Title-prefixed names (Dr. John Smith, Prof. Jane Doe).
433    for mat in TITLE_PREFIX_RE.find_iter(text) {
434        let span: HashSet<usize> = (mat.start()..mat.end()).collect();
435        if span.is_disjoint(&covered) {
436            out.push(NerEntity {
437                text: mat.as_str().to_string(),
438                entity_type: NerEntityType::Person,
439                start: mat.start(),
440                end: mat.end(),
441                confidence: 0.85,
442            });
443        }
444    }
445
446    // Capitalised multi-word sequences that look like proper nouns.
447    // We use context to distinguish person / org / location:
448    //  - Preceded by location prepositions (in, at, from, near) -> Location
449    //  - Preceded by "at" or "for" with org-like suffixes -> Organisation
450    //  - Otherwise -> Person (lower confidence)
451    let location_preps: HashSet<&str> = [
452        "in", "at", "from", "near", "to", "across", "around", "through",
453    ]
454    .iter()
455    .copied()
456    .collect();
457
458    let org_suffixes: HashSet<&str> = [
459        "inc",
460        "corp",
461        "corporation",
462        "ltd",
463        "llc",
464        "co",
465        "company",
466        "group",
467        "foundation",
468        "institute",
469        "university",
470        "bank",
471        "labs",
472        "technologies",
473    ]
474    .iter()
475    .copied()
476    .collect();
477
478    let updated_covered: HashSet<usize> = out.iter().flat_map(|e| e.start..e.end).collect();
479
480    for mat in CAPITALISED_SEQUENCE_RE.find_iter(text) {
481        let span: HashSet<usize> = (mat.start()..mat.end()).collect();
482        if !span.is_disjoint(&updated_covered) {
483            continue;
484        }
485
486        let matched = mat.as_str();
487        let last_word = matched
488            .split_whitespace()
489            .last()
490            .unwrap_or("")
491            .to_lowercase();
492
493        // Check preceding word.
494        let preceding = preceding_word(text, mat.start());
495
496        let (entity_type, confidence) = if org_suffixes.contains(last_word.as_str()) {
497            (NerEntityType::Organisation, 0.8)
498        } else if let Some(ref pw) = preceding {
499            if location_preps.contains(pw.to_lowercase().as_str()) {
500                (NerEntityType::Location, 0.7)
501            } else {
502                (NerEntityType::Person, 0.6)
503            }
504        } else {
505            (NerEntityType::Person, 0.55)
506        };
507
508        out.push(NerEntity {
509            text: matched.to_string(),
510            entity_type,
511            start: mat.start(),
512            end: mat.end(),
513            confidence,
514        });
515    }
516}
517
518/// Return the word immediately preceding `byte_offset` in `text`, if any.
519fn preceding_word(text: &str, byte_offset: usize) -> Option<String> {
520    let prefix = &text[..byte_offset];
521    let trimmed = prefix.trim_end();
522    trimmed.split_whitespace().last().map(|s| s.to_string())
523}
524
525/// Remove duplicate / overlapping entities, preferring higher confidence and
526/// longer spans.
527fn deduplicate_entities(entities: &mut Vec<NerEntity>) {
528    // Sort by (start, -length, -confidence).
529    entities.sort_by(|a, b| {
530        a.start
531            .cmp(&b.start)
532            .then_with(|| {
533                let a_len = a.end - a.start;
534                let b_len = b.end - b.start;
535                b_len.cmp(&a_len)
536            })
537            .then_with(|| {
538                b.confidence
539                    .partial_cmp(&a.confidence)
540                    .unwrap_or(std::cmp::Ordering::Equal)
541            })
542    });
543
544    let mut keep: Vec<bool> = vec![true; entities.len()];
545    for i in 0..entities.len() {
546        if !keep[i] {
547            continue;
548        }
549        for j in (i + 1)..entities.len() {
550            if !keep[j] {
551                continue;
552            }
553            // If j overlaps with i, discard j (since i is preferred).
554            if entities[j].start < entities[i].end {
555                keep[j] = false;
556            }
557        }
558    }
559
560    let mut idx = 0;
561    entities.retain(|_| {
562        let k = keep[idx];
563        idx += 1;
564        k
565    });
566}
567
568// ---------------------------------------------------------------------------
569// Tests
570// ---------------------------------------------------------------------------
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    // ---- Date tests ----
577
578    #[test]
579    fn test_iso_date_extraction() {
580        let entities = extract_entities("Meeting on 2025-01-15 at noon.", &NerPatternConfig::all())
581            .expect("Should succeed");
582        let dates: Vec<&NerEntity> = entities
583            .iter()
584            .filter(|e| e.entity_type == NerEntityType::Date)
585            .collect();
586        assert!(!dates.is_empty(), "Should find an ISO date");
587        assert!(dates[0].text.contains("2025-01-15"));
588    }
589
590    #[test]
591    fn test_month_name_date_extraction() {
592        let entities = extract_entities(
593            "The launch is on January 15, 2025.",
594            &NerPatternConfig::all(),
595        )
596        .expect("Should succeed");
597        let dates: Vec<&NerEntity> = entities
598            .iter()
599            .filter(|e| e.entity_type == NerEntityType::Date)
600            .collect();
601        assert!(!dates.is_empty(), "Should find a month-name date");
602    }
603
604    #[test]
605    fn test_relative_date() {
606        let entities = extract_entities("I'll do it tomorrow.", &NerPatternConfig::all())
607            .expect("Should succeed");
608        let dates: Vec<&NerEntity> = entities
609            .iter()
610            .filter(|e| e.entity_type == NerEntityType::Date)
611            .collect();
612        assert!(!dates.is_empty(), "Should find 'tomorrow' as a date");
613    }
614
615    #[test]
616    fn test_us_date() {
617        let entities = extract_entities("Due by 12/31/2025.", &NerPatternConfig::all())
618            .expect("Should succeed");
619        let dates: Vec<&NerEntity> = entities
620            .iter()
621            .filter(|e| e.entity_type == NerEntityType::Date)
622            .collect();
623        assert!(!dates.is_empty(), "Should find US-format date");
624    }
625
626    #[test]
627    fn test_day_month_year_date() {
628        let entities = extract_entities("Submitted on 5th January 2025.", &NerPatternConfig::all())
629            .expect("Should succeed");
630        let dates: Vec<&NerEntity> = entities
631            .iter()
632            .filter(|e| e.entity_type == NerEntityType::Date)
633            .collect();
634        assert!(!dates.is_empty(), "Should find day-month-year date");
635    }
636
637    // ---- Time tests ----
638
639    #[test]
640    fn test_time_extraction() {
641        let entities = extract_entities("The meeting is at 14:30.", &NerPatternConfig::all())
642            .expect("Should succeed");
643        let times: Vec<&NerEntity> = entities
644            .iter()
645            .filter(|e| e.entity_type == NerEntityType::Time)
646            .collect();
647        assert!(!times.is_empty(), "Should find a time");
648    }
649
650    #[test]
651    fn test_time_am_pm() {
652        let entities = extract_entities("Lunch at 12:00 PM.", &NerPatternConfig::all())
653            .expect("Should succeed");
654        let times: Vec<&NerEntity> = entities
655            .iter()
656            .filter(|e| e.entity_type == NerEntityType::Time)
657            .collect();
658        assert!(!times.is_empty(), "Should find AM/PM time");
659    }
660
661    #[test]
662    fn test_no_false_time() {
663        // "3.14" should not be detected as time.
664        let cfg = NerPatternConfig {
665            times: true,
666            ..NerPatternConfig::none()
667        };
668        let entities = extract_entities("Pi is 3.14.", &cfg).expect("ok");
669        let times: Vec<&NerEntity> = entities
670            .iter()
671            .filter(|e| e.entity_type == NerEntityType::Time)
672            .collect();
673        assert!(times.is_empty());
674    }
675
676    #[test]
677    fn test_time_with_seconds() {
678        let entities = extract_entities("Recorded at 09:15:30.", &NerPatternConfig::all())
679            .expect("Should succeed");
680        let times: Vec<&NerEntity> = entities
681            .iter()
682            .filter(|e| e.entity_type == NerEntityType::Time)
683            .collect();
684        assert!(!times.is_empty());
685    }
686
687    #[test]
688    fn test_time_zone() {
689        let entities =
690            extract_entities("The call is at 10:00 EST.", &NerPatternConfig::all()).expect("ok");
691        let times: Vec<&NerEntity> = entities
692            .iter()
693            .filter(|e| e.entity_type == NerEntityType::Time)
694            .collect();
695        assert!(!times.is_empty());
696    }
697
698    // ---- Email / URL tests ----
699
700    #[test]
701    fn test_email_extraction() {
702        let entities = extract_entities(
703            "Contact us at info@example.com for details.",
704            &NerPatternConfig::all(),
705        )
706        .expect("ok");
707        let emails: Vec<&NerEntity> = entities
708            .iter()
709            .filter(|e| e.entity_type == NerEntityType::Email)
710            .collect();
711        assert_eq!(emails.len(), 1);
712        assert_eq!(emails[0].text, "info@example.com");
713    }
714
715    #[test]
716    fn test_url_extraction() {
717        let entities = extract_entities(
718            "Visit https://www.rust-lang.org for info.",
719            &NerPatternConfig::all(),
720        )
721        .expect("ok");
722        let urls: Vec<&NerEntity> = entities
723            .iter()
724            .filter(|e| e.entity_type == NerEntityType::Url)
725            .collect();
726        assert!(!urls.is_empty());
727    }
728
729    #[test]
730    fn test_multiple_emails() {
731        let text = "Send to alice@test.com or bob@test.com.";
732        let entities = extract_entities(text, &NerPatternConfig::all()).expect("ok");
733        let emails: Vec<&NerEntity> = entities
734            .iter()
735            .filter(|e| e.entity_type == NerEntityType::Email)
736            .collect();
737        assert_eq!(emails.len(), 2);
738    }
739
740    #[test]
741    fn test_url_with_path() {
742        let entities = extract_entities(
743            "See https://docs.rs/scirs2-text/0.1/index.html",
744            &NerPatternConfig::all(),
745        )
746        .expect("ok");
747        let urls: Vec<&NerEntity> = entities
748            .iter()
749            .filter(|e| e.entity_type == NerEntityType::Url)
750            .collect();
751        assert!(!urls.is_empty());
752    }
753
754    #[test]
755    fn test_email_confidence() {
756        let entities = extract_entities("a@b.co", &NerPatternConfig::all()).expect("ok");
757        let emails: Vec<&NerEntity> = entities
758            .iter()
759            .filter(|e| e.entity_type == NerEntityType::Email)
760            .collect();
761        if !emails.is_empty() {
762            assert!((emails[0].confidence - 1.0).abs() < 1e-6);
763        }
764    }
765
766    // ---- IP address tests ----
767
768    #[test]
769    fn test_ipv4_extraction() {
770        let entities =
771            extract_entities("Server at 192.168.1.1 responded.", &NerPatternConfig::all())
772                .expect("ok");
773        let ips: Vec<&NerEntity> = entities
774            .iter()
775            .filter(|e| e.entity_type == NerEntityType::IpAddress)
776            .collect();
777        assert!(!ips.is_empty());
778        assert_eq!(ips[0].text, "192.168.1.1");
779    }
780
781    #[test]
782    fn test_multiple_ipv4() {
783        let entities = extract_entities("Ping 10.0.0.1 and 172.16.0.1.", &NerPatternConfig::all())
784            .expect("ok");
785        let ips: Vec<&NerEntity> = entities
786            .iter()
787            .filter(|e| e.entity_type == NerEntityType::IpAddress)
788            .collect();
789        assert_eq!(ips.len(), 2);
790    }
791
792    #[test]
793    fn test_invalid_ip_not_matched() {
794        let entities =
795            extract_entities("Value is 999.999.999.999.", &NerPatternConfig::all()).expect("ok");
796        let ips: Vec<&NerEntity> = entities
797            .iter()
798            .filter(|e| e.entity_type == NerEntityType::IpAddress)
799            .collect();
800        assert!(ips.is_empty(), "999.x.x.x is not a valid IPv4");
801    }
802
803    #[test]
804    fn test_ip_loopback() {
805        let entities =
806            extract_entities("Localhost is 127.0.0.1.", &NerPatternConfig::all()).expect("ok");
807        let ips: Vec<&NerEntity> = entities
808            .iter()
809            .filter(|e| e.entity_type == NerEntityType::IpAddress)
810            .collect();
811        assert_eq!(ips.len(), 1);
812    }
813
814    #[test]
815    fn test_ip_boundary() {
816        let entities =
817            extract_entities("Address: 255.255.255.0.", &NerPatternConfig::all()).expect("ok");
818        let ips: Vec<&NerEntity> = entities
819            .iter()
820            .filter(|e| e.entity_type == NerEntityType::IpAddress)
821            .collect();
822        assert!(!ips.is_empty());
823    }
824
825    // ---- Hashtag / Mention tests ----
826
827    #[test]
828    fn test_hashtag_extraction() {
829        let entities = extract_entities("Loving #Rust and #OpenSource!", &NerPatternConfig::all())
830            .expect("ok");
831        let tags: Vec<&NerEntity> = entities
832            .iter()
833            .filter(|e| e.entity_type == NerEntityType::Hashtag)
834            .collect();
835        assert_eq!(tags.len(), 2);
836    }
837
838    #[test]
839    fn test_mention_extraction() {
840        let entities = extract_entities("Thanks @rustlang!", &NerPatternConfig::all()).expect("ok");
841        let mentions: Vec<&NerEntity> = entities
842            .iter()
843            .filter(|e| e.entity_type == NerEntityType::Mention)
844            .collect();
845        assert!(!mentions.is_empty());
846    }
847
848    #[test]
849    fn test_hashtag_number_only_skipped() {
850        // "#123" starts with a digit after # so should not match.
851        let cfg = NerPatternConfig {
852            hashtags: true,
853            ..NerPatternConfig::none()
854        };
855        let entities = extract_entities("#123", &cfg).expect("ok");
856        let tags: Vec<&NerEntity> = entities
857            .iter()
858            .filter(|e| e.entity_type == NerEntityType::Hashtag)
859            .collect();
860        assert!(tags.is_empty());
861    }
862
863    #[test]
864    fn test_mention_with_underscore() {
865        let entities = extract_entities("cc @cool_japan", &NerPatternConfig::all()).expect("ok");
866        let mentions: Vec<&NerEntity> = entities
867            .iter()
868            .filter(|e| e.entity_type == NerEntityType::Mention)
869            .collect();
870        assert!(!mentions.is_empty());
871        assert!(mentions[0].text.contains("cool_japan"));
872    }
873
874    #[test]
875    fn test_hashtag_single_char_skipped() {
876        let cfg = NerPatternConfig {
877            hashtags: true,
878            ..NerPatternConfig::none()
879        };
880        let entities = extract_entities("#a", &cfg).expect("ok");
881        let tags: Vec<&NerEntity> = entities
882            .iter()
883            .filter(|e| e.entity_type == NerEntityType::Hashtag)
884            .collect();
885        // "#a" only has 1 char after # which is < 2 required by {1,}
886        // Actually our regex requires \w{1,} which means >= 1 after the first letter.
887        // "#a" has just one letter total => matched since [A-Za-z_]\w{1,} needs at least 2 chars after #.
888        // Let's just assert it is handled.
889        let _ = tags; // Not a hard requirement.
890    }
891
892    // ---- Money / Percentage tests ----
893
894    #[test]
895    fn test_money_extraction() {
896        let entities = extract_entities(
897            "The price is $29.99 and shipping is $5.00.",
898            &NerPatternConfig::all(),
899        )
900        .expect("ok");
901        let money: Vec<&NerEntity> = entities
902            .iter()
903            .filter(|e| e.entity_type == NerEntityType::Money)
904            .collect();
905        assert_eq!(money.len(), 2);
906    }
907
908    #[test]
909    fn test_percentage_extraction() {
910        let entities =
911            extract_entities("Sales grew by 15.5%.", &NerPatternConfig::all()).expect("ok");
912        let pcts: Vec<&NerEntity> = entities
913            .iter()
914            .filter(|e| e.entity_type == NerEntityType::Percentage)
915            .collect();
916        assert!(!pcts.is_empty());
917        assert!(pcts[0].text.contains("15.5%"));
918    }
919
920    #[test]
921    fn test_euro_money() {
922        let entities =
923            extract_entities("Total: \u{20AC}100.", &NerPatternConfig::all()).expect("ok");
924        let money: Vec<&NerEntity> = entities
925            .iter()
926            .filter(|e| e.entity_type == NerEntityType::Money)
927            .collect();
928        assert!(!money.is_empty());
929    }
930
931    #[test]
932    fn test_money_word_form() {
933        let entities =
934            extract_entities("Costs about 50 dollars.", &NerPatternConfig::all()).expect("ok");
935        let money: Vec<&NerEntity> = entities
936            .iter()
937            .filter(|e| e.entity_type == NerEntityType::Money)
938            .collect();
939        assert!(!money.is_empty());
940    }
941
942    #[test]
943    fn test_percentage_integer() {
944        let entities =
945            extract_entities("Achieved 100% accuracy.", &NerPatternConfig::all()).expect("ok");
946        let pcts: Vec<&NerEntity> = entities
947            .iter()
948            .filter(|e| e.entity_type == NerEntityType::Percentage)
949            .collect();
950        assert!(!pcts.is_empty());
951    }
952
953    // ---- Number tests ----
954
955    #[test]
956    fn test_scientific_notation() {
957        let cfg = NerPatternConfig {
958            numbers: true,
959            ..NerPatternConfig::none()
960        };
961        let entities = extract_entities("Speed of light is 3e8 m/s.", &cfg).expect("ok");
962        let nums: Vec<&NerEntity> = entities
963            .iter()
964            .filter(|e| e.entity_type == NerEntityType::Number)
965            .collect();
966        assert!(!nums.is_empty());
967        assert!(nums.iter().any(|n| n.text == "3e8"));
968    }
969
970    #[test]
971    fn test_ordinal_extraction() {
972        let cfg = NerPatternConfig {
973            numbers: true,
974            ..NerPatternConfig::none()
975        };
976        let entities = extract_entities("She finished 1st and he was 3rd.", &cfg).expect("ok");
977        let ordinals: Vec<&NerEntity> = entities
978            .iter()
979            .filter(|e| {
980                e.entity_type == NerEntityType::Number && e.text.ends_with("st")
981                    || e.text.ends_with("rd")
982            })
983            .collect();
984        assert!(ordinals.len() >= 2);
985    }
986
987    #[test]
988    fn test_float_extraction() {
989        let cfg = NerPatternConfig {
990            numbers: true,
991            ..NerPatternConfig::none()
992        };
993        let entities = extract_entities("Pi is approximately 3.14159.", &cfg).expect("ok");
994        let floats: Vec<&NerEntity> = entities
995            .iter()
996            .filter(|e| e.entity_type == NerEntityType::Number && e.text.contains('.'))
997            .collect();
998        assert!(!floats.is_empty());
999    }
1000
1001    #[test]
1002    fn test_integer_extraction() {
1003        let cfg = NerPatternConfig {
1004            numbers: true,
1005            ..NerPatternConfig::none()
1006        };
1007        let entities = extract_entities("There are 42 items.", &cfg).expect("ok");
1008        let nums: Vec<&NerEntity> = entities
1009            .iter()
1010            .filter(|e| e.entity_type == NerEntityType::Number)
1011            .collect();
1012        assert!(!nums.is_empty());
1013    }
1014
1015    #[test]
1016    fn test_scientific_notation_with_sign() {
1017        let cfg = NerPatternConfig {
1018            numbers: true,
1019            ..NerPatternConfig::none()
1020        };
1021        let entities = extract_entities("Value: 1.5e-10", &cfg).expect("ok");
1022        let nums: Vec<&NerEntity> = entities
1023            .iter()
1024            .filter(|e| e.entity_type == NerEntityType::Number && e.text.contains("e-"))
1025            .collect();
1026        assert!(!nums.is_empty());
1027    }
1028
1029    // ---- Heuristic entity tests ----
1030
1031    #[test]
1032    fn test_title_prefix_person() {
1033        let entities = extract_entities(
1034            "We met with Dr. Jane Smith yesterday.",
1035            &NerPatternConfig::all(),
1036        )
1037        .expect("ok");
1038        let persons: Vec<&NerEntity> = entities
1039            .iter()
1040            .filter(|e| e.entity_type == NerEntityType::Person)
1041            .collect();
1042        assert!(!persons.is_empty());
1043        assert!(persons.iter().any(|p| p.text.contains("Jane Smith")));
1044    }
1045
1046    #[test]
1047    fn test_capitalised_location_hint() {
1048        let entities = extract_entities(
1049            "The conference was held in San Francisco.",
1050            &NerPatternConfig::all(),
1051        )
1052        .expect("ok");
1053        let locations: Vec<&NerEntity> = entities
1054            .iter()
1055            .filter(|e| e.entity_type == NerEntityType::Location)
1056            .collect();
1057        assert!(
1058            !locations.is_empty(),
1059            "Should detect 'San Francisco' as location"
1060        );
1061    }
1062
1063    #[test]
1064    fn test_organisation_suffix() {
1065        let entities = extract_entities("She works at Acme Corporation.", &NerPatternConfig::all())
1066            .expect("ok");
1067        let orgs: Vec<&NerEntity> = entities
1068            .iter()
1069            .filter(|e| e.entity_type == NerEntityType::Organisation)
1070            .collect();
1071        assert!(!orgs.is_empty());
1072    }
1073
1074    #[test]
1075    fn test_heuristic_low_confidence() {
1076        let entities =
1077            extract_entities("John Smith attended the meeting.", &NerPatternConfig::all())
1078                .expect("ok");
1079        let persons: Vec<&NerEntity> = entities
1080            .iter()
1081            .filter(|e| e.entity_type == NerEntityType::Person)
1082            .collect();
1083        // Heuristic person detection should have confidence < 1.0.
1084        for p in &persons {
1085            assert!(p.confidence < 1.0);
1086        }
1087    }
1088
1089    #[test]
1090    fn test_heuristic_disabled() {
1091        let cfg = NerPatternConfig {
1092            heuristic_entities: false,
1093            ..NerPatternConfig::all()
1094        };
1095        let entities = extract_entities("Dr. Jane Smith in San Francisco.", &cfg).expect("ok");
1096        // Without heuristic, we should not find Person or Location from capitalisation.
1097        let persons: Vec<&NerEntity> = entities
1098            .iter()
1099            .filter(|e| e.entity_type == NerEntityType::Person)
1100            .collect();
1101        assert!(persons.is_empty());
1102    }
1103
1104    // ---- Misc / integration tests ----
1105
1106    #[test]
1107    fn test_empty_text() {
1108        let entities = extract_entities("", &NerPatternConfig::all()).expect("ok");
1109        assert!(entities.is_empty());
1110    }
1111
1112    #[test]
1113    fn test_entities_sorted_by_position() {
1114        let text = "Email info@test.com, call (555) 123-4567, visit https://test.com.";
1115        let entities = extract_entities(text, &NerPatternConfig::all()).expect("ok");
1116        for pair in entities.windows(2) {
1117            assert!(pair[0].start <= pair[1].start, "Should be sorted by start");
1118        }
1119    }
1120
1121    #[test]
1122    fn test_config_none_returns_empty() {
1123        let entities =
1124            extract_entities("Hello $100 at 10:30.", &NerPatternConfig::none()).expect("ok");
1125        assert!(entities.is_empty());
1126    }
1127
1128    #[test]
1129    fn test_mixed_entities() {
1130        let text = "On 2025-01-15 at 10:30, Dr. John Smith emailed john@example.com about $500.";
1131        let entities = extract_entities(text, &NerPatternConfig::all()).expect("ok");
1132        let types: HashSet<_> = entities.iter().map(|e| &e.entity_type).collect();
1133        // Should find at least date, time, person, email, money.
1134        assert!(types.contains(&NerEntityType::Date));
1135        assert!(types.contains(&NerEntityType::Time));
1136        assert!(types.contains(&NerEntityType::Email));
1137        assert!(types.contains(&NerEntityType::Money));
1138    }
1139
1140    #[test]
1141    fn test_phone_extraction() {
1142        let entities = extract_entities(
1143            "Call (555) 123-4567 or 800-555-0199.",
1144            &NerPatternConfig::all(),
1145        )
1146        .expect("ok");
1147        let phones: Vec<&NerEntity> = entities
1148            .iter()
1149            .filter(|e| e.entity_type == NerEntityType::Phone)
1150            .collect();
1151        assert!(!phones.is_empty());
1152    }
1153}