1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::OnceLock;
9
10use anyhow::{Context, Result};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::{Linear, Module, VarBuilder};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use regex::Regex;
15use serde::Deserialize;
16use unicode_normalization::UnicodeNormalization;
17
18use crate::entity_type::EntityType;
19use crate::paths::AppPaths;
20use crate::storage::entities::{NewEntity, NewRelationship};
21
22const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
23const MAX_SEQ_LEN: usize = 512;
24const STRIDE: usize = 256;
25const MAX_ENTS: usize = 30;
26#[cfg(test)]
29const TOP_K_RELATIONS: usize = 5;
30const DEFAULT_RELATION: &str = "mentions";
31const MIN_ENTITY_CHARS: usize = 2;
32
33static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
34static REGEX_URL: OnceLock<Regex> = OnceLock::new();
35static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
36static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
37static REGEX_SECTION_MARKER: OnceLock<Regex> = OnceLock::new();
39static REGEX_BRAND_CAMEL: OnceLock<Regex> = OnceLock::new();
41
42const ALL_CAPS_STOPWORDS: &[&str] = &[
60 "ACEITE",
61 "ACID",
62 "ACK",
63 "ACL",
64 "ACRESCENTADO",
65 "ADAPTER",
66 "ADICIONADA",
67 "ADICIONADAS",
68 "ADICIONADO",
69 "ADICIONADOS",
70 "ADICIONAR",
71 "AGENTS",
72 "AINDA",
73 "ALL",
74 "ALTA",
75 "ALWAYS",
76 "APENAS",
77 "API",
78 "ARTEFATOS",
79 "ATIVA",
80 "ATIVO",
81 "BAIXA",
82 "BANCO",
83 "BLOQUEAR",
84 "BORDA",
85 "BUG",
86 "CAPÍTULO",
87 "CASO",
88 "CEO",
89 "CHECKLIST",
90 "CLARO",
91 "CLI",
92 "COMPLETED",
93 "CONFIRMADO",
94 "CONFIRMARAM",
95 "CONFIRME",
96 "CONFIRMEI",
97 "CONFIRMOU",
98 "CONTRATO",
99 "CRIE",
100 "CRÍTICO",
101 "CRITICAL",
102 "CSV",
103 "DDL",
104 "DEFAULT",
105 "DEFINIR",
106 "DEPARTMENT",
107 "DESC",
108 "DEVE",
109 "DEVEMOS",
110 "DISCO",
111 "DONE",
112 "DSL",
113 "DTO",
114 "EFEITO",
115 "ENTRADA",
116 "EPERM",
117 "ERROR",
118 "ESCREVA",
119 "ESCRITA",
120 "ESRCH",
121 "ESSA",
122 "ESSE",
123 "ESSENCIAL",
124 "ESTA",
125 "ESTADO",
126 "ESTE",
127 "ETAPA",
128 "EVITAR",
129 "EXEMPLO",
130 "EXPANDIR",
131 "EXPOR",
132 "FALHA",
133 "FASE",
134 "FATO",
135 "FIFO",
136 "FIXED",
137 "FIXME",
138 "FLUXO",
139 "FONTES",
140 "FORBIDDEN",
141 "FUNCIONA",
142 "HACK",
143 "HEARTBEAT",
144 "HTTP",
145 "HTTPS",
146 "INATIVO",
147 "JAMAIS",
148 "JSON",
149 "JWT",
150 "LEITURA",
151 "LLM",
152 "MESMO",
153 "METADADOS",
154 "MUST",
155 "NEGUE",
156 "NEVER",
157 "NOTE",
158 "NUNCA",
159 "OBRIGATORIA",
160 "OBRIGATÓRIO",
161 "PADRÃO",
162 "PASSIVA",
163 "PASSO",
164 "PENDING",
165 "PLAN",
166 "PODEMOS",
167 "PONTEIROS",
168 "PROIBIDO",
169 "PROJETO",
170 "RECUSE",
171 "REGRA",
172 "REGRAS",
173 "REQUIRED",
174 "REQUISITO",
175 "REST",
176 "SEÇÃO",
177 "SEMPRE",
178 "SHALL",
179 "SHOULD",
180 "SOMENTE",
181 "SOUL",
182 "TODAS",
183 "TODO",
184 "TODOS",
185 "TOKEN",
186 "TOOLS",
187 "TSV",
188 "UI",
189 "URL",
190 "USAR",
191 "VALIDAR",
192 "VAMOS",
193 "VOCÊ",
194 "WARNING",
195 "XML",
196 "YAML",
197];
198
199const HTTP_METHODS: &[&str] = &[
202 "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
203];
204
205fn is_filtered_all_caps(token: &str) -> bool {
206 let is_identifier = token.contains('_');
208 if is_identifier {
209 return false;
210 }
211 ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
212}
213
214fn regex_email() -> &'static Regex {
215 REGEX_EMAIL.get_or_init(|| {
217 Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
218 .expect("compile-time validated email regex literal")
219 })
220}
221
222fn regex_url() -> &'static Regex {
223 REGEX_URL.get_or_init(|| {
225 Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#)
226 .expect("compile-time validated URL regex literal")
227 })
228}
229
230fn regex_uuid() -> &'static Regex {
231 REGEX_UUID.get_or_init(|| {
233 Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
234 .expect("compile-time validated UUID regex literal")
235 })
236}
237
238fn regex_all_caps() -> &'static Regex {
239 REGEX_ALL_CAPS.get_or_init(|| {
240 Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b")
241 .expect("compile-time validated all-caps regex literal")
242 })
243}
244
245fn regex_section_marker() -> &'static Regex {
246 REGEX_SECTION_MARKER.get_or_init(|| {
247 Regex::new("\\b(?:Etapa|Fase|Passo|Camada|Se\u{00e7}\u{00e3}o|Cap\u{00ed}tulo)\\s+\\d+\\b")
254 .expect("compile-time validated section marker regex literal")
255 })
256}
257
258fn regex_brand_camel() -> &'static Regex {
259 REGEX_BRAND_CAMEL.get_or_init(|| {
260 Regex::new(r"\b[A-Z][a-z]+[A-Z][A-Za-z]+\b")
263 .expect("compile-time validated CamelCase brand regex literal")
264 })
265}
266
267#[derive(Debug, Clone, PartialEq)]
268pub struct ExtractedEntity {
269 pub name: String,
270 pub entity_type: EntityType,
271}
272
273#[derive(Debug, Clone)]
275pub struct ExtractedUrl {
276 pub url: String,
277 pub offset: usize,
279}
280
281#[derive(Debug, Clone)]
282pub struct ExtractionResult {
283 pub entities: Vec<NewEntity>,
284 pub relationships: Vec<NewRelationship>,
285 pub relationships_truncated: bool,
288 pub extraction_method: String,
291 pub urls: Vec<ExtractedUrl>,
293}
294
295pub trait Extractor: Send + Sync {
296 fn extract(&self, body: &str) -> Result<ExtractionResult>;
297}
298
299#[derive(Deserialize)]
300struct ModelConfig {
301 #[serde(default)]
302 id2label: HashMap<String, String>,
303 hidden_size: usize,
304}
305
306struct BertNerModel {
307 bert: BertModel,
308 classifier: Linear,
309 device: Device,
310 id2label: HashMap<usize, String>,
311}
312
313impl BertNerModel {
314 fn load(model_dir: &Path) -> Result<Self> {
315 let config_path = model_dir.join("config.json");
316 let weights_path = model_dir.join("model.safetensors");
317
318 let config_str = std::fs::read_to_string(&config_path)
319 .with_context(|| format!("lendo config.json em {config_path:?}"))?;
320 let model_cfg: ModelConfig =
321 serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
322
323 let id2label: HashMap<usize, String> = model_cfg
324 .id2label
325 .into_iter()
326 .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
327 .collect();
328
329 let num_labels = id2label.len().max(9);
330 let hidden_size = model_cfg.hidden_size;
331
332 let bert_config_str = std::fs::read_to_string(&config_path)
333 .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
334 let bert_cfg: BertConfig =
335 serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
336
337 let device = Device::Cpu;
338
339 let vb = unsafe {
347 VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
348 .with_context(|| format!("mapping {weights_path:?}"))?
349 };
350 let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("loading BertModel")?;
351
352 let cls_vb = vb.pp("classifier");
355 let weight = cls_vb
356 .get((num_labels, hidden_size), "weight")
357 .context("carregando classifier.weight do safetensors")?;
358 let bias = cls_vb
359 .get(num_labels, "bias")
360 .context("carregando classifier.bias do safetensors")?;
361 let classifier = Linear::new(weight, Some(bias));
362
363 Ok(Self {
364 bert,
365 classifier,
366 device,
367 id2label,
368 })
369 }
370
371 fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
372 let len = token_ids.len();
373 let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
374 let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
375
376 let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
377 .context("creating tensor input_ids")?;
378 let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
379 .context("creating tensor token_type_ids")?;
380 let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
381 .context("creating tensor attention_mask")?;
382
383 let sequence_output = self
384 .bert
385 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
386 .context("BertModel forward pass")?;
387
388 let logits = self
389 .classifier
390 .forward(&sequence_output)
391 .context("classifier forward pass")?;
392
393 let logits_2d = logits.squeeze(0).context("removing batch dimension")?;
394
395 let num_tokens = logits_2d.dim(0).context("dim(0)")?;
396
397 let mut labels = Vec::with_capacity(num_tokens);
398 for i in 0..num_tokens {
399 let token_logits = logits_2d.get(i).context("get token logits")?;
400 let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
401 let argmax = vec
402 .iter()
403 .enumerate()
404 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
405 .map(|(idx, _)| idx)
406 .unwrap_or(0);
407 let label = self
408 .id2label
409 .get(&argmax)
410 .cloned()
411 .unwrap_or_else(|| "O".to_string());
412 labels.push(label);
413 }
414
415 Ok(labels)
416 }
417
418 fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
427 let batch_size = windows.len();
428 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
429 if max_len == 0 {
430 return Ok(vec![vec![]; batch_size]);
431 }
432
433 let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
434 let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
435
436 for (ids, _) in windows {
437 let len = ids.len();
438 let pad_right = max_len - len;
439
440 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
441 let t = Tensor::from_vec(ids_i64, len, &self.device)
443 .context("creating id tensor for batch")?;
444 let t = t
445 .pad_with_zeros(0, 0, pad_right)
446 .context("padding id tensor")?;
447 padded_ids.push(t);
448
449 let mut mask_i64 = vec![1i64; len];
451 mask_i64.extend(vec![0i64; pad_right]);
452 let m = Tensor::from_vec(mask_i64, max_len, &self.device)
453 .context("creating mask tensor for batch")?;
454 padded_masks.push(m);
455 }
456
457 let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
459 let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
460 let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
461 .context("creating token_type_ids tensor for batch")?;
462
463 let sequence_output = self
465 .bert
466 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
467 .context("BertModel batch forward pass")?;
468 let logits = self
471 .classifier
472 .forward(&sequence_output)
473 .context("forward pass batch classificador")?;
474 let mut results = Vec::with_capacity(batch_size);
477 for (i, (window_ids, _)) in windows.iter().enumerate() {
478 let example_logits = logits.get(i).context("get logits exemplo")?;
479 let real_len = window_ids.len();
481 let example_slice = example_logits
482 .narrow(0, 0, real_len)
483 .context("narrow para tokens reais")?;
484 let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
485
486 let labels: Vec<String> = logits_2d
487 .iter()
488 .map(|token_logits| {
489 let argmax = token_logits
490 .iter()
491 .enumerate()
492 .max_by(|(_, a), (_, b)| {
493 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
494 })
495 .map(|(idx, _)| idx)
496 .unwrap_or(0);
497 self.id2label
498 .get(&argmax)
499 .cloned()
500 .unwrap_or_else(|| "O".to_string())
501 })
502 .collect();
503
504 results.push(labels);
505 }
506
507 Ok(results)
508 }
509}
510
511static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
512
513fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
514 NER_MODEL
515 .get_or_init(|| match load_model(paths) {
516 Ok(m) => Some(m),
517 Err(e) => {
518 tracing::warn!("NER model unavailable (graceful degradation): {e:#}");
519 None
520 }
521 })
522 .as_ref()
523}
524
525fn model_dir(paths: &AppPaths) -> PathBuf {
526 paths.models.join("bert-multilingual-ner")
527}
528
529fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
530 let dir = model_dir(paths);
531 std::fs::create_dir_all(&dir).with_context(|| format!("creating model directory: {dir:?}"))?;
532
533 let weights = dir.join("model.safetensors");
534 let config = dir.join("config.json");
535 let tokenizer = dir.join("tokenizer.json");
536
537 if weights.exists() && config.exists() && tokenizer.exists() {
538 return Ok(dir);
539 }
540
541 tracing::info!("Downloading NER model (first run, ~676 MB)...");
542 crate::output::emit_progress_i18n(
543 "Downloading NER model (first run, ~676 MB)...",
544 crate::i18n::validation::runtime_pt::downloading_ner_model(),
545 );
546
547 let api = huggingface_hub::api::sync::Api::new().context("creating HF Hub client")?;
548 let repo = api.model(MODEL_ID.to_string());
549
550 for (remote, local) in &[
554 ("model.safetensors", "model.safetensors"),
555 ("config.json", "config.json"),
556 ("onnx/tokenizer.json", "tokenizer.json"),
557 ("tokenizer_config.json", "tokenizer_config.json"),
558 ] {
559 let dest = dir.join(local);
560 if !dest.exists() {
561 let src = repo
562 .get(remote)
563 .with_context(|| format!("baixando {remote} do HF Hub"))?;
564 std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
565 }
566 }
567
568 Ok(dir)
569}
570
571fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
572 let dir = ensure_model_files(paths)?;
573 BertNerModel::load(&dir)
574}
575
576#[inline]
578fn hash_str(s: &str) -> u64 {
579 use std::hash::{Hash, Hasher};
580 let mut h = std::collections::hash_map::DefaultHasher::new();
581 s.hash(&mut h);
582 h.finish()
583}
584
585fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
586 let mut entities = Vec::with_capacity(16);
587 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
588
589 let add = |entities: &mut Vec<ExtractedEntity>,
590 seen: &mut std::collections::HashSet<String>,
591 name: &str,
592 entity_type: EntityType| {
593 let name = name.trim().to_string();
594 if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
595 entities.push(ExtractedEntity { name, entity_type });
596 }
597 };
598
599 let cleaned = regex_section_marker().replace_all(body, " ");
602 let cleaned = cleaned.as_ref();
603
604 for m in regex_email().find_iter(cleaned) {
605 add(&mut entities, &mut seen, m.as_str(), EntityType::Concept);
607 }
608 for m in regex_uuid().find_iter(cleaned) {
609 add(&mut entities, &mut seen, m.as_str(), EntityType::Concept);
610 }
611 for m in regex_all_caps().find_iter(cleaned) {
612 let candidate = m.as_str();
613 if !is_filtered_all_caps(candidate) {
615 add(&mut entities, &mut seen, candidate, EntityType::Concept);
616 }
617 }
618 for m in regex_brand_camel().find_iter(cleaned) {
621 let name = m.as_str();
622 if !ALL_CAPS_STOPWORDS.contains(&name.to_uppercase().as_str()) {
624 add(&mut entities, &mut seen, name, EntityType::Organization);
625 }
626 }
627
628 entities
629}
630
631pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
635 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
636 let mut result = Vec::with_capacity(4);
637 for m in regex_url().find_iter(body) {
638 let raw = m.as_str();
639 let cleaned = raw
640 .trim_end_matches('`')
641 .trim_end_matches(',')
642 .trim_end_matches('.')
643 .trim_end_matches(';')
644 .trim_end_matches(')')
645 .trim_end_matches(']')
646 .trim_end_matches('}');
647 if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
648 result.push(ExtractedUrl {
649 url: cleaned.to_string(),
650 offset: m.start(),
651 });
652 }
653 }
654 result
655}
656
657fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
658 let mut entities: Vec<ExtractedEntity> = Vec::with_capacity(tokens.len() / 4);
659 let mut current_parts: Vec<String> = Vec::new();
660 let mut current_type: Option<EntityType> = None;
661
662 let flush = |parts: &mut Vec<String>,
663 typ: &mut Option<EntityType>,
664 entities: &mut Vec<ExtractedEntity>| {
665 if let Some(t) = typ.take() {
666 let name = parts.join(" ").trim().to_string();
667 let is_single_caps = !name.contains(' ')
671 && name == name.to_uppercase()
672 && name.len() >= MIN_ENTITY_CHARS;
673 let should_skip = is_single_caps && is_filtered_all_caps(&name);
674 let is_section_marker = regex_section_marker().is_match(&name);
679 if name.len() >= MIN_ENTITY_CHARS && !should_skip && !is_section_marker {
680 entities.push(ExtractedEntity {
681 name,
682 entity_type: t,
683 });
684 }
685 parts.clear();
686 }
687 };
688
689 for (token, label) in tokens.iter().zip(labels.iter()) {
690 if label == "O" {
691 flush(&mut current_parts, &mut current_type, &mut entities);
692 continue;
693 }
694
695 let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
696 ("B", rest)
697 } else if let Some(rest) = label.strip_prefix("I-") {
698 ("I", rest)
699 } else {
700 flush(&mut current_parts, &mut current_type, &mut entities);
701 continue;
702 };
703
704 const PT_VERB_FALSE_POSITIVES: &[&str] = &[
711 "L\u{00ea}",
712 "V\u{00ea}",
713 "C\u{00e1}",
714 "P\u{00f4}r",
715 "Ser",
716 "Vir",
717 "Ver",
718 "Dar",
719 "Ler",
720 "Ter",
721 ];
722
723 let entity_type: EntityType = match bio_type {
724 "DATE" => EntityType::Date,
726 "PER" => {
727 if PT_VERB_FALSE_POSITIVES.contains(&token.as_str()) {
729 flush(&mut current_parts, &mut current_type, &mut entities);
730 continue;
731 }
732 EntityType::Person
733 }
734 "ORG" => {
735 let t = token.to_lowercase();
736 if t.contains("lib")
737 || t.contains("sdk")
738 || t.contains("cli")
739 || t.contains("crate")
740 || t.contains("npm")
741 {
742 EntityType::Tool
743 } else {
744 EntityType::Organization
746 }
747 }
748 "LOC" => EntityType::Location,
750 _ => EntityType::Concept,
752 };
753
754 if prefix == "B" {
755 if token.starts_with("##") {
756 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
759 if let Some(last) = current_parts.last_mut() {
760 last.push_str(clean);
761 }
762 continue;
763 }
764 flush(&mut current_parts, &mut current_type, &mut entities);
765 current_parts.push(token.clone());
766 current_type = Some(entity_type);
767 } else if prefix == "I" && current_type.is_some() {
768 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
769 if token.starts_with("##") {
770 if let Some(last) = current_parts.last_mut() {
771 last.push_str(clean);
772 }
773 } else {
774 current_parts.push(clean.to_string());
775 }
776 }
777 }
778
779 flush(&mut current_parts, &mut current_type, &mut entities);
780 entities
781}
782
783#[cfg(test)]
793fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
794 if entities.len() < 2 {
795 return (Vec::new(), false);
796 }
797
798 let max_rels = crate::constants::max_relationships_per_memory();
801 let n = entities.len().min(MAX_ENTS);
802 let mut rels: Vec<NewRelationship> = Vec::new();
803 let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
804
805 let mut hit_cap = false;
806 'outer: for i in 0..n {
807 if rels.len() >= max_rels {
808 hit_cap = true;
809 break;
810 }
811
812 let mut for_entity = 0usize;
813 for j in (i + 1)..n {
814 if for_entity >= TOP_K_RELATIONS {
815 break;
816 }
817 if rels.len() >= max_rels {
818 hit_cap = true;
819 break 'outer;
820 }
821
822 let key = (i.min(j), i.max(j));
823 if !seen.insert(key) {
824 continue;
825 }
826
827 rels.push(NewRelationship {
828 source: entities[i].name.clone(),
830 target: entities[j].name.clone(),
831 relation: DEFAULT_RELATION.to_string(),
832 strength: 0.5,
833 description: None,
834 });
835 for_entity += 1;
836 }
837 }
838
839 if hit_cap {
841 tracing::warn!(
842 "relationships truncated to {max_rels} (with {n} entities, theoretical max was ~{}x combinations)",
843 n.saturating_sub(1)
844 );
845 }
846
847 (rels, hit_cap)
848}
849
850fn build_relationships_by_sentence_cooccurrence(
861 body: &str,
862 entities: &[NewEntity],
863) -> (Vec<NewRelationship>, bool) {
864 if entities.len() < 2 {
865 return (Vec::new(), false);
866 }
867
868 let max_rels = crate::constants::max_relationships_per_memory();
869 let lower_names: Vec<(usize, String)> = entities
870 .iter()
871 .take(MAX_ENTS)
872 .enumerate()
873 .map(|(i, e)| (i, e.name.to_lowercase()))
874 .collect();
875
876 let mut rels: Vec<NewRelationship> = Vec::new();
877 let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
878 let mut hit_cap = false;
879
880 for sentence in body.split(['.', '!', '?', '\n']) {
881 if sentence.trim().is_empty() {
882 continue;
883 }
884 let lower_sentence = sentence.to_lowercase();
885 let present: Vec<usize> = lower_names
886 .iter()
887 .filter(|(_, name)| !name.is_empty() && lower_sentence.contains(name.as_str()))
888 .map(|(i, _)| *i)
889 .collect();
890
891 if present.len() < 2 {
892 continue;
893 }
894
895 for i in 0..present.len() {
896 for j in (i + 1)..present.len() {
897 if rels.len() >= max_rels {
898 hit_cap = true;
899 tracing::warn!(
900 "relationships truncated to {max_rels} during sentence-level pairing"
901 );
902 return (rels, hit_cap);
903 }
904 let ei = present[i];
905 let ej = present[j];
906 let key = (ei.min(ej), ei.max(ej));
907 if seen.insert(key) {
908 rels.push(NewRelationship {
909 source: entities[ei].name.clone(),
910 target: entities[ej].name.clone(),
911 relation: DEFAULT_RELATION.to_string(),
912 strength: 0.5,
913 description: None,
914 });
915 }
916 }
917 }
918 }
919
920 (rels, hit_cap)
921}
922
923fn run_ner_sliding_window(
924 model: &BertNerModel,
925 body: &str,
926 paths: &AppPaths,
927) -> Result<Vec<ExtractedEntity>> {
928 let tokenizer_path = model_dir(paths).join("tokenizer.json");
929 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
930 .map_err(|e| anyhow::anyhow!("loading NER tokenizer: {e}"))?;
931
932 let encoding = tokenizer
933 .encode(body, false)
934 .map_err(|e| anyhow::anyhow!("encoding NER input: {e}"))?;
935
936 let mut all_ids: Vec<u32> = encoding.get_ids().to_vec();
937 let mut all_tokens: Vec<String> = encoding
938 .get_tokens()
939 .iter()
940 .map(|s| s.to_string())
941 .collect();
942
943 if all_ids.is_empty() {
944 return Ok(Vec::new());
945 }
946
947 let max_tokens = crate::constants::extraction_max_tokens();
955 if all_ids.len() > max_tokens {
956 tracing::warn!(
957 target: "extraction",
958 original_tokens = all_ids.len(),
959 capped_tokens = max_tokens,
960 "NER input truncated to cap; later body region will be skipped by NER (regex prefilter still covers full body)"
961 );
962 all_ids.truncate(max_tokens);
963 all_tokens.truncate(max_tokens);
964 }
965
966 let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
968 let mut start = 0usize;
969 loop {
970 let end = (start + MAX_SEQ_LEN).min(all_ids.len());
971 windows.push((
972 all_ids[start..end].to_vec(),
973 all_tokens[start..end].to_vec(),
974 ));
975 if end >= all_ids.len() {
976 break;
977 }
978 start += STRIDE;
979 }
980
981 windows.sort_by_key(|(ids, _)| ids.len());
983
984 let batch_size = crate::constants::ner_batch_size();
986 let mut entities: Vec<ExtractedEntity> = Vec::new();
987 let mut seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
988
989 for chunk in windows.chunks(batch_size) {
990 match model.predict_batch(chunk) {
991 Ok(batch_labels) => {
992 for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
993 for ent in iob_to_entities(tokens, labels) {
994 if seen.insert(hash_str(&ent.name)) {
995 entities.push(ent);
996 }
997 }
998 }
999 }
1000 Err(e) => {
1001 tracing::warn!(
1002 "batch NER failed (chunk of {} windows): {e:#} — falling back to single-window",
1003 chunk.len()
1004 );
1005 for (ids, tokens) in chunk {
1007 let mask = vec![1u32; ids.len()];
1008 match model.predict(ids, &mask) {
1009 Ok(labels) => {
1010 for ent in iob_to_entities(tokens, &labels) {
1011 if seen.insert(hash_str(&ent.name)) {
1012 entities.push(ent);
1013 }
1014 }
1015 }
1016 Err(e2) => {
1017 tracing::warn!("NER window fallback also failed: {e2:#}");
1018 }
1019 }
1020 }
1021 }
1022 }
1023 }
1024
1025 Ok(entities)
1026}
1027
1028fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
1035 static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
1036 let suffix_re = SUFFIX_RE.get_or_init(|| {
1039 Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)")
1040 .expect("compile-time validated numeric suffix regex literal")
1041 });
1042
1043 entities
1044 .into_iter()
1045 .map(|ent| {
1046 if let Some(pos) = body.find(&ent.name) {
1048 let after_pos = pos + ent.name.len();
1049 if after_pos < body.len() {
1050 let after = &body[after_pos..];
1051 if let Some(m) = suffix_re.find(after) {
1052 let suffix = m.as_str();
1053 if suffix.len() <= 7 {
1056 let mut extended = String::with_capacity(ent.name.len() + suffix.len());
1057 extended.push_str(&ent.name);
1058 extended.push_str(suffix);
1059 return ExtractedEntity {
1060 name: extended,
1061 entity_type: ent.entity_type,
1062 };
1063 }
1064 }
1065 }
1066 }
1067 ent
1068 })
1069 .collect()
1070}
1071
1072fn augment_versioned_model_names(
1092 entities: Vec<ExtractedEntity>,
1093 body: &str,
1094) -> Vec<ExtractedEntity> {
1095 static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
1096 let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
1103 Regex::new(
1104 r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
1105 )
1106 .expect("compile-time validated versioned model regex literal")
1107 });
1108
1109 let mut existing_lc: std::collections::HashSet<String> =
1110 entities.iter().map(|ent| ent.name.to_lowercase()).collect();
1111 let mut result = entities;
1112
1113 for caps in model_re.captures_iter(body) {
1114 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1115 if full_match.is_empty() || full_match.len() > 24 {
1118 continue;
1119 }
1120 let normalized_lc = full_match.to_lowercase();
1121 if existing_lc.contains(&normalized_lc) {
1122 continue;
1123 }
1124 if result.len() >= MAX_ENTS {
1127 break;
1128 }
1129 existing_lc.insert(normalized_lc);
1130 result.push(ExtractedEntity {
1131 name: full_match.to_string(),
1132 entity_type: EntityType::Concept,
1133 });
1134 }
1135
1136 result
1137}
1138
1139fn merge_and_deduplicate(
1140 regex_ents: Vec<ExtractedEntity>,
1141 ner_ents: Vec<ExtractedEntity>,
1142) -> Vec<ExtractedEntity> {
1143 let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
1158 let mut result: Vec<ExtractedEntity> = Vec::new();
1159 let mut truncated = false;
1160
1161 let total_input = regex_ents.len() + ner_ents.len();
1162 for ent in regex_ents.into_iter().chain(ner_ents) {
1163 let name_lc = ent.name.nfkc().collect::<String>().to_lowercase();
1164 let key = {
1168 let et = ent.entity_type.as_str();
1169 let mut k = String::with_capacity(et.len() + 1 + name_lc.len());
1170 k.push_str(et);
1171 k.push('\0');
1172 k.push_str(&name_lc);
1173 k
1174 };
1175
1176 let type_prefix = {
1181 let et = ent.entity_type.as_str();
1182 let mut p = String::with_capacity(et.len() + 1);
1183 p.push_str(et);
1184 p.push('\0');
1185 p
1186 };
1187 let mut collision_idx: Option<usize> = None;
1188 for (existing_key, idx) in &by_lc {
1189 if !existing_key.starts_with(&type_prefix) {
1191 continue;
1192 }
1193 let existing_name_lc = &existing_key[type_prefix.len()..];
1194 if existing_name_lc == name_lc
1195 || existing_name_lc.contains(name_lc.as_str())
1196 || name_lc.contains(existing_name_lc)
1197 {
1198 collision_idx = Some(*idx);
1199 break;
1200 }
1201 }
1202 match collision_idx {
1203 Some(idx) => {
1204 if ent.name.len() > result[idx].name.len() {
1207 let old_name_lc = result[idx].name.nfkc().collect::<String>().to_lowercase();
1208 let old_key = {
1209 let et = result[idx].entity_type.as_str();
1210 let mut k = String::with_capacity(et.len() + 1 + old_name_lc.len());
1211 k.push_str(et);
1212 k.push('\0');
1213 k.push_str(&old_name_lc);
1214 k
1215 };
1216 by_lc.remove(&old_key);
1217 result[idx] = ent;
1218 by_lc.insert(key, idx);
1219 }
1220 }
1221 None => {
1222 by_lc.insert(key, result.len());
1223 result.push(ent);
1224 }
1225 }
1226 if result.len() >= MAX_ENTS {
1227 truncated = true;
1228 break;
1229 }
1230 }
1231
1232 if truncated {
1234 tracing::warn!(
1235 "extraction truncated at {MAX_ENTS} entities (input had {total_input} candidates before deduplication)"
1236 );
1237 }
1238
1239 result
1240}
1241
1242fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
1243 extracted
1244 .into_iter()
1245 .map(|e| NewEntity {
1246 name: e.name,
1247 entity_type: e.entity_type,
1248 description: None,
1249 })
1250 .collect()
1251}
1252
1253pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
1254 let regex_entities = apply_regex_prefilter(body);
1255
1256 let mut bert_used = false;
1257 let ner_entities = match get_or_init_model(paths) {
1258 Some(model) => match run_ner_sliding_window(model, body, paths) {
1259 Ok(ents) => {
1260 bert_used = true;
1261 ents
1262 }
1263 Err(e) => {
1264 tracing::warn!("NER failed, falling back to regex-only extraction: {e:#}");
1265 Vec::new()
1266 }
1267 },
1268 None => Vec::new(),
1269 };
1270
1271 let merged = merge_and_deduplicate(regex_entities, ner_entities);
1272 let extended = extend_with_numeric_suffix(merged, body);
1274 let with_models = augment_versioned_model_names(extended, body);
1278 let with_models: Vec<ExtractedEntity> = with_models
1282 .into_iter()
1283 .filter(|e| !regex_section_marker().is_match(&e.name))
1284 .collect();
1285 let entities = to_new_entities(with_models);
1286 let (relationships, relationships_truncated) =
1287 build_relationships_by_sentence_cooccurrence(body, &entities);
1288
1289 let extraction_method = if bert_used {
1290 "bert+regex-batch".to_string()
1291 } else {
1292 "regex-only".to_string()
1293 };
1294
1295 let urls = extract_urls(body);
1296
1297 Ok(ExtractionResult {
1298 entities,
1299 relationships,
1300 relationships_truncated,
1301 extraction_method,
1302 urls,
1303 })
1304}
1305
1306pub struct RegexExtractor;
1307
1308impl Extractor for RegexExtractor {
1309 fn extract(&self, body: &str) -> Result<ExtractionResult> {
1310 let regex_entities = apply_regex_prefilter(body);
1311 let entities = to_new_entities(regex_entities);
1312 let (relationships, relationships_truncated) =
1313 build_relationships_by_sentence_cooccurrence(body, &entities);
1314 let urls = extract_urls(body);
1315 Ok(ExtractionResult {
1316 entities,
1317 relationships,
1318 relationships_truncated,
1319 extraction_method: "regex-only".to_string(),
1320 urls,
1321 })
1322 }
1323}
1324
1325#[cfg(test)]
1326mod tests {
1327 use super::*;
1328 use crate::entity_type::EntityType;
1329
1330 fn make_paths() -> AppPaths {
1331 use std::path::PathBuf;
1332 AppPaths {
1333 db: PathBuf::from("/tmp/test.sqlite"),
1334 models: PathBuf::from("/tmp/test_models"),
1335 }
1336 }
1337
1338 #[test]
1339 fn regex_email_captures_address() {
1340 let ents = apply_regex_prefilter("contact: someone@company.com for more info");
1341 assert!(ents
1343 .iter()
1344 .any(|e| e.name == "someone@company.com" && e.entity_type == EntityType::Concept));
1345 }
1346
1347 #[test]
1348 fn regex_all_caps_filters_pt_rule_word() {
1349 let ents = apply_regex_prefilter("NUNCA do this. PROIBIDO use X. DEVE follow Y.");
1351 assert!(
1352 !ents.iter().any(|e| e.name == "NUNCA"),
1353 "NUNCA must be filtered as a stopword"
1354 );
1355 assert!(
1356 !ents.iter().any(|e| e.name == "PROIBIDO"),
1357 "PROIBIDO must be filtered"
1358 );
1359 assert!(
1360 !ents.iter().any(|e| e.name == "DEVE"),
1361 "DEVE must be filtered"
1362 );
1363 }
1364
1365 #[test]
1366 fn regex_all_caps_accepts_underscored_constant() {
1367 let ents = apply_regex_prefilter("configure MAX_RETRY=3 and API_TIMEOUT=30");
1369 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1370 assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1371 }
1372
1373 #[test]
1374 fn regex_all_caps_accepts_domain_acronym() {
1375 let ents = apply_regex_prefilter("OPENAI launched GPT-5 with NVIDIA H100");
1377 assert!(ents.iter().any(|e| e.name == "OPENAI"));
1378 assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1379 }
1380
1381 #[test]
1382 fn regex_url_does_not_appear_in_apply_regex_prefilter() {
1383 let ents = apply_regex_prefilter("see https://docs.rs/crate for details");
1385 assert!(
1386 !ents.iter().any(|e| e.name.starts_with("https://")),
1387 "URLs must not appear as entities after the P0-2 split"
1388 );
1389 }
1390
1391 #[test]
1392 fn extract_urls_captures_https() {
1393 let urls = extract_urls("see https://docs.rs/crate for details");
1394 assert_eq!(urls.len(), 1);
1395 assert_eq!(urls[0].url, "https://docs.rs/crate");
1396 assert!(urls[0].offset > 0);
1397 }
1398
1399 #[test]
1400 fn extract_urls_trim_sufixo_pontuacao() {
1401 let urls = extract_urls("link: https://example.com/path. fim");
1402 assert!(!urls.is_empty());
1403 assert!(
1404 !urls[0].url.ends_with('.'),
1405 "sufixo ponto deve ser removido"
1406 );
1407 }
1408
1409 #[test]
1410 fn extract_urls_dedupes_repeated() {
1411 let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1412 let urls = extract_urls(body);
1413 assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1414 }
1415
1416 #[test]
1417 fn regex_uuid_captura_identificador() {
1418 let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1419 assert!(ents.iter().any(|e| e.entity_type == EntityType::Concept));
1420 }
1421
1422 #[test]
1423 fn regex_all_caps_captura_constante() {
1424 let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1425 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1426 assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1427 }
1428
1429 #[test]
1430 fn regex_all_caps_ignores_short_words() {
1431 let ents = apply_regex_prefilter("use AI em seu projeto");
1432 assert!(
1433 !ents.iter().any(|e| e.name == "AI"),
1434 "AI tem apenas 2 chars, deve ser ignorado"
1435 );
1436 }
1437
1438 #[test]
1439 fn iob_decodes_per_to_person() {
1440 let tokens = vec![
1441 "John".to_string(),
1442 "Doe".to_string(),
1443 "trabalhou".to_string(),
1444 ];
1445 let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1446 let ents = iob_to_entities(&tokens, &labels);
1447 assert_eq!(ents.len(), 1);
1448 assert_eq!(ents[0].entity_type, EntityType::Person);
1449 assert!(ents[0].name.contains("John"));
1450 }
1451
1452 #[test]
1453 fn iob_strip_subword_b_prefix() {
1454 let tokens = vec!["Open".to_string(), "##AI".to_string()];
1457 let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1458 let ents = iob_to_entities(&tokens, &labels);
1459 assert!(
1460 ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1461 "should merge ##AI or discard"
1462 );
1463 }
1464
1465 #[test]
1466 fn iob_subword_orphan_discards() {
1467 let tokens = vec!["##AI".to_string()];
1469 let labels = vec!["B-ORG".to_string()];
1470 let ents = iob_to_entities(&tokens, &labels);
1471 assert!(
1472 ents.is_empty(),
1473 "orphan subword without an active entity must be discarded"
1474 );
1475 }
1476
1477 #[test]
1478 fn iob_maps_date_to_date_v1025() {
1479 let tokens = vec!["January".to_string(), "2024".to_string()];
1481 let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1482 let ents = iob_to_entities(&tokens, &labels);
1483 assert_eq!(
1484 ents.len(),
1485 1,
1486 "DATE must be emitted as an entity in v1.0.25"
1487 );
1488 assert_eq!(ents[0].entity_type, EntityType::Date);
1489 }
1490
1491 #[test]
1492 fn iob_maps_org_to_organization_v1025() {
1493 let tokens = vec!["Empresa".to_string()];
1495 let labels = vec!["B-ORG".to_string()];
1496 let ents = iob_to_entities(&tokens, &labels);
1497 assert_eq!(ents[0].entity_type, EntityType::Organization);
1498 }
1499
1500 #[test]
1501 fn iob_maps_org_sdk_to_tool() {
1502 let tokens = vec!["tokio-sdk".to_string()];
1503 let labels = vec!["B-ORG".to_string()];
1504 let ents = iob_to_entities(&tokens, &labels);
1505 assert_eq!(ents[0].entity_type, EntityType::Tool);
1506 }
1507
1508 #[test]
1509 fn iob_maps_loc_to_location_v1025() {
1510 let tokens = vec!["Brasil".to_string()];
1512 let labels = vec!["B-LOC".to_string()];
1513 let ents = iob_to_entities(&tokens, &labels);
1514 assert_eq!(ents[0].entity_type, EntityType::Location);
1515 }
1516
1517 #[test]
1518 fn build_relationships_respeitam_max_rels() {
1519 let entities: Vec<NewEntity> = (0..20)
1520 .map(|i| NewEntity {
1521 name: format!("entidade_{i}"),
1522 entity_type: EntityType::Concept,
1523 description: None,
1524 })
1525 .collect();
1526 let (rels, truncated) = build_relationships(&entities);
1527 let max_rels = crate::constants::max_relationships_per_memory();
1528 assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1529 if rels.len() == max_rels {
1530 assert!(truncated, "truncated deve ser true quando atingiu o cap");
1531 }
1532 }
1533
1534 #[test]
1535 fn build_relationships_without_duplicates() {
1536 let entities: Vec<NewEntity> = (0..5)
1537 .map(|i| NewEntity {
1538 name: format!("ent_{i}"),
1539 entity_type: EntityType::Concept,
1540 description: None,
1541 })
1542 .collect();
1543 let (rels, _truncated) = build_relationships(&entities);
1544 let mut pares: std::collections::HashSet<(String, String)> =
1545 std::collections::HashSet::new();
1546 for r in &rels {
1547 let par = (r.source.clone(), r.target.clone());
1548 assert!(pares.insert(par), "par duplicado encontrado");
1549 }
1550 }
1551
1552 #[test]
1553 fn merge_dedupes_by_lowercase_name() {
1554 let a = vec![ExtractedEntity {
1557 name: "Rust".to_string(),
1558 entity_type: EntityType::Concept,
1559 }];
1560 let b = vec![ExtractedEntity {
1561 name: "rust".to_string(),
1562 entity_type: EntityType::Concept,
1563 }];
1564 let merged = merge_and_deduplicate(a, b);
1565 assert_eq!(
1566 merged.len(),
1567 1,
1568 "rust and Rust with the same type are the same entity"
1569 );
1570 }
1571
1572 #[test]
1573 fn regex_extractor_implements_trait() {
1574 let extractor = RegexExtractor;
1575 let result = extractor
1576 .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1577 .unwrap();
1578 assert!(!result.entities.is_empty());
1579 }
1580
1581 #[test]
1582 fn extract_returns_ok_without_model() {
1583 let paths = make_paths();
1585 let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1586 let result = extract_graph_auto(body, &paths).unwrap();
1587 assert!(result
1588 .entities
1589 .iter()
1590 .any(|e| e.name.contains("teste@exemplo.com")));
1591 }
1592
1593 #[test]
1594 fn stopwords_filter_v1024_terms() {
1595 let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1598 DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1599 let ents = apply_regex_prefilter(body);
1600 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1601 for word in &[
1602 "ACEITE",
1603 "ACK",
1604 "ACL",
1605 "BORDA",
1606 "CHECKLIST",
1607 "COMPLETED",
1608 "CONFIRME",
1609 "DEVEMOS",
1610 "DONE",
1611 "FIXED",
1612 "NEGUE",
1613 "PENDING",
1614 "PLAN",
1615 "PODEMOS",
1616 "RECUSE",
1617 "TOKEN",
1618 "VAMOS",
1619 ] {
1620 assert!(
1621 !names.contains(word),
1622 "v1.0.24 stopword {word} should be filtered but was found in entities"
1623 );
1624 }
1625 }
1626
1627 #[test]
1628 fn dedup_normalizes_unicode_combining_marks() {
1629 let nfc = vec![ExtractedEntity {
1633 name: "Caf\u{e9}".to_string(),
1634 entity_type: EntityType::Concept,
1635 }];
1636 let nfd_name = "Cafe\u{301}".to_string();
1638 let nfd = vec![ExtractedEntity {
1639 name: nfd_name,
1640 entity_type: EntityType::Concept,
1641 }];
1642 let merged = merge_and_deduplicate(nfc, nfd);
1643 assert_eq!(
1644 merged.len(),
1645 1,
1646 "NFC 'Caf\\u{{e9}}' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1647 );
1648 }
1649
1650 #[test]
1653 fn predict_batch_output_count_matches_input() {
1654 let w1_ids: Vec<u32> = vec![101, 100, 102];
1660 let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1661 let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1662 let w2_tok: Vec<String> = vec![
1663 "[CLS]".into(),
1664 "world".into(),
1665 "foo".into(),
1666 "bar".into(),
1667 "[SEP]".into(),
1668 ];
1669 let windows: Vec<(Vec<u32>, Vec<String>)> =
1670 vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1671
1672 let device = Device::Cpu;
1675 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1676 assert_eq!(max_len, 5, "max_len deve ser 5");
1677
1678 let mut padded_ids: Vec<Tensor> = Vec::new();
1679 for (ids, _) in &windows {
1680 let len = ids.len();
1681 let pad_right = max_len - len;
1682 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1683 let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1684 let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1685 assert_eq!(
1686 t.dims(),
1687 &[max_len],
1688 "each window must have shape (max_len,) after padding"
1689 );
1690 padded_ids.push(t);
1691 }
1692
1693 let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1694 assert_eq!(
1695 stacked.dims(),
1696 &[2, max_len],
1697 "stack deve produzir (batch_size=2, max_len=5)"
1698 );
1699
1700 let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; let fake_logits =
1704 Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1705 for (i, (ids, _)) in windows.iter().enumerate() {
1706 let real_len = ids.len();
1707 let example = fake_logits.get(i).unwrap();
1708 let sliced = example.narrow(0, 0, real_len).unwrap();
1709 assert_eq!(
1710 sliced.dims(),
1711 &[real_len, 9],
1712 "narrow deve preservar apenas {real_len} tokens reais"
1713 );
1714 }
1715 }
1716
1717 #[test]
1718 fn predict_batch_empty_windows_returns_empty() {
1719 let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1722 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1723 assert_eq!(max_len, 0, "zero windows → max_len 0");
1724 let result: Vec<Vec<String>> = if max_len == 0 {
1727 Vec::new()
1728 } else {
1729 unreachable!()
1730 };
1731 assert!(result.is_empty());
1732 }
1733
1734 #[test]
1735 fn ner_batch_size_default_is_8() {
1736 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1739 assert_eq!(crate::constants::ner_batch_size(), 8);
1740 }
1741
1742 #[test]
1743 fn ner_batch_size_env_override_clamped() {
1744 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1746 assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1747
1748 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1749 assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1750
1751 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1752 assert_eq!(
1753 crate::constants::ner_batch_size(),
1754 4,
1755 "valid value preserved"
1756 );
1757
1758 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1759 }
1760
1761 #[test]
1762 fn extraction_method_regex_only_unchanged() {
1763 let result = RegexExtractor.extract("contact: dev@acme.io").unwrap();
1766 assert_eq!(
1767 result.extraction_method, "regex-only",
1768 "RegexExtractor must return regex-only"
1769 );
1770 }
1771
1772 #[test]
1775 fn extend_suffix_pure_numeric_unchanged() {
1776 let ents = vec![ExtractedEntity {
1778 name: "GPT".to_string(),
1779 entity_type: EntityType::Concept,
1780 }];
1781 let result = extend_with_numeric_suffix(ents, "using GPT-5 in the project");
1782 assert_eq!(
1783 result[0].name, "GPT-5",
1784 "purely numeric suffix must be extended"
1785 );
1786 }
1787
1788 #[test]
1789 fn extend_suffix_alphanumeric_letter_after_digit() {
1790 let ents = vec![ExtractedEntity {
1792 name: "GPT".to_string(),
1793 entity_type: EntityType::Concept,
1794 }];
1795 let result = extend_with_numeric_suffix(ents, "using GPT-4o for advanced tasks");
1796 assert_eq!(result[0].name, "GPT-4o", "suffix '4o' must be accepted");
1797 }
1798
1799 #[test]
1800 fn extend_suffix_alphanumeric_b_suffix() {
1801 let ents = vec![ExtractedEntity {
1803 name: "Llama".to_string(),
1804 entity_type: EntityType::Concept,
1805 }];
1806 let result = extend_with_numeric_suffix(ents, "Llama-5b open-weight model");
1807 assert_eq!(result[0].name, "Llama-5b", "suffix '5b' must be accepted");
1808 }
1809
1810 #[test]
1811 fn extend_suffix_alphanumeric_x_suffix() {
1812 let ents = vec![ExtractedEntity {
1814 name: "Mistral".to_string(),
1815 entity_type: EntityType::Concept,
1816 }];
1817 let result = extend_with_numeric_suffix(ents, "testing Mistral-8x in production");
1818 assert_eq!(result[0].name, "Mistral-8x", "suffix '8x' must be accepted");
1819 }
1820
1821 #[test]
1824 fn augment_versioned_gpt4o() {
1825 let result = augment_versioned_model_names(vec![], "using GPT-4o for analysis");
1827 assert!(
1828 result.iter().any(|e| e.name == "GPT-4o"),
1829 "GPT-4o must be captured by augment, found: {:?}",
1830 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1831 );
1832 }
1833
1834 #[test]
1835 fn augment_versioned_claude_4_sonnet() {
1836 let result =
1838 augment_versioned_model_names(vec![], "best model: Claude 4 Sonnet released today");
1839 assert!(
1840 result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1841 "Claude 4 Sonnet must be captured, found: {:?}",
1842 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1843 );
1844 }
1845
1846 #[test]
1847 fn augment_versioned_llama_3_pro() {
1848 let result =
1850 augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1851 assert!(
1852 result.iter().any(|e| e.name == "Llama 3 Pro"),
1853 "Llama 3 Pro deve ser capturado, achados: {:?}",
1854 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1855 );
1856 }
1857
1858 #[test]
1859 fn augment_versioned_mixtral_8x7b() {
1860 let result =
1862 augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1863 assert!(
1864 result.iter().any(|e| e.name == "Mixtral 8x7B"),
1865 "Mixtral 8x7B deve ser capturado, achados: {:?}",
1866 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1867 );
1868 }
1869
1870 #[test]
1871 fn augment_versioned_does_not_duplicate_existing() {
1872 let existing = vec![ExtractedEntity {
1874 name: "Claude 4".to_string(),
1875 entity_type: EntityType::Concept,
1876 }];
1877 let result = augment_versioned_model_names(existing, "using Claude 4 in the project");
1878 let count = result.iter().filter(|e| e.name == "Claude 4").count();
1879 assert_eq!(count, 1, "Claude 4 must not be duplicated");
1880 }
1881
1882 #[test]
1885 fn stopwords_filter_url_jwt_api_v1025() {
1886 let body = "We use URL, JWT, and API REST in our LLM-powered CLI via HTTP/HTTPS and UI.";
1888 let ents = apply_regex_prefilter(body);
1889 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1890 for blocked in &[
1891 "URL", "JWT", "API", "REST", "LLM", "CLI", "HTTP", "HTTPS", "UI",
1892 ] {
1893 assert!(
1894 !names.contains(blocked),
1895 "v1.0.25 stopword {blocked} leaked as entity; found names: {names:?}"
1896 );
1897 }
1898 }
1899
1900 #[test]
1903 fn section_markers_etapa_fase_filtered_v1025() {
1904 let body = "Etapa 3 do plano: implementar Fase 1 da Migra\u{e7}\u{e3}o.";
1908 let ents = apply_regex_prefilter(body);
1909 assert!(
1910 !ents
1911 .iter()
1912 .any(|e| e.name.contains("Etapa") || e.name.contains("Fase")),
1913 "section markers must be stripped; entities: {:?}",
1914 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1915 );
1916 }
1917
1918 #[test]
1919 fn section_markers_passo_secao_filtered_v1025() {
1920 let body = "Siga Passo 2 conforme Se\u{e7}\u{e3}o 3 do manual.";
1923 let ents = apply_regex_prefilter(body);
1924 assert!(
1925 !ents
1926 .iter()
1927 .any(|e| e.name.contains("Passo") || e.name.contains("Se\u{e7}\u{e3}o")),
1928 "Passo/Se\\u{{e7}}\\u{{e3}}o section markers must be stripped; entities: {:?}",
1929 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1930 );
1931 }
1932
1933 #[test]
1936 fn brand_camelcase_extracted_as_organization_v1025() {
1937 let body = "OpenAI launched GPT-4 and PostgreSQL added pgvector.";
1939 let ents = apply_regex_prefilter(body);
1940 let openai = ents.iter().find(|e| e.name == "OpenAI");
1941 assert!(
1942 openai.is_some(),
1943 "OpenAI must be extracted by CamelCase brand regex; entities: {:?}",
1944 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1945 );
1946 assert_eq!(
1947 openai.unwrap().entity_type,
1948 EntityType::Organization,
1949 "brand CamelCase must map to organization (V008)"
1950 );
1951 }
1952
1953 #[test]
1954 fn brand_postgresql_extracted_as_organization_v1025() {
1955 let body = "migrating from MySQL to PostgreSQL for better performance.";
1956 let ents = apply_regex_prefilter(body);
1957 assert!(
1958 ents.iter()
1959 .any(|e| e.name == "PostgreSQL" && e.entity_type == EntityType::Organization),
1960 "PostgreSQL must be extracted as organization; entities: {:?}",
1961 ents.iter()
1962 .map(|e| (&e.name, &e.entity_type))
1963 .collect::<Vec<_>>()
1964 );
1965 }
1966
1967 #[test]
1970 fn iob_org_maps_to_organization_not_project_v1025() {
1971 let tokens = vec!["Microsoft".to_string()];
1973 let labels = vec!["B-ORG".to_string()];
1974 let ents = iob_to_entities(&tokens, &labels);
1975 assert_eq!(
1976 ents[0].entity_type,
1977 EntityType::Organization,
1978 "B-ORG must map to organization (V008); got {}",
1979 ents[0].entity_type
1980 );
1981 }
1982
1983 #[test]
1984 fn iob_loc_maps_to_location_not_concept_v1025() {
1985 let tokens = vec!["S\u{e3}o".to_string(), "Paulo".to_string()];
1988 let labels = vec!["B-LOC".to_string(), "I-LOC".to_string()];
1989 let ents = iob_to_entities(&tokens, &labels);
1990 assert_eq!(
1991 ents[0].entity_type,
1992 EntityType::Location,
1993 "B-LOC must map to location (V008); got {}",
1994 ents[0].entity_type
1995 );
1996 }
1997
1998 #[test]
1999 fn iob_date_maps_to_date_not_discarded_v1025() {
2000 let tokens = vec!["2025".to_string(), "-".to_string(), "12".to_string()];
2002 let labels = vec![
2003 "B-DATE".to_string(),
2004 "I-DATE".to_string(),
2005 "I-DATE".to_string(),
2006 ];
2007 let ents = iob_to_entities(&tokens, &labels);
2008 assert_eq!(
2009 ents.len(),
2010 1,
2011 "DATE entity must be emitted (V008); entities: {ents:?}"
2012 );
2013 assert_eq!(ents[0].entity_type, EntityType::Date);
2014 }
2015
2016 #[test]
2019 fn pt_verb_le_filtered_as_per_v1025() {
2020 let tokens = vec!["L\u{ea}".to_string(), "o".to_string(), "livro".to_string()];
2023 let labels = vec!["B-PER".to_string(), "O".to_string(), "O".to_string()];
2024 let ents = iob_to_entities(&tokens, &labels);
2025 assert!(
2026 !ents
2027 .iter()
2028 .any(|e| e.name == "L\u{ea}" && e.entity_type == EntityType::Person),
2029 "PT verb 'L\\u{{ea}}' tagged B-PER must be filtered; entities: {ents:?}"
2030 );
2031 }
2032
2033 #[test]
2034 fn pt_verb_ver_filtered_as_per_v1025() {
2035 let tokens = vec!["Ver".to_string()];
2037 let labels = vec!["B-PER".to_string()];
2038 let ents = iob_to_entities(&tokens, &labels);
2039 assert!(
2040 ents.is_empty(),
2041 "PT verb 'Ver' tagged B-PER must be filtered; entities: {ents:?}"
2042 );
2043 }
2044
2045 fn entity(name: &str, entity_type: EntityType) -> ExtractedEntity {
2048 ExtractedEntity {
2049 name: name.to_string(),
2050 entity_type,
2051 }
2052 }
2053
2054 #[test]
2055 fn merge_resolves_sonne_vs_sonnet_keeps_longest_v1025() {
2056 let regex = vec![entity("Sonne", EntityType::Concept)];
2058 let ner = vec![entity("Sonnet", EntityType::Concept)];
2059 let result = merge_and_deduplicate(regex, ner);
2060 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2061 assert_eq!(result[0].name, "Sonnet");
2062 }
2063
2064 #[test]
2065 fn merge_resolves_open_vs_openai_keeps_longest_v1025() {
2066 let regex = vec![
2068 entity("Open", EntityType::Organization),
2069 entity("OpenAI", EntityType::Organization),
2070 ];
2071 let result = merge_and_deduplicate(regex, vec![]);
2072 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2073 assert_eq!(result[0].name, "OpenAI");
2074 }
2075
2076 #[test]
2077 fn merge_keeps_both_when_no_containment_v1025() {
2078 let regex = vec![
2080 entity("Alice", EntityType::Person),
2081 entity("Bob", EntityType::Person),
2082 ];
2083 let result = merge_and_deduplicate(regex, vec![]);
2084 assert_eq!(result.len(), 2, "expected 2 entities, got: {result:?}");
2085 }
2086
2087 #[test]
2088 fn merge_respects_entity_type_boundary_v1025() {
2089 let regex = vec![
2091 entity("Apple", EntityType::Organization),
2092 entity("Apple", EntityType::Concept),
2093 ];
2094 let result = merge_and_deduplicate(regex, vec![]);
2095 assert_eq!(
2096 result.len(),
2097 2,
2098 "expected 2 entities (different types), got: {result:?}"
2099 );
2100 }
2101
2102 #[test]
2103 fn merge_case_insensitive_dedup_v1025() {
2104 let regex = vec![
2106 entity("OpenAI", EntityType::Organization),
2107 entity("openai", EntityType::Organization),
2108 ];
2109 let result = merge_and_deduplicate(regex, vec![]);
2110 assert_eq!(
2111 result.len(),
2112 1,
2113 "expected 1 entity after case-insensitive dedup, got: {result:?}"
2114 );
2115 }
2116
2117 #[test]
2120 fn iob_section_marker_etapa_filtered_v1025() {
2121 let tokens = vec!["Etapa".to_string(), "3".to_string()];
2123 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2124 let ents = iob_to_entities(&tokens, &labels);
2125 assert!(
2126 !ents.iter().any(|e| e.name.contains("Etapa")),
2127 "section marker 'Etapa 3' from BERT must be filtered; entities: {ents:?}"
2128 );
2129 }
2130
2131 #[test]
2132 fn iob_section_marker_fase_filtered_v1025() {
2133 let tokens = vec!["Fase".to_string(), "1".to_string()];
2135 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2136 let ents = iob_to_entities(&tokens, &labels);
2137 assert!(
2138 !ents.iter().any(|e| e.name.contains("Fase")),
2139 "section marker 'Fase 1' from BERT must be filtered; entities: {ents:?}"
2140 );
2141 }
2142
2143 #[test]
2146 fn extract_graph_auto_handles_large_body_under_30s() {
2147 let body = "x ".repeat(40_000);
2150 let paths = make_paths();
2151 let start = std::time::Instant::now();
2152 let result = extract_graph_auto(&body, &paths).expect("extraction must not error");
2153 let elapsed = start.elapsed();
2154 assert!(
2155 elapsed.as_secs() < 30,
2156 "extract_graph_auto took {}s for 80 KB body (cap should keep it well under 30s)",
2157 elapsed.as_secs()
2158 );
2159 let _ = result.entities;
2161 }
2162
2163 #[test]
2166 fn pt_uppercase_stopwords_filtered_v1031() {
2167 let body = "Para o ADAPTER funcionar com PROJETO em modo PASSIVA, devemos usar \
2168 SOMENTE LEITURA conforme a REGRA OBRIGATORIA do EXEMPLO DEFAULT.";
2169 let ents = apply_regex_prefilter(body);
2170 let names: Vec<String> = ents.iter().map(|e| e.name.to_uppercase()).collect();
2171 for stop in &[
2172 "ADAPTER",
2173 "PROJETO",
2174 "PASSIVA",
2175 "SOMENTE",
2176 "LEITURA",
2177 "REGRA",
2178 "OBRIGATORIA",
2179 "EXEMPLO",
2180 "DEFAULT",
2181 ] {
2182 assert!(
2183 !names.contains(&stop.to_string()),
2184 "v1.0.31 A11 stoplist failed: {stop} leaked as entity; got names: {names:?}"
2185 );
2186 }
2187 }
2188
2189 #[test]
2190 fn pt_underscored_identifier_preserved_v1031() {
2191 let ents = apply_regex_prefilter("configure FLOWAIPER_API_KEY=foo and MAX_TIMEOUT=30");
2194 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
2195 assert!(names.contains(&"FLOWAIPER_API_KEY"));
2196 assert!(names.contains(&"MAX_TIMEOUT"));
2197 }
2198
2199 #[test]
2202 fn build_relationships_by_sentence_only_links_co_occurring_entities() {
2203 let body = "Alice met Bob at the conference. Carol works alone in another room.";
2204 let entities = vec![
2205 NewEntity {
2206 name: "Alice".to_string(),
2207 entity_type: EntityType::Person,
2208 description: None,
2209 },
2210 NewEntity {
2211 name: "Bob".to_string(),
2212 entity_type: EntityType::Person,
2213 description: None,
2214 },
2215 NewEntity {
2216 name: "Carol".to_string(),
2217 entity_type: EntityType::Person,
2218 description: None,
2219 },
2220 ];
2221 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2222 assert!(!truncated);
2223 assert_eq!(
2224 rels.len(),
2225 1,
2226 "only Alice/Bob should pair (same sentence); Carol is isolated"
2227 );
2228 let pair = (rels[0].source.as_str(), rels[0].target.as_str());
2229 assert!(
2230 matches!(pair, ("Alice", "Bob") | ("Bob", "Alice")),
2231 "unexpected pair {pair:?}"
2232 );
2233 }
2234
2235 #[test]
2236 fn build_relationships_by_sentence_returns_empty_for_single_entity() {
2237 let body = "Alice is here.";
2238 let entities = vec![NewEntity {
2239 name: "Alice".to_string(),
2240 entity_type: EntityType::Person,
2241 description: None,
2242 }];
2243 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2244 assert!(rels.is_empty());
2245 assert!(!truncated);
2246 }
2247
2248 #[test]
2249 fn build_relationships_by_sentence_dedupes_pairs_across_sentences() {
2250 let body = "Alice met Bob. Bob saw Alice again.";
2251 let entities = vec![
2252 NewEntity {
2253 name: "Alice".to_string(),
2254 entity_type: EntityType::Person,
2255 description: None,
2256 },
2257 NewEntity {
2258 name: "Bob".to_string(),
2259 entity_type: EntityType::Person,
2260 description: None,
2261 },
2262 ];
2263 let (rels, _) = build_relationships_by_sentence_cooccurrence(body, &entities);
2264 assert_eq!(
2265 rels.len(),
2266 1,
2267 "Alice/Bob pair must be emitted only once even when co-occurring in multiple sentences"
2268 );
2269 }
2270
2271 #[test]
2272 fn extraction_max_tokens_default_is_5000() {
2273 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2274 assert_eq!(crate::constants::extraction_max_tokens(), 5_000);
2275 }
2276
2277 #[test]
2278 fn extraction_max_tokens_env_override_clamped() {
2279 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200");
2280 assert_eq!(
2281 crate::constants::extraction_max_tokens(),
2282 5_000,
2283 "value below 512 must fall back to default"
2284 );
2285
2286 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200000");
2287 assert_eq!(
2288 crate::constants::extraction_max_tokens(),
2289 5_000,
2290 "value above 100_000 must fall back to default"
2291 );
2292
2293 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "8000");
2294 assert_eq!(
2295 crate::constants::extraction_max_tokens(),
2296 8_000,
2297 "valid value must be honoured"
2298 );
2299
2300 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2301 }
2302}