1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::OnceLock;
9
10use anyhow::{Context, Result};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::{Linear, Module, VarBuilder};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use regex::Regex;
15use serde::Deserialize;
16use unicode_normalization::UnicodeNormalization;
17
18use crate::paths::AppPaths;
19use crate::storage::entities::{NewEntity, NewRelationship};
20
21const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
22const MAX_SEQ_LEN: usize = 512;
23const STRIDE: usize = 256;
24const MAX_ENTS: usize = 30;
25#[cfg(test)]
28const TOP_K_RELATIONS: usize = 5;
29const DEFAULT_RELATION: &str = "mentions";
30const MIN_ENTITY_CHARS: usize = 2;
31
32static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
33static REGEX_URL: OnceLock<Regex> = OnceLock::new();
34static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
35static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
36static REGEX_SECTION_MARKER: OnceLock<Regex> = OnceLock::new();
38static REGEX_BRAND_CAMEL: OnceLock<Regex> = OnceLock::new();
40
41const ALL_CAPS_STOPWORDS: &[&str] = &[
59 "ACEITE",
60 "ACID",
61 "ACK",
62 "ACL",
63 "ACRESCENTADO",
64 "ADAPTER",
65 "ADICIONADA",
66 "ADICIONADAS",
67 "ADICIONADO",
68 "ADICIONADOS",
69 "ADICIONAR",
70 "AGENTS",
71 "AINDA",
72 "ALL",
73 "ALTA",
74 "ALWAYS",
75 "APENAS",
76 "API",
77 "ARTEFATOS",
78 "ATIVA",
79 "ATIVO",
80 "BAIXA",
81 "BANCO",
82 "BLOQUEAR",
83 "BORDA",
84 "BUG",
85 "CAPÍTULO",
86 "CASO",
87 "CEO",
88 "CHECKLIST",
89 "CLARO",
90 "CLI",
91 "COMPLETED",
92 "CONFIRMADO",
93 "CONFIRMARAM",
94 "CONFIRME",
95 "CONFIRMEI",
96 "CONFIRMOU",
97 "CONTRATO",
98 "CRIE",
99 "CRÍTICO",
100 "CRITICAL",
101 "CSV",
102 "DDL",
103 "DEFAULT",
104 "DEFINIR",
105 "DEPARTMENT",
106 "DESC",
107 "DEVE",
108 "DEVEMOS",
109 "DISCO",
110 "DONE",
111 "DSL",
112 "DTO",
113 "EFEITO",
114 "ENTRADA",
115 "EPERM",
116 "ERROR",
117 "ESCREVA",
118 "ESCRITA",
119 "ESRCH",
120 "ESSA",
121 "ESSE",
122 "ESSENCIAL",
123 "ESTA",
124 "ESTADO",
125 "ESTE",
126 "ETAPA",
127 "EVITAR",
128 "EXEMPLO",
129 "EXPANDIR",
130 "EXPOR",
131 "FALHA",
132 "FASE",
133 "FATO",
134 "FIFO",
135 "FIXED",
136 "FIXME",
137 "FLUXO",
138 "FONTES",
139 "FORBIDDEN",
140 "FUNCIONA",
141 "HACK",
142 "HEARTBEAT",
143 "HTTP",
144 "HTTPS",
145 "INATIVO",
146 "JAMAIS",
147 "JSON",
148 "JWT",
149 "LEITURA",
150 "LLM",
151 "MESMO",
152 "METADADOS",
153 "MUST",
154 "NEGUE",
155 "NEVER",
156 "NOTE",
157 "NUNCA",
158 "OBRIGATORIA",
159 "OBRIGATÓRIO",
160 "PADRÃO",
161 "PASSIVA",
162 "PASSO",
163 "PENDING",
164 "PLAN",
165 "PODEMOS",
166 "PONTEIROS",
167 "PROIBIDO",
168 "PROJETO",
169 "RECUSE",
170 "REGRA",
171 "REGRAS",
172 "REQUIRED",
173 "REQUISITO",
174 "REST",
175 "SEÇÃO",
176 "SEMPRE",
177 "SHALL",
178 "SHOULD",
179 "SOMENTE",
180 "SOUL",
181 "TODAS",
182 "TODO",
183 "TODOS",
184 "TOKEN",
185 "TOOLS",
186 "TSV",
187 "UI",
188 "URL",
189 "USAR",
190 "VALIDAR",
191 "VAMOS",
192 "VOCÊ",
193 "WARNING",
194 "XML",
195 "YAML",
196];
197
198const HTTP_METHODS: &[&str] = &[
201 "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
202];
203
204fn is_filtered_all_caps(token: &str) -> bool {
205 let is_identifier = token.contains('_');
207 if is_identifier {
208 return false;
209 }
210 ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
211}
212
213fn regex_email() -> &'static Regex {
214 REGEX_EMAIL.get_or_init(|| {
215 Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
216 .expect("compile-time validated email regex literal")
217 })
218}
219
220fn regex_url() -> &'static Regex {
221 REGEX_URL.get_or_init(|| {
222 Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#)
223 .expect("compile-time validated URL regex literal")
224 })
225}
226
227fn regex_uuid() -> &'static Regex {
228 REGEX_UUID.get_or_init(|| {
229 Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
230 .expect("compile-time validated UUID regex literal")
231 })
232}
233
234fn regex_all_caps() -> &'static Regex {
235 REGEX_ALL_CAPS.get_or_init(|| {
236 Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b")
237 .expect("compile-time validated all-caps regex literal")
238 })
239}
240
241fn regex_section_marker() -> &'static Regex {
242 REGEX_SECTION_MARKER.get_or_init(|| {
243 Regex::new("\\b(?:Etapa|Fase|Passo|Camada|Se\u{00e7}\u{00e3}o|Cap\u{00ed}tulo)\\s+\\d+\\b")
250 .expect("compile-time validated section marker regex literal")
251 })
252}
253
254fn regex_brand_camel() -> &'static Regex {
255 REGEX_BRAND_CAMEL.get_or_init(|| {
256 Regex::new(r"\b[A-Z][a-z]+[A-Z][A-Za-z]+\b")
259 .expect("compile-time validated CamelCase brand regex literal")
260 })
261}
262
263#[derive(Debug, Clone, PartialEq)]
264pub struct ExtractedEntity {
265 pub name: String,
266 pub entity_type: String,
267}
268
269#[derive(Debug, Clone)]
271pub struct ExtractedUrl {
272 pub url: String,
273 pub offset: usize,
275}
276
277#[derive(Debug, Clone)]
278pub struct ExtractionResult {
279 pub entities: Vec<NewEntity>,
280 pub relationships: Vec<NewRelationship>,
281 pub relationships_truncated: bool,
284 pub extraction_method: String,
287 pub urls: Vec<ExtractedUrl>,
289}
290
291pub trait Extractor: Send + Sync {
292 fn extract(&self, body: &str) -> Result<ExtractionResult>;
293}
294
295#[derive(Deserialize)]
296struct ModelConfig {
297 #[serde(default)]
298 id2label: HashMap<String, String>,
299 hidden_size: usize,
300}
301
302struct BertNerModel {
303 bert: BertModel,
304 classifier: Linear,
305 device: Device,
306 id2label: HashMap<usize, String>,
307}
308
309impl BertNerModel {
310 fn load(model_dir: &Path) -> Result<Self> {
311 let config_path = model_dir.join("config.json");
312 let weights_path = model_dir.join("model.safetensors");
313
314 let config_str = std::fs::read_to_string(&config_path)
315 .with_context(|| format!("lendo config.json em {config_path:?}"))?;
316 let model_cfg: ModelConfig =
317 serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
318
319 let id2label: HashMap<usize, String> = model_cfg
320 .id2label
321 .into_iter()
322 .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
323 .collect();
324
325 let num_labels = id2label.len().max(9);
326 let hidden_size = model_cfg.hidden_size;
327
328 let bert_config_str = std::fs::read_to_string(&config_path)
329 .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
330 let bert_cfg: BertConfig =
331 serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
332
333 let device = Device::Cpu;
334
335 let vb = unsafe {
343 VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
344 .with_context(|| format!("mapping {weights_path:?}"))?
345 };
346 let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("loading BertModel")?;
347
348 let cls_vb = vb.pp("classifier");
351 let weight = cls_vb
352 .get((num_labels, hidden_size), "weight")
353 .context("carregando classifier.weight do safetensors")?;
354 let bias = cls_vb
355 .get(num_labels, "bias")
356 .context("carregando classifier.bias do safetensors")?;
357 let classifier = Linear::new(weight, Some(bias));
358
359 Ok(Self {
360 bert,
361 classifier,
362 device,
363 id2label,
364 })
365 }
366
367 fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
368 let len = token_ids.len();
369 let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
370 let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
371
372 let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
373 .context("creating tensor input_ids")?;
374 let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
375 .context("creating tensor token_type_ids")?;
376 let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
377 .context("creating tensor attention_mask")?;
378
379 let sequence_output = self
380 .bert
381 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
382 .context("BertModel forward pass")?;
383
384 let logits = self
385 .classifier
386 .forward(&sequence_output)
387 .context("classifier forward pass")?;
388
389 let logits_2d = logits.squeeze(0).context("removing batch dimension")?;
390
391 let num_tokens = logits_2d.dim(0).context("dim(0)")?;
392
393 let mut labels = Vec::with_capacity(num_tokens);
394 for i in 0..num_tokens {
395 let token_logits = logits_2d.get(i).context("get token logits")?;
396 let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
397 let argmax = vec
398 .iter()
399 .enumerate()
400 .max_by(|(_, a), (_, b)| {
401 a.partial_cmp(b)
402 .expect("BERT NER logits invariant: no NaN in classifier output")
403 })
404 .map(|(idx, _)| idx)
405 .unwrap_or(0);
406 let label = self
407 .id2label
408 .get(&argmax)
409 .cloned()
410 .unwrap_or_else(|| "O".to_string());
411 labels.push(label);
412 }
413
414 Ok(labels)
415 }
416
417 fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
426 let batch_size = windows.len();
427 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
428 if max_len == 0 {
429 return Ok(vec![vec![]; batch_size]);
430 }
431
432 let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
433 let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
434
435 for (ids, _) in windows {
436 let len = ids.len();
437 let pad_right = max_len - len;
438
439 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
440 let t = Tensor::from_vec(ids_i64, len, &self.device)
442 .context("creating id tensor for batch")?;
443 let t = t
444 .pad_with_zeros(0, 0, pad_right)
445 .context("padding id tensor")?;
446 padded_ids.push(t);
447
448 let mut mask_i64 = vec![1i64; len];
450 mask_i64.extend(vec![0i64; pad_right]);
451 let m = Tensor::from_vec(mask_i64, max_len, &self.device)
452 .context("creating mask tensor for batch")?;
453 padded_masks.push(m);
454 }
455
456 let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
458 let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
459 let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
460 .context("creating token_type_ids tensor for batch")?;
461
462 let sequence_output = self
464 .bert
465 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
466 .context("BertModel batch forward pass")?;
467 let logits = self
470 .classifier
471 .forward(&sequence_output)
472 .context("forward pass batch classificador")?;
473 let mut results = Vec::with_capacity(batch_size);
476 for (i, (window_ids, _)) in windows.iter().enumerate() {
477 let example_logits = logits.get(i).context("get logits exemplo")?;
478 let real_len = window_ids.len();
480 let example_slice = example_logits
481 .narrow(0, 0, real_len)
482 .context("narrow para tokens reais")?;
483 let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
484
485 let labels: Vec<String> = logits_2d
486 .iter()
487 .map(|token_logits| {
488 let argmax = token_logits
489 .iter()
490 .enumerate()
491 .max_by(|(_, a), (_, b)| {
492 a.partial_cmp(b)
493 .expect("BERT NER logits invariant: no NaN in classifier output")
494 })
495 .map(|(idx, _)| idx)
496 .unwrap_or(0);
497 self.id2label
498 .get(&argmax)
499 .cloned()
500 .unwrap_or_else(|| "O".to_string())
501 })
502 .collect();
503
504 results.push(labels);
505 }
506
507 Ok(results)
508 }
509}
510
511static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
512
513fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
514 NER_MODEL
515 .get_or_init(|| match load_model(paths) {
516 Ok(m) => Some(m),
517 Err(e) => {
518 tracing::warn!("NER model unavailable (graceful degradation): {e:#}");
519 None
520 }
521 })
522 .as_ref()
523}
524
525fn model_dir(paths: &AppPaths) -> PathBuf {
526 paths.models.join("bert-multilingual-ner")
527}
528
529fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
530 let dir = model_dir(paths);
531 std::fs::create_dir_all(&dir).with_context(|| format!("creating model directory: {dir:?}"))?;
532
533 let weights = dir.join("model.safetensors");
534 let config = dir.join("config.json");
535 let tokenizer = dir.join("tokenizer.json");
536
537 if weights.exists() && config.exists() && tokenizer.exists() {
538 return Ok(dir);
539 }
540
541 tracing::info!("Downloading NER model (first run, ~676 MB)...");
542 crate::output::emit_progress_i18n(
543 "Downloading NER model (first run, ~676 MB)...",
544 crate::i18n::validation::runtime_pt::downloading_ner_model(),
545 );
546
547 let api = huggingface_hub::api::sync::Api::new().context("creating HF Hub client")?;
548 let repo = api.model(MODEL_ID.to_string());
549
550 for (remote, local) in &[
554 ("model.safetensors", "model.safetensors"),
555 ("config.json", "config.json"),
556 ("onnx/tokenizer.json", "tokenizer.json"),
557 ("tokenizer_config.json", "tokenizer_config.json"),
558 ] {
559 let dest = dir.join(local);
560 if !dest.exists() {
561 let src = repo
562 .get(remote)
563 .with_context(|| format!("baixando {remote} do HF Hub"))?;
564 std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
565 }
566 }
567
568 Ok(dir)
569}
570
571fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
572 let dir = ensure_model_files(paths)?;
573 BertNerModel::load(&dir)
574}
575
576fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
577 let mut entities = Vec::new();
578 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
579
580 let add = |entities: &mut Vec<ExtractedEntity>,
581 seen: &mut std::collections::HashSet<String>,
582 name: &str,
583 entity_type: &str| {
584 let name = name.trim().to_string();
585 if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
586 entities.push(ExtractedEntity {
587 name,
588 entity_type: entity_type.to_string(),
589 });
590 }
591 };
592
593 let cleaned = regex_section_marker().replace_all(body, " ");
596 let cleaned = cleaned.as_ref();
597
598 for m in regex_email().find_iter(cleaned) {
599 add(&mut entities, &mut seen, m.as_str(), "concept");
601 }
602 for m in regex_uuid().find_iter(cleaned) {
603 add(&mut entities, &mut seen, m.as_str(), "concept");
604 }
605 for m in regex_all_caps().find_iter(cleaned) {
606 let candidate = m.as_str();
607 if !is_filtered_all_caps(candidate) {
609 add(&mut entities, &mut seen, candidate, "concept");
610 }
611 }
612 for m in regex_brand_camel().find_iter(cleaned) {
615 let name = m.as_str();
616 if !ALL_CAPS_STOPWORDS.contains(&name.to_uppercase().as_str()) {
618 add(&mut entities, &mut seen, name, "organization");
619 }
620 }
621
622 entities
623}
624
625pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
629 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
630 let mut result = Vec::new();
631 for m in regex_url().find_iter(body) {
632 let raw = m.as_str();
633 let cleaned = raw
634 .trim_end_matches('`')
635 .trim_end_matches(',')
636 .trim_end_matches('.')
637 .trim_end_matches(';')
638 .trim_end_matches(')')
639 .trim_end_matches(']')
640 .trim_end_matches('}');
641 if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
642 result.push(ExtractedUrl {
643 url: cleaned.to_string(),
644 offset: m.start(),
645 });
646 }
647 }
648 result
649}
650
651fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
652 let mut entities: Vec<ExtractedEntity> = Vec::new();
653 let mut current_parts: Vec<String> = Vec::new();
654 let mut current_type: Option<String> = None;
655
656 let flush =
657 |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
658 if let Some(t) = typ.take() {
659 let name = parts.join(" ").trim().to_string();
660 let is_single_caps = !name.contains(' ')
664 && name == name.to_uppercase()
665 && name.len() >= MIN_ENTITY_CHARS;
666 let should_skip = is_single_caps && is_filtered_all_caps(&name);
667 let is_section_marker = regex_section_marker().is_match(&name);
672 if name.len() >= MIN_ENTITY_CHARS && !should_skip && !is_section_marker {
673 entities.push(ExtractedEntity {
674 name,
675 entity_type: t,
676 });
677 }
678 parts.clear();
679 }
680 };
681
682 for (token, label) in tokens.iter().zip(labels.iter()) {
683 if label == "O" {
684 flush(&mut current_parts, &mut current_type, &mut entities);
685 continue;
686 }
687
688 let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
689 ("B", rest)
690 } else if let Some(rest) = label.strip_prefix("I-") {
691 ("I", rest)
692 } else {
693 flush(&mut current_parts, &mut current_type, &mut entities);
694 continue;
695 };
696
697 const PT_VERB_FALSE_POSITIVES: &[&str] = &[
704 "L\u{00ea}",
705 "V\u{00ea}",
706 "C\u{00e1}",
707 "P\u{00f4}r",
708 "Ser",
709 "Vir",
710 "Ver",
711 "Dar",
712 "Ler",
713 "Ter",
714 ];
715
716 let entity_type = match bio_type {
717 "DATE" => "date",
719 "PER" => {
720 if PT_VERB_FALSE_POSITIVES.contains(&token.as_str()) {
722 flush(&mut current_parts, &mut current_type, &mut entities);
723 continue;
724 }
725 "person"
726 }
727 "ORG" => {
728 let t = token.to_lowercase();
729 if t.contains("lib")
730 || t.contains("sdk")
731 || t.contains("cli")
732 || t.contains("crate")
733 || t.contains("npm")
734 {
735 "tool"
736 } else {
737 "organization"
739 }
740 }
741 "LOC" => "location",
743 other => other,
744 };
745
746 if prefix == "B" {
747 if token.starts_with("##") {
748 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
751 if let Some(last) = current_parts.last_mut() {
752 last.push_str(clean);
753 }
754 continue;
755 }
756 flush(&mut current_parts, &mut current_type, &mut entities);
757 current_parts.push(token.clone());
758 current_type = Some(entity_type.to_string());
759 } else if prefix == "I" && current_type.is_some() {
760 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
761 if token.starts_with("##") {
762 if let Some(last) = current_parts.last_mut() {
763 last.push_str(clean);
764 }
765 } else {
766 current_parts.push(clean.to_string());
767 }
768 }
769 }
770
771 flush(&mut current_parts, &mut current_type, &mut entities);
772 entities
773}
774
775#[cfg(test)]
785fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
786 if entities.len() < 2 {
787 return (Vec::new(), false);
788 }
789
790 let max_rels = crate::constants::max_relationships_per_memory();
793 let n = entities.len().min(MAX_ENTS);
794 let mut rels: Vec<NewRelationship> = Vec::new();
795 let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
796
797 let mut hit_cap = false;
798 'outer: for i in 0..n {
799 if rels.len() >= max_rels {
800 hit_cap = true;
801 break;
802 }
803
804 let mut for_entity = 0usize;
805 for j in (i + 1)..n {
806 if for_entity >= TOP_K_RELATIONS {
807 break;
808 }
809 if rels.len() >= max_rels {
810 hit_cap = true;
811 break 'outer;
812 }
813
814 let src = &entities[i].name;
815 let tgt = &entities[j].name;
816 let key = (src.clone(), tgt.clone());
817
818 if seen.contains(&key) {
819 continue;
820 }
821 seen.insert(key);
822
823 rels.push(NewRelationship {
824 source: src.clone(),
825 target: tgt.clone(),
826 relation: DEFAULT_RELATION.to_string(),
827 strength: 0.5,
828 description: None,
829 });
830 for_entity += 1;
831 }
832 }
833
834 if hit_cap {
836 tracing::warn!(
837 "relationships truncated to {max_rels} (with {n} entities, theoretical max was ~{}x combinations)",
838 n.saturating_sub(1)
839 );
840 }
841
842 (rels, hit_cap)
843}
844
845fn build_relationships_by_sentence_cooccurrence(
856 body: &str,
857 entities: &[NewEntity],
858) -> (Vec<NewRelationship>, bool) {
859 if entities.len() < 2 {
860 return (Vec::new(), false);
861 }
862
863 let max_rels = crate::constants::max_relationships_per_memory();
864 let lower_names: Vec<(usize, String)> = entities
865 .iter()
866 .take(MAX_ENTS)
867 .enumerate()
868 .map(|(i, e)| (i, e.name.to_lowercase()))
869 .collect();
870
871 let mut rels: Vec<NewRelationship> = Vec::new();
872 let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
873 let mut hit_cap = false;
874
875 for sentence in body.split(['.', '!', '?', '\n']) {
876 if sentence.trim().is_empty() {
877 continue;
878 }
879 let lower_sentence = sentence.to_lowercase();
880 let present: Vec<usize> = lower_names
881 .iter()
882 .filter(|(_, name)| !name.is_empty() && lower_sentence.contains(name.as_str()))
883 .map(|(i, _)| *i)
884 .collect();
885
886 if present.len() < 2 {
887 continue;
888 }
889
890 for i in 0..present.len() {
891 for j in (i + 1)..present.len() {
892 if rels.len() >= max_rels {
893 hit_cap = true;
894 tracing::warn!(
895 "relationships truncated to {max_rels} during sentence-level pairing"
896 );
897 return (rels, hit_cap);
898 }
899 let a = &entities[present[i]];
900 let b = &entities[present[j]];
901 let key = (a.name.to_lowercase(), b.name.to_lowercase());
902 if seen.insert(key) {
903 rels.push(NewRelationship {
904 source: a.name.clone(),
905 target: b.name.clone(),
906 relation: DEFAULT_RELATION.to_string(),
907 strength: 0.5,
908 description: None,
909 });
910 }
911 }
912 }
913 }
914
915 (rels, hit_cap)
916}
917
918fn run_ner_sliding_window(
919 model: &BertNerModel,
920 body: &str,
921 paths: &AppPaths,
922) -> Result<Vec<ExtractedEntity>> {
923 let tokenizer_path = model_dir(paths).join("tokenizer.json");
924 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
925 .map_err(|e| anyhow::anyhow!("loading NER tokenizer: {e}"))?;
926
927 let encoding = tokenizer
928 .encode(body, false)
929 .map_err(|e| anyhow::anyhow!("encoding NER input: {e}"))?;
930
931 let mut all_ids: Vec<u32> = encoding.get_ids().to_vec();
932 let mut all_tokens: Vec<String> = encoding
933 .get_tokens()
934 .iter()
935 .map(|s| s.to_string())
936 .collect();
937
938 if all_ids.is_empty() {
939 return Ok(Vec::new());
940 }
941
942 let max_tokens = crate::constants::extraction_max_tokens();
950 if all_ids.len() > max_tokens {
951 tracing::warn!(
952 target: "extraction",
953 original_tokens = all_ids.len(),
954 capped_tokens = max_tokens,
955 "NER input truncated to cap; later body region will be skipped by NER (regex prefilter still covers full body)"
956 );
957 all_ids.truncate(max_tokens);
958 all_tokens.truncate(max_tokens);
959 }
960
961 let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
963 let mut start = 0usize;
964 loop {
965 let end = (start + MAX_SEQ_LEN).min(all_ids.len());
966 windows.push((
967 all_ids[start..end].to_vec(),
968 all_tokens[start..end].to_vec(),
969 ));
970 if end >= all_ids.len() {
971 break;
972 }
973 start += STRIDE;
974 }
975
976 windows.sort_by_key(|(ids, _)| ids.len());
978
979 let batch_size = crate::constants::ner_batch_size();
981 let mut entities: Vec<ExtractedEntity> = Vec::new();
982 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
983
984 for chunk in windows.chunks(batch_size) {
985 match model.predict_batch(chunk) {
986 Ok(batch_labels) => {
987 for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
988 for ent in iob_to_entities(tokens, labels) {
989 if seen.insert(ent.name.clone()) {
990 entities.push(ent);
991 }
992 }
993 }
994 }
995 Err(e) => {
996 tracing::warn!(
997 "batch NER failed (chunk of {} windows): {e:#} — falling back to single-window",
998 chunk.len()
999 );
1000 for (ids, tokens) in chunk {
1002 let mask = vec![1u32; ids.len()];
1003 match model.predict(ids, &mask) {
1004 Ok(labels) => {
1005 for ent in iob_to_entities(tokens, &labels) {
1006 if seen.insert(ent.name.clone()) {
1007 entities.push(ent);
1008 }
1009 }
1010 }
1011 Err(e2) => {
1012 tracing::warn!("NER window fallback also failed: {e2:#}");
1013 }
1014 }
1015 }
1016 }
1017 }
1018 }
1019
1020 Ok(entities)
1021}
1022
1023fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
1030 static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
1031 let suffix_re = SUFFIX_RE.get_or_init(|| {
1034 Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)")
1035 .expect("compile-time validated numeric suffix regex literal")
1036 });
1037
1038 entities
1039 .into_iter()
1040 .map(|ent| {
1041 if let Some(pos) = body.find(&ent.name) {
1043 let after_pos = pos + ent.name.len();
1044 if after_pos < body.len() {
1045 let after = &body[after_pos..];
1046 if let Some(m) = suffix_re.find(after) {
1047 let suffix = m.as_str();
1048 if suffix.len() <= 7 {
1051 let extended = format!("{}{}", ent.name, suffix);
1052 return ExtractedEntity {
1053 name: extended,
1054 entity_type: ent.entity_type,
1055 };
1056 }
1057 }
1058 }
1059 }
1060 ent
1061 })
1062 .collect()
1063}
1064
1065fn augment_versioned_model_names(
1085 entities: Vec<ExtractedEntity>,
1086 body: &str,
1087) -> Vec<ExtractedEntity> {
1088 static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
1089 let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
1096 Regex::new(
1097 r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
1098 )
1099 .expect("compile-time validated versioned model regex literal")
1100 });
1101
1102 let mut existing_lc: std::collections::HashSet<String> =
1103 entities.iter().map(|ent| ent.name.to_lowercase()).collect();
1104 let mut result = entities;
1105
1106 for caps in model_re.captures_iter(body) {
1107 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1108 if full_match.is_empty() || full_match.len() > 24 {
1111 continue;
1112 }
1113 let normalized_lc = full_match.to_lowercase();
1114 if existing_lc.contains(&normalized_lc) {
1115 continue;
1116 }
1117 if result.len() >= MAX_ENTS {
1120 break;
1121 }
1122 existing_lc.insert(normalized_lc);
1123 result.push(ExtractedEntity {
1124 name: full_match.to_string(),
1125 entity_type: "concept".to_string(),
1126 });
1127 }
1128
1129 result
1130}
1131
1132fn merge_and_deduplicate(
1133 regex_ents: Vec<ExtractedEntity>,
1134 ner_ents: Vec<ExtractedEntity>,
1135) -> Vec<ExtractedEntity> {
1136 let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
1151 let mut result: Vec<ExtractedEntity> = Vec::new();
1152 let mut truncated = false;
1153
1154 let total_input = regex_ents.len() + ner_ents.len();
1155 for ent in regex_ents.into_iter().chain(ner_ents) {
1156 let name_lc = ent.name.nfkc().collect::<String>().to_lowercase();
1157 let key = format!("{}\0{}", ent.entity_type, name_lc);
1161
1162 let mut collision_idx: Option<usize> = None;
1167 for (existing_key, idx) in &by_lc {
1168 let type_prefix = format!("{}\0", ent.entity_type);
1170 if !existing_key.starts_with(&type_prefix) {
1171 continue;
1172 }
1173 let existing_name_lc = &existing_key[type_prefix.len()..];
1174 if existing_name_lc == name_lc
1175 || existing_name_lc.contains(name_lc.as_str())
1176 || name_lc.contains(existing_name_lc)
1177 {
1178 collision_idx = Some(*idx);
1179 break;
1180 }
1181 }
1182 match collision_idx {
1183 Some(idx) => {
1184 if ent.name.len() > result[idx].name.len() {
1187 let old_name_lc = result[idx].name.nfkc().collect::<String>().to_lowercase();
1188 let old_key = format!("{}\0{}", result[idx].entity_type, old_name_lc);
1189 by_lc.remove(&old_key);
1190 result[idx] = ent;
1191 by_lc.insert(key, idx);
1192 }
1193 }
1194 None => {
1195 by_lc.insert(key, result.len());
1196 result.push(ent);
1197 }
1198 }
1199 if result.len() >= MAX_ENTS {
1200 truncated = true;
1201 break;
1202 }
1203 }
1204
1205 if truncated {
1207 tracing::warn!(
1208 "extraction truncated at {MAX_ENTS} entities (input had {total_input} candidates before deduplication)"
1209 );
1210 }
1211
1212 result
1213}
1214
1215fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
1216 extracted
1217 .into_iter()
1218 .map(|e| NewEntity {
1219 name: e.name,
1220 entity_type: e.entity_type,
1221 description: None,
1222 })
1223 .collect()
1224}
1225
1226pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
1227 let regex_entities = apply_regex_prefilter(body);
1228
1229 let mut bert_used = false;
1230 let ner_entities = match get_or_init_model(paths) {
1231 Some(model) => match run_ner_sliding_window(model, body, paths) {
1232 Ok(ents) => {
1233 bert_used = true;
1234 ents
1235 }
1236 Err(e) => {
1237 tracing::warn!("NER failed, falling back to regex-only extraction: {e:#}");
1238 Vec::new()
1239 }
1240 },
1241 None => Vec::new(),
1242 };
1243
1244 let merged = merge_and_deduplicate(regex_entities, ner_entities);
1245 let extended = extend_with_numeric_suffix(merged, body);
1247 let with_models = augment_versioned_model_names(extended, body);
1251 let with_models: Vec<ExtractedEntity> = with_models
1255 .into_iter()
1256 .filter(|e| !regex_section_marker().is_match(&e.name))
1257 .collect();
1258 let entities = to_new_entities(with_models);
1259 let (relationships, relationships_truncated) =
1260 build_relationships_by_sentence_cooccurrence(body, &entities);
1261
1262 let extraction_method = if bert_used {
1263 "bert+regex-batch".to_string()
1264 } else {
1265 "regex-only".to_string()
1266 };
1267
1268 let urls = extract_urls(body);
1269
1270 Ok(ExtractionResult {
1271 entities,
1272 relationships,
1273 relationships_truncated,
1274 extraction_method,
1275 urls,
1276 })
1277}
1278
1279pub struct RegexExtractor;
1280
1281impl Extractor for RegexExtractor {
1282 fn extract(&self, body: &str) -> Result<ExtractionResult> {
1283 let regex_entities = apply_regex_prefilter(body);
1284 let entities = to_new_entities(regex_entities);
1285 let (relationships, relationships_truncated) =
1286 build_relationships_by_sentence_cooccurrence(body, &entities);
1287 let urls = extract_urls(body);
1288 Ok(ExtractionResult {
1289 entities,
1290 relationships,
1291 relationships_truncated,
1292 extraction_method: "regex-only".to_string(),
1293 urls,
1294 })
1295 }
1296}
1297
1298#[cfg(test)]
1299mod tests {
1300 use super::*;
1301
1302 fn make_paths() -> AppPaths {
1303 use std::path::PathBuf;
1304 AppPaths {
1305 db: PathBuf::from("/tmp/test.sqlite"),
1306 models: PathBuf::from("/tmp/test_models"),
1307 }
1308 }
1309
1310 #[test]
1311 fn regex_email_captures_address() {
1312 let ents = apply_regex_prefilter("contact: someone@company.com for more info");
1313 assert!(ents
1315 .iter()
1316 .any(|e| e.name == "someone@company.com" && e.entity_type == "concept"));
1317 }
1318
1319 #[test]
1320 fn regex_all_caps_filters_pt_rule_word() {
1321 let ents = apply_regex_prefilter("NUNCA do this. PROIBIDO use X. DEVE follow Y.");
1323 assert!(
1324 !ents.iter().any(|e| e.name == "NUNCA"),
1325 "NUNCA must be filtered as a stopword"
1326 );
1327 assert!(
1328 !ents.iter().any(|e| e.name == "PROIBIDO"),
1329 "PROIBIDO must be filtered"
1330 );
1331 assert!(
1332 !ents.iter().any(|e| e.name == "DEVE"),
1333 "DEVE must be filtered"
1334 );
1335 }
1336
1337 #[test]
1338 fn regex_all_caps_accepts_underscored_constant() {
1339 let ents = apply_regex_prefilter("configure MAX_RETRY=3 and API_TIMEOUT=30");
1341 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1342 assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1343 }
1344
1345 #[test]
1346 fn regex_all_caps_accepts_domain_acronym() {
1347 let ents = apply_regex_prefilter("OPENAI launched GPT-5 with NVIDIA H100");
1349 assert!(ents.iter().any(|e| e.name == "OPENAI"));
1350 assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1351 }
1352
1353 #[test]
1354 fn regex_url_does_not_appear_in_apply_regex_prefilter() {
1355 let ents = apply_regex_prefilter("see https://docs.rs/crate for details");
1357 assert!(
1358 !ents.iter().any(|e| e.name.starts_with("https://")),
1359 "URLs must not appear as entities after the P0-2 split"
1360 );
1361 }
1362
1363 #[test]
1364 fn extract_urls_captures_https() {
1365 let urls = extract_urls("see https://docs.rs/crate for details");
1366 assert_eq!(urls.len(), 1);
1367 assert_eq!(urls[0].url, "https://docs.rs/crate");
1368 assert!(urls[0].offset > 0);
1369 }
1370
1371 #[test]
1372 fn extract_urls_trim_sufixo_pontuacao() {
1373 let urls = extract_urls("link: https://example.com/path. fim");
1374 assert!(!urls.is_empty());
1375 assert!(
1376 !urls[0].url.ends_with('.'),
1377 "sufixo ponto deve ser removido"
1378 );
1379 }
1380
1381 #[test]
1382 fn extract_urls_dedupes_repeated() {
1383 let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1384 let urls = extract_urls(body);
1385 assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1386 }
1387
1388 #[test]
1389 fn regex_uuid_captura_identificador() {
1390 let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1391 assert!(ents.iter().any(|e| e.entity_type == "concept"));
1392 }
1393
1394 #[test]
1395 fn regex_all_caps_captura_constante() {
1396 let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1397 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1398 assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1399 }
1400
1401 #[test]
1402 fn regex_all_caps_ignores_short_words() {
1403 let ents = apply_regex_prefilter("use AI em seu projeto");
1404 assert!(
1405 !ents.iter().any(|e| e.name == "AI"),
1406 "AI tem apenas 2 chars, deve ser ignorado"
1407 );
1408 }
1409
1410 #[test]
1411 fn iob_decodes_per_to_person() {
1412 let tokens = vec![
1413 "John".to_string(),
1414 "Doe".to_string(),
1415 "trabalhou".to_string(),
1416 ];
1417 let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1418 let ents = iob_to_entities(&tokens, &labels);
1419 assert_eq!(ents.len(), 1);
1420 assert_eq!(ents[0].entity_type, "person");
1421 assert!(ents[0].name.contains("John"));
1422 }
1423
1424 #[test]
1425 fn iob_strip_subword_b_prefix() {
1426 let tokens = vec!["Open".to_string(), "##AI".to_string()];
1429 let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1430 let ents = iob_to_entities(&tokens, &labels);
1431 assert!(
1432 ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1433 "should merge ##AI or discard"
1434 );
1435 }
1436
1437 #[test]
1438 fn iob_subword_orphan_discards() {
1439 let tokens = vec!["##AI".to_string()];
1441 let labels = vec!["B-ORG".to_string()];
1442 let ents = iob_to_entities(&tokens, &labels);
1443 assert!(
1444 ents.is_empty(),
1445 "orphan subword without an active entity must be discarded"
1446 );
1447 }
1448
1449 #[test]
1450 fn iob_maps_date_to_date_v1025() {
1451 let tokens = vec!["January".to_string(), "2024".to_string()];
1453 let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1454 let ents = iob_to_entities(&tokens, &labels);
1455 assert_eq!(
1456 ents.len(),
1457 1,
1458 "DATE must be emitted as an entity in v1.0.25"
1459 );
1460 assert_eq!(ents[0].entity_type, "date");
1461 }
1462
1463 #[test]
1464 fn iob_maps_org_to_organization_v1025() {
1465 let tokens = vec!["Empresa".to_string()];
1467 let labels = vec!["B-ORG".to_string()];
1468 let ents = iob_to_entities(&tokens, &labels);
1469 assert_eq!(ents[0].entity_type, "organization");
1470 }
1471
1472 #[test]
1473 fn iob_maps_org_sdk_to_tool() {
1474 let tokens = vec!["tokio-sdk".to_string()];
1475 let labels = vec!["B-ORG".to_string()];
1476 let ents = iob_to_entities(&tokens, &labels);
1477 assert_eq!(ents[0].entity_type, "tool");
1478 }
1479
1480 #[test]
1481 fn iob_maps_loc_to_location_v1025() {
1482 let tokens = vec!["Brasil".to_string()];
1484 let labels = vec!["B-LOC".to_string()];
1485 let ents = iob_to_entities(&tokens, &labels);
1486 assert_eq!(ents[0].entity_type, "location");
1487 }
1488
1489 #[test]
1490 fn build_relationships_respeitam_max_rels() {
1491 let entities: Vec<NewEntity> = (0..20)
1492 .map(|i| NewEntity {
1493 name: format!("entidade_{i}"),
1494 entity_type: "concept".to_string(),
1495 description: None,
1496 })
1497 .collect();
1498 let (rels, truncated) = build_relationships(&entities);
1499 let max_rels = crate::constants::max_relationships_per_memory();
1500 assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1501 if rels.len() == max_rels {
1502 assert!(truncated, "truncated deve ser true quando atingiu o cap");
1503 }
1504 }
1505
1506 #[test]
1507 fn build_relationships_without_duplicates() {
1508 let entities: Vec<NewEntity> = (0..5)
1509 .map(|i| NewEntity {
1510 name: format!("ent_{i}"),
1511 entity_type: "concept".to_string(),
1512 description: None,
1513 })
1514 .collect();
1515 let (rels, _truncated) = build_relationships(&entities);
1516 let mut pares: std::collections::HashSet<(String, String)> =
1517 std::collections::HashSet::new();
1518 for r in &rels {
1519 let par = (r.source.clone(), r.target.clone());
1520 assert!(pares.insert(par), "par duplicado encontrado");
1521 }
1522 }
1523
1524 #[test]
1525 fn merge_dedupes_by_lowercase_name() {
1526 let a = vec![ExtractedEntity {
1529 name: "Rust".to_string(),
1530 entity_type: "concept".to_string(),
1531 }];
1532 let b = vec![ExtractedEntity {
1533 name: "rust".to_string(),
1534 entity_type: "concept".to_string(),
1535 }];
1536 let merged = merge_and_deduplicate(a, b);
1537 assert_eq!(
1538 merged.len(),
1539 1,
1540 "rust and Rust with the same type are the same entity"
1541 );
1542 }
1543
1544 #[test]
1545 fn regex_extractor_implements_trait() {
1546 let extractor = RegexExtractor;
1547 let result = extractor
1548 .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1549 .unwrap();
1550 assert!(!result.entities.is_empty());
1551 }
1552
1553 #[test]
1554 fn extract_returns_ok_without_model() {
1555 let paths = make_paths();
1557 let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1558 let result = extract_graph_auto(body, &paths).unwrap();
1559 assert!(result
1560 .entities
1561 .iter()
1562 .any(|e| e.name.contains("teste@exemplo.com")));
1563 }
1564
1565 #[test]
1566 fn stopwords_filter_v1024_terms() {
1567 let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1570 DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1571 let ents = apply_regex_prefilter(body);
1572 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1573 for word in &[
1574 "ACEITE",
1575 "ACK",
1576 "ACL",
1577 "BORDA",
1578 "CHECKLIST",
1579 "COMPLETED",
1580 "CONFIRME",
1581 "DEVEMOS",
1582 "DONE",
1583 "FIXED",
1584 "NEGUE",
1585 "PENDING",
1586 "PLAN",
1587 "PODEMOS",
1588 "RECUSE",
1589 "TOKEN",
1590 "VAMOS",
1591 ] {
1592 assert!(
1593 !names.contains(word),
1594 "v1.0.24 stopword {word} should be filtered but was found in entities"
1595 );
1596 }
1597 }
1598
1599 #[test]
1600 fn dedup_normalizes_unicode_combining_marks() {
1601 let nfc = vec![ExtractedEntity {
1605 name: "Caf\u{e9}".to_string(),
1606 entity_type: "concept".to_string(),
1607 }];
1608 let nfd_name = "Cafe\u{301}".to_string();
1610 let nfd = vec![ExtractedEntity {
1611 name: nfd_name,
1612 entity_type: "concept".to_string(),
1613 }];
1614 let merged = merge_and_deduplicate(nfc, nfd);
1615 assert_eq!(
1616 merged.len(),
1617 1,
1618 "NFC 'Caf\\u{{e9}}' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1619 );
1620 }
1621
1622 #[test]
1625 fn predict_batch_output_count_matches_input() {
1626 let w1_ids: Vec<u32> = vec![101, 100, 102];
1632 let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1633 let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1634 let w2_tok: Vec<String> = vec![
1635 "[CLS]".into(),
1636 "world".into(),
1637 "foo".into(),
1638 "bar".into(),
1639 "[SEP]".into(),
1640 ];
1641 let windows: Vec<(Vec<u32>, Vec<String>)> =
1642 vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1643
1644 let device = Device::Cpu;
1647 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1648 assert_eq!(max_len, 5, "max_len deve ser 5");
1649
1650 let mut padded_ids: Vec<Tensor> = Vec::new();
1651 for (ids, _) in &windows {
1652 let len = ids.len();
1653 let pad_right = max_len - len;
1654 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1655 let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1656 let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1657 assert_eq!(
1658 t.dims(),
1659 &[max_len],
1660 "each window must have shape (max_len,) after padding"
1661 );
1662 padded_ids.push(t);
1663 }
1664
1665 let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1666 assert_eq!(
1667 stacked.dims(),
1668 &[2, max_len],
1669 "stack deve produzir (batch_size=2, max_len=5)"
1670 );
1671
1672 let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; let fake_logits =
1676 Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1677 for (i, (ids, _)) in windows.iter().enumerate() {
1678 let real_len = ids.len();
1679 let example = fake_logits.get(i).unwrap();
1680 let sliced = example.narrow(0, 0, real_len).unwrap();
1681 assert_eq!(
1682 sliced.dims(),
1683 &[real_len, 9],
1684 "narrow deve preservar apenas {real_len} tokens reais"
1685 );
1686 }
1687 }
1688
1689 #[test]
1690 fn predict_batch_empty_windows_returns_empty() {
1691 let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1694 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1695 assert_eq!(max_len, 0, "zero windows → max_len 0");
1696 let result: Vec<Vec<String>> = if max_len == 0 {
1699 Vec::new()
1700 } else {
1701 unreachable!()
1702 };
1703 assert!(result.is_empty());
1704 }
1705
1706 #[test]
1707 fn ner_batch_size_default_is_8() {
1708 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1711 assert_eq!(crate::constants::ner_batch_size(), 8);
1712 }
1713
1714 #[test]
1715 fn ner_batch_size_env_override_clamped() {
1716 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1718 assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1719
1720 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1721 assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1722
1723 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1724 assert_eq!(
1725 crate::constants::ner_batch_size(),
1726 4,
1727 "valid value preserved"
1728 );
1729
1730 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1731 }
1732
1733 #[test]
1734 fn extraction_method_regex_only_unchanged() {
1735 let result = RegexExtractor.extract("contact: dev@acme.io").unwrap();
1738 assert_eq!(
1739 result.extraction_method, "regex-only",
1740 "RegexExtractor must return regex-only"
1741 );
1742 }
1743
1744 #[test]
1747 fn extend_suffix_pure_numeric_unchanged() {
1748 let ents = vec![ExtractedEntity {
1750 name: "GPT".to_string(),
1751 entity_type: "concept".to_string(),
1752 }];
1753 let result = extend_with_numeric_suffix(ents, "using GPT-5 in the project");
1754 assert_eq!(
1755 result[0].name, "GPT-5",
1756 "purely numeric suffix must be extended"
1757 );
1758 }
1759
1760 #[test]
1761 fn extend_suffix_alphanumeric_letter_after_digit() {
1762 let ents = vec![ExtractedEntity {
1764 name: "GPT".to_string(),
1765 entity_type: "concept".to_string(),
1766 }];
1767 let result = extend_with_numeric_suffix(ents, "using GPT-4o for advanced tasks");
1768 assert_eq!(result[0].name, "GPT-4o", "suffix '4o' must be accepted");
1769 }
1770
1771 #[test]
1772 fn extend_suffix_alphanumeric_b_suffix() {
1773 let ents = vec![ExtractedEntity {
1775 name: "Llama".to_string(),
1776 entity_type: "concept".to_string(),
1777 }];
1778 let result = extend_with_numeric_suffix(ents, "Llama-5b open-weight model");
1779 assert_eq!(result[0].name, "Llama-5b", "suffix '5b' must be accepted");
1780 }
1781
1782 #[test]
1783 fn extend_suffix_alphanumeric_x_suffix() {
1784 let ents = vec![ExtractedEntity {
1786 name: "Mistral".to_string(),
1787 entity_type: "concept".to_string(),
1788 }];
1789 let result = extend_with_numeric_suffix(ents, "testing Mistral-8x in production");
1790 assert_eq!(result[0].name, "Mistral-8x", "suffix '8x' must be accepted");
1791 }
1792
1793 #[test]
1796 fn augment_versioned_gpt4o() {
1797 let result = augment_versioned_model_names(vec![], "using GPT-4o for analysis");
1799 assert!(
1800 result.iter().any(|e| e.name == "GPT-4o"),
1801 "GPT-4o must be captured by augment, found: {:?}",
1802 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1803 );
1804 }
1805
1806 #[test]
1807 fn augment_versioned_claude_4_sonnet() {
1808 let result =
1810 augment_versioned_model_names(vec![], "best model: Claude 4 Sonnet released today");
1811 assert!(
1812 result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1813 "Claude 4 Sonnet must be captured, found: {:?}",
1814 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1815 );
1816 }
1817
1818 #[test]
1819 fn augment_versioned_llama_3_pro() {
1820 let result =
1822 augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1823 assert!(
1824 result.iter().any(|e| e.name == "Llama 3 Pro"),
1825 "Llama 3 Pro deve ser capturado, achados: {:?}",
1826 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1827 );
1828 }
1829
1830 #[test]
1831 fn augment_versioned_mixtral_8x7b() {
1832 let result =
1834 augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1835 assert!(
1836 result.iter().any(|e| e.name == "Mixtral 8x7B"),
1837 "Mixtral 8x7B deve ser capturado, achados: {:?}",
1838 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1839 );
1840 }
1841
1842 #[test]
1843 fn augment_versioned_does_not_duplicate_existing() {
1844 let existing = vec![ExtractedEntity {
1846 name: "Claude 4".to_string(),
1847 entity_type: "concept".to_string(),
1848 }];
1849 let result = augment_versioned_model_names(existing, "using Claude 4 in the project");
1850 let count = result.iter().filter(|e| e.name == "Claude 4").count();
1851 assert_eq!(count, 1, "Claude 4 must not be duplicated");
1852 }
1853
1854 #[test]
1857 fn stopwords_filter_url_jwt_api_v1025() {
1858 let body = "We use URL, JWT, and API REST in our LLM-powered CLI via HTTP/HTTPS and UI.";
1860 let ents = apply_regex_prefilter(body);
1861 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1862 for blocked in &[
1863 "URL", "JWT", "API", "REST", "LLM", "CLI", "HTTP", "HTTPS", "UI",
1864 ] {
1865 assert!(
1866 !names.contains(blocked),
1867 "v1.0.25 stopword {blocked} leaked as entity; found names: {names:?}"
1868 );
1869 }
1870 }
1871
1872 #[test]
1875 fn section_markers_etapa_fase_filtered_v1025() {
1876 let body = "Etapa 3 do plano: implementar Fase 1 da Migra\u{e7}\u{e3}o.";
1880 let ents = apply_regex_prefilter(body);
1881 assert!(
1882 !ents
1883 .iter()
1884 .any(|e| e.name.contains("Etapa") || e.name.contains("Fase")),
1885 "section markers must be stripped; entities: {:?}",
1886 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1887 );
1888 }
1889
1890 #[test]
1891 fn section_markers_passo_secao_filtered_v1025() {
1892 let body = "Siga Passo 2 conforme Se\u{e7}\u{e3}o 3 do manual.";
1895 let ents = apply_regex_prefilter(body);
1896 assert!(
1897 !ents
1898 .iter()
1899 .any(|e| e.name.contains("Passo") || e.name.contains("Se\u{e7}\u{e3}o")),
1900 "Passo/Se\\u{{e7}}\\u{{e3}}o section markers must be stripped; entities: {:?}",
1901 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1902 );
1903 }
1904
1905 #[test]
1908 fn brand_camelcase_extracted_as_organization_v1025() {
1909 let body = "OpenAI launched GPT-4 and PostgreSQL added pgvector.";
1911 let ents = apply_regex_prefilter(body);
1912 let openai = ents.iter().find(|e| e.name == "OpenAI");
1913 assert!(
1914 openai.is_some(),
1915 "OpenAI must be extracted by CamelCase brand regex; entities: {:?}",
1916 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1917 );
1918 assert_eq!(
1919 openai.unwrap().entity_type,
1920 "organization",
1921 "brand CamelCase must map to organization (V008)"
1922 );
1923 }
1924
1925 #[test]
1926 fn brand_postgresql_extracted_as_organization_v1025() {
1927 let body = "migrating from MySQL to PostgreSQL for better performance.";
1928 let ents = apply_regex_prefilter(body);
1929 assert!(
1930 ents.iter()
1931 .any(|e| e.name == "PostgreSQL" && e.entity_type == "organization"),
1932 "PostgreSQL must be extracted as organization; entities: {:?}",
1933 ents.iter()
1934 .map(|e| (&e.name, &e.entity_type))
1935 .collect::<Vec<_>>()
1936 );
1937 }
1938
1939 #[test]
1942 fn iob_org_maps_to_organization_not_project_v1025() {
1943 let tokens = vec!["Microsoft".to_string()];
1945 let labels = vec!["B-ORG".to_string()];
1946 let ents = iob_to_entities(&tokens, &labels);
1947 assert_eq!(
1948 ents[0].entity_type, "organization",
1949 "B-ORG must map to organization (V008); got {}",
1950 ents[0].entity_type
1951 );
1952 }
1953
1954 #[test]
1955 fn iob_loc_maps_to_location_not_concept_v1025() {
1956 let tokens = vec!["S\u{e3}o".to_string(), "Paulo".to_string()];
1959 let labels = vec!["B-LOC".to_string(), "I-LOC".to_string()];
1960 let ents = iob_to_entities(&tokens, &labels);
1961 assert_eq!(
1962 ents[0].entity_type, "location",
1963 "B-LOC must map to location (V008); got {}",
1964 ents[0].entity_type
1965 );
1966 }
1967
1968 #[test]
1969 fn iob_date_maps_to_date_not_discarded_v1025() {
1970 let tokens = vec!["2025".to_string(), "-".to_string(), "12".to_string()];
1972 let labels = vec![
1973 "B-DATE".to_string(),
1974 "I-DATE".to_string(),
1975 "I-DATE".to_string(),
1976 ];
1977 let ents = iob_to_entities(&tokens, &labels);
1978 assert_eq!(
1979 ents.len(),
1980 1,
1981 "DATE entity must be emitted (V008); entities: {ents:?}"
1982 );
1983 assert_eq!(ents[0].entity_type, "date");
1984 }
1985
1986 #[test]
1989 fn pt_verb_le_filtered_as_per_v1025() {
1990 let tokens = vec!["L\u{ea}".to_string(), "o".to_string(), "livro".to_string()];
1993 let labels = vec!["B-PER".to_string(), "O".to_string(), "O".to_string()];
1994 let ents = iob_to_entities(&tokens, &labels);
1995 assert!(
1996 !ents
1997 .iter()
1998 .any(|e| e.name == "L\u{ea}" && e.entity_type == "person"),
1999 "PT verb 'L\\u{{ea}}' tagged B-PER must be filtered; entities: {ents:?}"
2000 );
2001 }
2002
2003 #[test]
2004 fn pt_verb_ver_filtered_as_per_v1025() {
2005 let tokens = vec!["Ver".to_string()];
2007 let labels = vec!["B-PER".to_string()];
2008 let ents = iob_to_entities(&tokens, &labels);
2009 assert!(
2010 ents.is_empty(),
2011 "PT verb 'Ver' tagged B-PER must be filtered; entities: {ents:?}"
2012 );
2013 }
2014
2015 fn entity(name: &str, entity_type: &str) -> ExtractedEntity {
2018 ExtractedEntity {
2019 name: name.to_string(),
2020 entity_type: entity_type.to_string(),
2021 }
2022 }
2023
2024 #[test]
2025 fn merge_resolves_sonne_vs_sonnet_keeps_longest_v1025() {
2026 let regex = vec![entity("Sonne", "concept")];
2028 let ner = vec![entity("Sonnet", "concept")];
2029 let result = merge_and_deduplicate(regex, ner);
2030 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2031 assert_eq!(result[0].name, "Sonnet");
2032 }
2033
2034 #[test]
2035 fn merge_resolves_open_vs_openai_keeps_longest_v1025() {
2036 let regex = vec![
2038 entity("Open", "organization"),
2039 entity("OpenAI", "organization"),
2040 ];
2041 let result = merge_and_deduplicate(regex, vec![]);
2042 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2043 assert_eq!(result[0].name, "OpenAI");
2044 }
2045
2046 #[test]
2047 fn merge_keeps_both_when_no_containment_v1025() {
2048 let regex = vec![entity("Alice", "person"), entity("Bob", "person")];
2050 let result = merge_and_deduplicate(regex, vec![]);
2051 assert_eq!(result.len(), 2, "expected 2 entities, got: {result:?}");
2052 }
2053
2054 #[test]
2055 fn merge_respects_entity_type_boundary_v1025() {
2056 let regex = vec![entity("Apple", "organization"), entity("Apple", "concept")];
2058 let result = merge_and_deduplicate(regex, vec![]);
2059 assert_eq!(
2060 result.len(),
2061 2,
2062 "expected 2 entities (different types), got: {result:?}"
2063 );
2064 }
2065
2066 #[test]
2067 fn merge_case_insensitive_dedup_v1025() {
2068 let regex = vec![
2070 entity("OpenAI", "organization"),
2071 entity("openai", "organization"),
2072 ];
2073 let result = merge_and_deduplicate(regex, vec![]);
2074 assert_eq!(
2075 result.len(),
2076 1,
2077 "expected 1 entity after case-insensitive dedup, got: {result:?}"
2078 );
2079 }
2080
2081 #[test]
2084 fn iob_section_marker_etapa_filtered_v1025() {
2085 let tokens = vec!["Etapa".to_string(), "3".to_string()];
2087 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2088 let ents = iob_to_entities(&tokens, &labels);
2089 assert!(
2090 !ents.iter().any(|e| e.name.contains("Etapa")),
2091 "section marker 'Etapa 3' from BERT must be filtered; entities: {ents:?}"
2092 );
2093 }
2094
2095 #[test]
2096 fn iob_section_marker_fase_filtered_v1025() {
2097 let tokens = vec!["Fase".to_string(), "1".to_string()];
2099 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2100 let ents = iob_to_entities(&tokens, &labels);
2101 assert!(
2102 !ents.iter().any(|e| e.name.contains("Fase")),
2103 "section marker 'Fase 1' from BERT must be filtered; entities: {ents:?}"
2104 );
2105 }
2106
2107 #[test]
2110 fn extract_graph_auto_handles_large_body_under_30s() {
2111 let body = "x ".repeat(40_000);
2114 let paths = make_paths();
2115 let start = std::time::Instant::now();
2116 let result = extract_graph_auto(&body, &paths).expect("extraction must not error");
2117 let elapsed = start.elapsed();
2118 assert!(
2119 elapsed.as_secs() < 30,
2120 "extract_graph_auto took {}s for 80 KB body (cap should keep it well under 30s)",
2121 elapsed.as_secs()
2122 );
2123 let _ = result.entities;
2125 }
2126
2127 #[test]
2130 fn pt_uppercase_stopwords_filtered_v1031() {
2131 let body = "Para o ADAPTER funcionar com PROJETO em modo PASSIVA, devemos usar \
2132 SOMENTE LEITURA conforme a REGRA OBRIGATORIA do EXEMPLO DEFAULT.";
2133 let ents = apply_regex_prefilter(body);
2134 let names: Vec<String> = ents.iter().map(|e| e.name.to_uppercase()).collect();
2135 for stop in &[
2136 "ADAPTER",
2137 "PROJETO",
2138 "PASSIVA",
2139 "SOMENTE",
2140 "LEITURA",
2141 "REGRA",
2142 "OBRIGATORIA",
2143 "EXEMPLO",
2144 "DEFAULT",
2145 ] {
2146 assert!(
2147 !names.contains(&stop.to_string()),
2148 "v1.0.31 A11 stoplist failed: {stop} leaked as entity; got names: {names:?}"
2149 );
2150 }
2151 }
2152
2153 #[test]
2154 fn pt_underscored_identifier_preserved_v1031() {
2155 let ents = apply_regex_prefilter("configure FLOWAIPER_API_KEY=foo and MAX_TIMEOUT=30");
2158 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
2159 assert!(names.contains(&"FLOWAIPER_API_KEY"));
2160 assert!(names.contains(&"MAX_TIMEOUT"));
2161 }
2162
2163 #[test]
2166 fn build_relationships_by_sentence_only_links_co_occurring_entities() {
2167 let body = "Alice met Bob at the conference. Carol works alone in another room.";
2168 let entities = vec![
2169 NewEntity {
2170 name: "Alice".to_string(),
2171 entity_type: "person".to_string(),
2172 description: None,
2173 },
2174 NewEntity {
2175 name: "Bob".to_string(),
2176 entity_type: "person".to_string(),
2177 description: None,
2178 },
2179 NewEntity {
2180 name: "Carol".to_string(),
2181 entity_type: "person".to_string(),
2182 description: None,
2183 },
2184 ];
2185 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2186 assert!(!truncated);
2187 assert_eq!(
2188 rels.len(),
2189 1,
2190 "only Alice/Bob should pair (same sentence); Carol is isolated"
2191 );
2192 let pair = (rels[0].source.as_str(), rels[0].target.as_str());
2193 assert!(
2194 matches!(pair, ("Alice", "Bob") | ("Bob", "Alice")),
2195 "unexpected pair {pair:?}"
2196 );
2197 }
2198
2199 #[test]
2200 fn build_relationships_by_sentence_returns_empty_for_single_entity() {
2201 let body = "Alice is here.";
2202 let entities = vec![NewEntity {
2203 name: "Alice".to_string(),
2204 entity_type: "person".to_string(),
2205 description: None,
2206 }];
2207 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2208 assert!(rels.is_empty());
2209 assert!(!truncated);
2210 }
2211
2212 #[test]
2213 fn build_relationships_by_sentence_dedupes_pairs_across_sentences() {
2214 let body = "Alice met Bob. Bob saw Alice again.";
2215 let entities = vec![
2216 NewEntity {
2217 name: "Alice".to_string(),
2218 entity_type: "person".to_string(),
2219 description: None,
2220 },
2221 NewEntity {
2222 name: "Bob".to_string(),
2223 entity_type: "person".to_string(),
2224 description: None,
2225 },
2226 ];
2227 let (rels, _) = build_relationships_by_sentence_cooccurrence(body, &entities);
2228 assert_eq!(
2229 rels.len(),
2230 1,
2231 "Alice/Bob pair must be emitted only once even when co-occurring in multiple sentences"
2232 );
2233 }
2234
2235 #[test]
2236 fn extraction_max_tokens_default_is_5000() {
2237 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2238 assert_eq!(crate::constants::extraction_max_tokens(), 5_000);
2239 }
2240
2241 #[test]
2242 fn extraction_max_tokens_env_override_clamped() {
2243 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200");
2244 assert_eq!(
2245 crate::constants::extraction_max_tokens(),
2246 5_000,
2247 "value below 512 must fall back to default"
2248 );
2249
2250 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200000");
2251 assert_eq!(
2252 crate::constants::extraction_max_tokens(),
2253 5_000,
2254 "value above 100_000 must fall back to default"
2255 );
2256
2257 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "8000");
2258 assert_eq!(
2259 crate::constants::extraction_max_tokens(),
2260 8_000,
2261 "valid value must be honoured"
2262 );
2263
2264 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2265 }
2266}