1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::OnceLock;
9
10use anyhow::{Context, Result};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::{Linear, Module, VarBuilder};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use regex::Regex;
15use serde::Deserialize;
16use unicode_normalization::UnicodeNormalization;
17
18use crate::paths::AppPaths;
19use crate::storage::entities::{NewEntity, NewRelationship};
20
21const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
22const MAX_SEQ_LEN: usize = 512;
23const STRIDE: usize = 256;
24const MAX_ENTS: usize = 30;
25#[cfg(test)]
28const TOP_K_RELATIONS: usize = 5;
29const DEFAULT_RELATION: &str = "mentions";
30const MIN_ENTITY_CHARS: usize = 2;
31
32static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
33static REGEX_URL: OnceLock<Regex> = OnceLock::new();
34static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
35static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
36static REGEX_SECTION_MARKER: OnceLock<Regex> = OnceLock::new();
38static REGEX_BRAND_CAMEL: OnceLock<Regex> = OnceLock::new();
40
41const ALL_CAPS_STOPWORDS: &[&str] = &[
59 "ACEITE",
60 "ACK",
61 "ACL",
62 "ACRESCENTADO",
63 "ADAPTER",
64 "ADICIONADA",
65 "ADICIONADAS",
66 "ADICIONADO",
67 "ADICIONADOS",
68 "ADICIONAR",
69 "AGENTS",
70 "ALL",
71 "ALTA",
72 "ALWAYS",
73 "API",
74 "ARTEFATOS",
75 "ATIVA",
76 "ATIVO",
77 "BAIXA",
78 "BANCO",
79 "BORDA",
80 "BLOQUEAR",
81 "BUG",
82 "CAPÍTULO",
83 "CASO",
84 "CHECKLIST",
85 "CLARO",
86 "CLI",
87 "COMPLETED",
88 "CONFIRMADO",
89 "CONFIRMARAM",
90 "CONFIRMEI",
91 "CONFIRMOU",
92 "CONFIRME",
93 "CONTRATO",
94 "CRÍTICO",
95 "CRITICAL",
96 "CSV",
97 "DEFAULT",
98 "DEVE",
99 "DEVEMOS",
100 "DISCO",
101 "DONE",
102 "EFEITO",
103 "ENTRADA",
104 "ERROR",
105 "ESCRITA",
106 "ESSA",
107 "ESSE",
108 "ESSENCIAL",
109 "ESTA",
110 "ESTE",
111 "ETAPA",
112 "EVITAR",
113 "EXEMPLO",
114 "EXPANDIR",
115 "EXPOR",
116 "FALHA",
117 "FASE",
118 "FIXED",
119 "FIXME",
120 "FORBIDDEN",
121 "HACK",
122 "HEARTBEAT",
123 "HTTP",
124 "HTTPS",
125 "INATIVO",
126 "JAMAIS",
127 "JSON",
128 "JWT",
129 "LEITURA",
130 "LLM",
131 "MUST",
132 "NEGUE",
133 "NEVER",
134 "NOTE",
135 "NUNCA",
136 "OBRIGATORIA",
137 "OBRIGATÓRIO",
138 "PADRÃO",
139 "PASSIVA",
140 "PASSO",
141 "PENDING",
142 "PLAN",
143 "PODEMOS",
144 "PROIBIDO",
145 "PROJETO",
146 "RECUSE",
147 "REGRA",
148 "REGRAS",
149 "REQUIRED",
150 "REQUISITO",
151 "REST",
152 "SEÇÃO",
153 "SEMPRE",
154 "SHALL",
155 "SHOULD",
156 "SOMENTE",
157 "SOUL",
158 "TODAS",
159 "TODO",
160 "TODOS",
161 "TOKEN",
162 "TOOLS",
163 "TSV",
164 "UI",
165 "URL",
166 "USAR",
167 "VALIDAR",
168 "VAMOS",
169 "VOCÊ",
170 "WARNING",
171 "XML",
172 "YAML",
173];
174
175const HTTP_METHODS: &[&str] = &[
178 "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
179];
180
181fn is_filtered_all_caps(token: &str) -> bool {
182 let is_identifier = token.contains('_');
184 if is_identifier {
185 return false;
186 }
187 ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
188}
189
190fn regex_email() -> &'static Regex {
191 REGEX_EMAIL.get_or_init(|| {
192 Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
193 .expect("compile-time validated email regex literal")
194 })
195}
196
197fn regex_url() -> &'static Regex {
198 REGEX_URL.get_or_init(|| {
199 Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#)
200 .expect("compile-time validated URL regex literal")
201 })
202}
203
204fn regex_uuid() -> &'static Regex {
205 REGEX_UUID.get_or_init(|| {
206 Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
207 .expect("compile-time validated UUID regex literal")
208 })
209}
210
211fn regex_all_caps() -> &'static Regex {
212 REGEX_ALL_CAPS.get_or_init(|| {
213 Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b")
214 .expect("compile-time validated all-caps regex literal")
215 })
216}
217
218fn regex_section_marker() -> &'static Regex {
219 REGEX_SECTION_MARKER.get_or_init(|| {
220 Regex::new("\\b(?:Etapa|Fase|Passo|Camada|Se\u{00e7}\u{00e3}o|Cap\u{00ed}tulo)\\s+\\d+\\b")
227 .expect("compile-time validated section marker regex literal")
228 })
229}
230
231fn regex_brand_camel() -> &'static Regex {
232 REGEX_BRAND_CAMEL.get_or_init(|| {
233 Regex::new(r"\b[A-Z][a-z]+[A-Z][A-Za-z]+\b")
236 .expect("compile-time validated CamelCase brand regex literal")
237 })
238}
239
240#[derive(Debug, Clone, PartialEq)]
241pub struct ExtractedEntity {
242 pub name: String,
243 pub entity_type: String,
244}
245
246#[derive(Debug, Clone)]
248pub struct ExtractedUrl {
249 pub url: String,
250 pub offset: usize,
252}
253
254#[derive(Debug, Clone)]
255pub struct ExtractionResult {
256 pub entities: Vec<NewEntity>,
257 pub relationships: Vec<NewRelationship>,
258 pub relationships_truncated: bool,
261 pub extraction_method: String,
264 pub urls: Vec<ExtractedUrl>,
266}
267
268pub trait Extractor: Send + Sync {
269 fn extract(&self, body: &str) -> Result<ExtractionResult>;
270}
271
272#[derive(Deserialize)]
273struct ModelConfig {
274 #[serde(default)]
275 id2label: HashMap<String, String>,
276 hidden_size: usize,
277}
278
279struct BertNerModel {
280 bert: BertModel,
281 classifier: Linear,
282 device: Device,
283 id2label: HashMap<usize, String>,
284}
285
286impl BertNerModel {
287 fn load(model_dir: &Path) -> Result<Self> {
288 let config_path = model_dir.join("config.json");
289 let weights_path = model_dir.join("model.safetensors");
290
291 let config_str = std::fs::read_to_string(&config_path)
292 .with_context(|| format!("lendo config.json em {config_path:?}"))?;
293 let model_cfg: ModelConfig =
294 serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
295
296 let id2label: HashMap<usize, String> = model_cfg
297 .id2label
298 .into_iter()
299 .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
300 .collect();
301
302 let num_labels = id2label.len().max(9);
303 let hidden_size = model_cfg.hidden_size;
304
305 let bert_config_str = std::fs::read_to_string(&config_path)
306 .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
307 let bert_cfg: BertConfig =
308 serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
309
310 let device = Device::Cpu;
311
312 let vb = unsafe {
320 VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
321 .with_context(|| format!("mapping {weights_path:?}"))?
322 };
323 let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("loading BertModel")?;
324
325 let cls_vb = vb.pp("classifier");
328 let weight = cls_vb
329 .get((num_labels, hidden_size), "weight")
330 .context("carregando classifier.weight do safetensors")?;
331 let bias = cls_vb
332 .get(num_labels, "bias")
333 .context("carregando classifier.bias do safetensors")?;
334 let classifier = Linear::new(weight, Some(bias));
335
336 Ok(Self {
337 bert,
338 classifier,
339 device,
340 id2label,
341 })
342 }
343
344 fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
345 let len = token_ids.len();
346 let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
347 let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
348
349 let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
350 .context("creating tensor input_ids")?;
351 let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
352 .context("creating tensor token_type_ids")?;
353 let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
354 .context("creating tensor attention_mask")?;
355
356 let sequence_output = self
357 .bert
358 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
359 .context("BertModel forward pass")?;
360
361 let logits = self
362 .classifier
363 .forward(&sequence_output)
364 .context("classifier forward pass")?;
365
366 let logits_2d = logits.squeeze(0).context("removing batch dimension")?;
367
368 let num_tokens = logits_2d.dim(0).context("dim(0)")?;
369
370 let mut labels = Vec::with_capacity(num_tokens);
371 for i in 0..num_tokens {
372 let token_logits = logits_2d.get(i).context("get token logits")?;
373 let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
374 let argmax = vec
375 .iter()
376 .enumerate()
377 .max_by(|(_, a), (_, b)| {
378 a.partial_cmp(b)
379 .expect("BERT NER logits invariant: no NaN in classifier output")
380 })
381 .map(|(idx, _)| idx)
382 .unwrap_or(0);
383 let label = self
384 .id2label
385 .get(&argmax)
386 .cloned()
387 .unwrap_or_else(|| "O".to_string());
388 labels.push(label);
389 }
390
391 Ok(labels)
392 }
393
394 fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
403 let batch_size = windows.len();
404 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
405 if max_len == 0 {
406 return Ok(vec![vec![]; batch_size]);
407 }
408
409 let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
410 let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
411
412 for (ids, _) in windows {
413 let len = ids.len();
414 let pad_right = max_len - len;
415
416 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
417 let t = Tensor::from_vec(ids_i64, len, &self.device)
419 .context("creating id tensor for batch")?;
420 let t = t
421 .pad_with_zeros(0, 0, pad_right)
422 .context("padding id tensor")?;
423 padded_ids.push(t);
424
425 let mut mask_i64 = vec![1i64; len];
427 mask_i64.extend(vec![0i64; pad_right]);
428 let m = Tensor::from_vec(mask_i64, max_len, &self.device)
429 .context("creating mask tensor for batch")?;
430 padded_masks.push(m);
431 }
432
433 let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
435 let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
436 let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
437 .context("creating token_type_ids tensor for batch")?;
438
439 let sequence_output = self
441 .bert
442 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
443 .context("BertModel batch forward pass")?;
444 let logits = self
447 .classifier
448 .forward(&sequence_output)
449 .context("forward pass batch classificador")?;
450 let mut results = Vec::with_capacity(batch_size);
453 for (i, (window_ids, _)) in windows.iter().enumerate() {
454 let example_logits = logits.get(i).context("get logits exemplo")?;
455 let real_len = window_ids.len();
457 let example_slice = example_logits
458 .narrow(0, 0, real_len)
459 .context("narrow para tokens reais")?;
460 let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
461
462 let labels: Vec<String> = logits_2d
463 .iter()
464 .map(|token_logits| {
465 let argmax = token_logits
466 .iter()
467 .enumerate()
468 .max_by(|(_, a), (_, b)| {
469 a.partial_cmp(b)
470 .expect("BERT NER logits invariant: no NaN in classifier output")
471 })
472 .map(|(idx, _)| idx)
473 .unwrap_or(0);
474 self.id2label
475 .get(&argmax)
476 .cloned()
477 .unwrap_or_else(|| "O".to_string())
478 })
479 .collect();
480
481 results.push(labels);
482 }
483
484 Ok(results)
485 }
486}
487
488static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
489
490fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
491 NER_MODEL
492 .get_or_init(|| match load_model(paths) {
493 Ok(m) => Some(m),
494 Err(e) => {
495 tracing::warn!("NER model unavailable (graceful degradation): {e:#}");
496 None
497 }
498 })
499 .as_ref()
500}
501
502fn model_dir(paths: &AppPaths) -> PathBuf {
503 paths.models.join("bert-multilingual-ner")
504}
505
506fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
507 let dir = model_dir(paths);
508 std::fs::create_dir_all(&dir).with_context(|| format!("creating model directory: {dir:?}"))?;
509
510 let weights = dir.join("model.safetensors");
511 let config = dir.join("config.json");
512 let tokenizer = dir.join("tokenizer.json");
513
514 if weights.exists() && config.exists() && tokenizer.exists() {
515 return Ok(dir);
516 }
517
518 tracing::info!("Downloading NER model (first run, ~676 MB)...");
519 crate::output::emit_progress_i18n(
520 "Downloading NER model (first run, ~676 MB)...",
521 crate::i18n::validation::runtime_pt::downloading_ner_model(),
522 );
523
524 let api = huggingface_hub::api::sync::Api::new().context("creating HF Hub client")?;
525 let repo = api.model(MODEL_ID.to_string());
526
527 for (remote, local) in &[
531 ("model.safetensors", "model.safetensors"),
532 ("config.json", "config.json"),
533 ("onnx/tokenizer.json", "tokenizer.json"),
534 ("tokenizer_config.json", "tokenizer_config.json"),
535 ] {
536 let dest = dir.join(local);
537 if !dest.exists() {
538 let src = repo
539 .get(remote)
540 .with_context(|| format!("baixando {remote} do HF Hub"))?;
541 std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
542 }
543 }
544
545 Ok(dir)
546}
547
548fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
549 let dir = ensure_model_files(paths)?;
550 BertNerModel::load(&dir)
551}
552
553fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
554 let mut entities = Vec::new();
555 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
556
557 let add = |entities: &mut Vec<ExtractedEntity>,
558 seen: &mut std::collections::HashSet<String>,
559 name: &str,
560 entity_type: &str| {
561 let name = name.trim().to_string();
562 if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
563 entities.push(ExtractedEntity {
564 name,
565 entity_type: entity_type.to_string(),
566 });
567 }
568 };
569
570 let cleaned = regex_section_marker().replace_all(body, " ");
573 let cleaned = cleaned.as_ref();
574
575 for m in regex_email().find_iter(cleaned) {
576 add(&mut entities, &mut seen, m.as_str(), "concept");
578 }
579 for m in regex_uuid().find_iter(cleaned) {
580 add(&mut entities, &mut seen, m.as_str(), "concept");
581 }
582 for m in regex_all_caps().find_iter(cleaned) {
583 let candidate = m.as_str();
584 if !is_filtered_all_caps(candidate) {
586 add(&mut entities, &mut seen, candidate, "concept");
587 }
588 }
589 for m in regex_brand_camel().find_iter(cleaned) {
592 let name = m.as_str();
593 if !ALL_CAPS_STOPWORDS.contains(&name.to_uppercase().as_str()) {
595 add(&mut entities, &mut seen, name, "organization");
596 }
597 }
598
599 entities
600}
601
602pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
606 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
607 let mut result = Vec::new();
608 for m in regex_url().find_iter(body) {
609 let raw = m.as_str();
610 let cleaned = raw
611 .trim_end_matches('`')
612 .trim_end_matches(',')
613 .trim_end_matches('.')
614 .trim_end_matches(';')
615 .trim_end_matches(')')
616 .trim_end_matches(']')
617 .trim_end_matches('}');
618 if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
619 result.push(ExtractedUrl {
620 url: cleaned.to_string(),
621 offset: m.start(),
622 });
623 }
624 }
625 result
626}
627
628fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
629 let mut entities: Vec<ExtractedEntity> = Vec::new();
630 let mut current_parts: Vec<String> = Vec::new();
631 let mut current_type: Option<String> = None;
632
633 let flush =
634 |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
635 if let Some(t) = typ.take() {
636 let name = parts.join(" ").trim().to_string();
637 let is_single_caps = !name.contains(' ')
641 && name == name.to_uppercase()
642 && name.len() >= MIN_ENTITY_CHARS;
643 let should_skip = is_single_caps && is_filtered_all_caps(&name);
644 let is_section_marker = regex_section_marker().is_match(&name);
649 if name.len() >= MIN_ENTITY_CHARS && !should_skip && !is_section_marker {
650 entities.push(ExtractedEntity {
651 name,
652 entity_type: t,
653 });
654 }
655 parts.clear();
656 }
657 };
658
659 for (token, label) in tokens.iter().zip(labels.iter()) {
660 if label == "O" {
661 flush(&mut current_parts, &mut current_type, &mut entities);
662 continue;
663 }
664
665 let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
666 ("B", rest)
667 } else if let Some(rest) = label.strip_prefix("I-") {
668 ("I", rest)
669 } else {
670 flush(&mut current_parts, &mut current_type, &mut entities);
671 continue;
672 };
673
674 const PT_VERB_FALSE_POSITIVES: &[&str] = &[
681 "L\u{00ea}",
682 "V\u{00ea}",
683 "C\u{00e1}",
684 "P\u{00f4}r",
685 "Ser",
686 "Vir",
687 "Ver",
688 "Dar",
689 "Ler",
690 "Ter",
691 ];
692
693 let entity_type = match bio_type {
694 "DATE" => "date",
696 "PER" => {
697 if PT_VERB_FALSE_POSITIVES.contains(&token.as_str()) {
699 flush(&mut current_parts, &mut current_type, &mut entities);
700 continue;
701 }
702 "person"
703 }
704 "ORG" => {
705 let t = token.to_lowercase();
706 if t.contains("lib")
707 || t.contains("sdk")
708 || t.contains("cli")
709 || t.contains("crate")
710 || t.contains("npm")
711 {
712 "tool"
713 } else {
714 "organization"
716 }
717 }
718 "LOC" => "location",
720 other => other,
721 };
722
723 if prefix == "B" {
724 if token.starts_with("##") {
725 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
728 if let Some(last) = current_parts.last_mut() {
729 last.push_str(clean);
730 }
731 continue;
732 }
733 flush(&mut current_parts, &mut current_type, &mut entities);
734 current_parts.push(token.clone());
735 current_type = Some(entity_type.to_string());
736 } else if prefix == "I" && current_type.is_some() {
737 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
738 if token.starts_with("##") {
739 if let Some(last) = current_parts.last_mut() {
740 last.push_str(clean);
741 }
742 } else {
743 current_parts.push(clean.to_string());
744 }
745 }
746 }
747
748 flush(&mut current_parts, &mut current_type, &mut entities);
749 entities
750}
751
752#[cfg(test)]
762fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
763 if entities.len() < 2 {
764 return (Vec::new(), false);
765 }
766
767 let max_rels = crate::constants::max_relationships_per_memory();
770 let n = entities.len().min(MAX_ENTS);
771 let mut rels: Vec<NewRelationship> = Vec::new();
772 let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
773
774 let mut hit_cap = false;
775 'outer: for i in 0..n {
776 if rels.len() >= max_rels {
777 hit_cap = true;
778 break;
779 }
780
781 let mut for_entity = 0usize;
782 for j in (i + 1)..n {
783 if for_entity >= TOP_K_RELATIONS {
784 break;
785 }
786 if rels.len() >= max_rels {
787 hit_cap = true;
788 break 'outer;
789 }
790
791 let src = &entities[i].name;
792 let tgt = &entities[j].name;
793 let key = (src.clone(), tgt.clone());
794
795 if seen.contains(&key) {
796 continue;
797 }
798 seen.insert(key);
799
800 rels.push(NewRelationship {
801 source: src.clone(),
802 target: tgt.clone(),
803 relation: DEFAULT_RELATION.to_string(),
804 strength: 0.5,
805 description: None,
806 });
807 for_entity += 1;
808 }
809 }
810
811 if hit_cap {
813 tracing::warn!(
814 "relationships truncated to {max_rels} (with {n} entities, theoretical max was ~{}x combinations)",
815 n.saturating_sub(1)
816 );
817 }
818
819 (rels, hit_cap)
820}
821
822fn build_relationships_by_sentence_cooccurrence(
833 body: &str,
834 entities: &[NewEntity],
835) -> (Vec<NewRelationship>, bool) {
836 if entities.len() < 2 {
837 return (Vec::new(), false);
838 }
839
840 let max_rels = crate::constants::max_relationships_per_memory();
841 let lower_names: Vec<(usize, String)> = entities
842 .iter()
843 .take(MAX_ENTS)
844 .enumerate()
845 .map(|(i, e)| (i, e.name.to_lowercase()))
846 .collect();
847
848 let mut rels: Vec<NewRelationship> = Vec::new();
849 let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
850 let mut hit_cap = false;
851
852 for sentence in body.split(['.', '!', '?', '\n']) {
853 if sentence.trim().is_empty() {
854 continue;
855 }
856 let lower_sentence = sentence.to_lowercase();
857 let present: Vec<usize> = lower_names
858 .iter()
859 .filter(|(_, name)| !name.is_empty() && lower_sentence.contains(name.as_str()))
860 .map(|(i, _)| *i)
861 .collect();
862
863 if present.len() < 2 {
864 continue;
865 }
866
867 for i in 0..present.len() {
868 for j in (i + 1)..present.len() {
869 if rels.len() >= max_rels {
870 hit_cap = true;
871 tracing::warn!(
872 "relationships truncated to {max_rels} during sentence-level pairing"
873 );
874 return (rels, hit_cap);
875 }
876 let a = &entities[present[i]];
877 let b = &entities[present[j]];
878 let key = (a.name.to_lowercase(), b.name.to_lowercase());
879 if seen.insert(key) {
880 rels.push(NewRelationship {
881 source: a.name.clone(),
882 target: b.name.clone(),
883 relation: DEFAULT_RELATION.to_string(),
884 strength: 0.5,
885 description: None,
886 });
887 }
888 }
889 }
890 }
891
892 (rels, hit_cap)
893}
894
895fn run_ner_sliding_window(
896 model: &BertNerModel,
897 body: &str,
898 paths: &AppPaths,
899) -> Result<Vec<ExtractedEntity>> {
900 let tokenizer_path = model_dir(paths).join("tokenizer.json");
901 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
902 .map_err(|e| anyhow::anyhow!("loading NER tokenizer: {e}"))?;
903
904 let encoding = tokenizer
905 .encode(body, false)
906 .map_err(|e| anyhow::anyhow!("encoding NER input: {e}"))?;
907
908 let mut all_ids: Vec<u32> = encoding.get_ids().to_vec();
909 let mut all_tokens: Vec<String> = encoding
910 .get_tokens()
911 .iter()
912 .map(|s| s.to_string())
913 .collect();
914
915 if all_ids.is_empty() {
916 return Ok(Vec::new());
917 }
918
919 let max_tokens = crate::constants::extraction_max_tokens();
927 if all_ids.len() > max_tokens {
928 tracing::warn!(
929 target: "extraction",
930 original_tokens = all_ids.len(),
931 capped_tokens = max_tokens,
932 "NER input truncated to cap; later body region will be skipped by NER (regex prefilter still covers full body)"
933 );
934 all_ids.truncate(max_tokens);
935 all_tokens.truncate(max_tokens);
936 }
937
938 let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
940 let mut start = 0usize;
941 loop {
942 let end = (start + MAX_SEQ_LEN).min(all_ids.len());
943 windows.push((
944 all_ids[start..end].to_vec(),
945 all_tokens[start..end].to_vec(),
946 ));
947 if end >= all_ids.len() {
948 break;
949 }
950 start += STRIDE;
951 }
952
953 windows.sort_by_key(|(ids, _)| ids.len());
955
956 let batch_size = crate::constants::ner_batch_size();
958 let mut entities: Vec<ExtractedEntity> = Vec::new();
959 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
960
961 for chunk in windows.chunks(batch_size) {
962 match model.predict_batch(chunk) {
963 Ok(batch_labels) => {
964 for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
965 for ent in iob_to_entities(tokens, labels) {
966 if seen.insert(ent.name.clone()) {
967 entities.push(ent);
968 }
969 }
970 }
971 }
972 Err(e) => {
973 tracing::warn!(
974 "batch NER failed (chunk of {} windows): {e:#} — falling back to single-window",
975 chunk.len()
976 );
977 for (ids, tokens) in chunk {
979 let mask = vec![1u32; ids.len()];
980 match model.predict(ids, &mask) {
981 Ok(labels) => {
982 for ent in iob_to_entities(tokens, &labels) {
983 if seen.insert(ent.name.clone()) {
984 entities.push(ent);
985 }
986 }
987 }
988 Err(e2) => {
989 tracing::warn!("NER window fallback also failed: {e2:#}");
990 }
991 }
992 }
993 }
994 }
995 }
996
997 Ok(entities)
998}
999
1000fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
1007 static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
1008 let suffix_re = SUFFIX_RE.get_or_init(|| {
1011 Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)")
1012 .expect("compile-time validated numeric suffix regex literal")
1013 });
1014
1015 entities
1016 .into_iter()
1017 .map(|ent| {
1018 if let Some(pos) = body.find(&ent.name) {
1020 let after_pos = pos + ent.name.len();
1021 if after_pos < body.len() {
1022 let after = &body[after_pos..];
1023 if let Some(m) = suffix_re.find(after) {
1024 let suffix = m.as_str();
1025 if suffix.len() <= 7 {
1028 let extended = format!("{}{}", ent.name, suffix);
1029 return ExtractedEntity {
1030 name: extended,
1031 entity_type: ent.entity_type,
1032 };
1033 }
1034 }
1035 }
1036 }
1037 ent
1038 })
1039 .collect()
1040}
1041
1042fn augment_versioned_model_names(
1062 entities: Vec<ExtractedEntity>,
1063 body: &str,
1064) -> Vec<ExtractedEntity> {
1065 static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
1066 let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
1073 Regex::new(
1074 r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
1075 )
1076 .expect("compile-time validated versioned model regex literal")
1077 });
1078
1079 let mut existing_lc: std::collections::HashSet<String> =
1080 entities.iter().map(|ent| ent.name.to_lowercase()).collect();
1081 let mut result = entities;
1082
1083 for caps in model_re.captures_iter(body) {
1084 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1085 if full_match.is_empty() || full_match.len() > 24 {
1088 continue;
1089 }
1090 let normalized_lc = full_match.to_lowercase();
1091 if existing_lc.contains(&normalized_lc) {
1092 continue;
1093 }
1094 if result.len() >= MAX_ENTS {
1097 break;
1098 }
1099 existing_lc.insert(normalized_lc);
1100 result.push(ExtractedEntity {
1101 name: full_match.to_string(),
1102 entity_type: "concept".to_string(),
1103 });
1104 }
1105
1106 result
1107}
1108
1109fn merge_and_deduplicate(
1110 regex_ents: Vec<ExtractedEntity>,
1111 ner_ents: Vec<ExtractedEntity>,
1112) -> Vec<ExtractedEntity> {
1113 let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
1128 let mut result: Vec<ExtractedEntity> = Vec::new();
1129 let mut truncated = false;
1130
1131 let total_input = regex_ents.len() + ner_ents.len();
1132 for ent in regex_ents.into_iter().chain(ner_ents) {
1133 let name_lc = ent.name.nfkc().collect::<String>().to_lowercase();
1134 let key = format!("{}\0{}", ent.entity_type, name_lc);
1138
1139 let mut collision_idx: Option<usize> = None;
1144 for (existing_key, idx) in &by_lc {
1145 let type_prefix = format!("{}\0", ent.entity_type);
1147 if !existing_key.starts_with(&type_prefix) {
1148 continue;
1149 }
1150 let existing_name_lc = &existing_key[type_prefix.len()..];
1151 if existing_name_lc == name_lc
1152 || existing_name_lc.contains(name_lc.as_str())
1153 || name_lc.contains(existing_name_lc)
1154 {
1155 collision_idx = Some(*idx);
1156 break;
1157 }
1158 }
1159 match collision_idx {
1160 Some(idx) => {
1161 if ent.name.len() > result[idx].name.len() {
1164 let old_name_lc = result[idx].name.nfkc().collect::<String>().to_lowercase();
1165 let old_key = format!("{}\0{}", result[idx].entity_type, old_name_lc);
1166 by_lc.remove(&old_key);
1167 result[idx] = ent;
1168 by_lc.insert(key, idx);
1169 }
1170 }
1171 None => {
1172 by_lc.insert(key, result.len());
1173 result.push(ent);
1174 }
1175 }
1176 if result.len() >= MAX_ENTS {
1177 truncated = true;
1178 break;
1179 }
1180 }
1181
1182 if truncated {
1184 tracing::warn!(
1185 "extraction truncated at {MAX_ENTS} entities (input had {total_input} candidates before deduplication)"
1186 );
1187 }
1188
1189 result
1190}
1191
1192fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
1193 extracted
1194 .into_iter()
1195 .map(|e| NewEntity {
1196 name: e.name,
1197 entity_type: e.entity_type,
1198 description: None,
1199 })
1200 .collect()
1201}
1202
1203pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
1204 let regex_entities = apply_regex_prefilter(body);
1205
1206 let mut bert_used = false;
1207 let ner_entities = match get_or_init_model(paths) {
1208 Some(model) => match run_ner_sliding_window(model, body, paths) {
1209 Ok(ents) => {
1210 bert_used = true;
1211 ents
1212 }
1213 Err(e) => {
1214 tracing::warn!("NER failed, falling back to regex-only extraction: {e:#}");
1215 Vec::new()
1216 }
1217 },
1218 None => Vec::new(),
1219 };
1220
1221 let merged = merge_and_deduplicate(regex_entities, ner_entities);
1222 let extended = extend_with_numeric_suffix(merged, body);
1224 let with_models = augment_versioned_model_names(extended, body);
1228 let with_models: Vec<ExtractedEntity> = with_models
1232 .into_iter()
1233 .filter(|e| !regex_section_marker().is_match(&e.name))
1234 .collect();
1235 let entities = to_new_entities(with_models);
1236 let (relationships, relationships_truncated) =
1237 build_relationships_by_sentence_cooccurrence(body, &entities);
1238
1239 let extraction_method = if bert_used {
1240 "bert+regex-batch".to_string()
1241 } else {
1242 "regex-only".to_string()
1243 };
1244
1245 let urls = extract_urls(body);
1246
1247 Ok(ExtractionResult {
1248 entities,
1249 relationships,
1250 relationships_truncated,
1251 extraction_method,
1252 urls,
1253 })
1254}
1255
1256pub struct RegexExtractor;
1257
1258impl Extractor for RegexExtractor {
1259 fn extract(&self, body: &str) -> Result<ExtractionResult> {
1260 let regex_entities = apply_regex_prefilter(body);
1261 let entities = to_new_entities(regex_entities);
1262 let (relationships, relationships_truncated) =
1263 build_relationships_by_sentence_cooccurrence(body, &entities);
1264 let urls = extract_urls(body);
1265 Ok(ExtractionResult {
1266 entities,
1267 relationships,
1268 relationships_truncated,
1269 extraction_method: "regex-only".to_string(),
1270 urls,
1271 })
1272 }
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277 use super::*;
1278
1279 fn make_paths() -> AppPaths {
1280 use std::path::PathBuf;
1281 AppPaths {
1282 db: PathBuf::from("/tmp/test.sqlite"),
1283 models: PathBuf::from("/tmp/test_models"),
1284 }
1285 }
1286
1287 #[test]
1288 fn regex_email_captures_address() {
1289 let ents = apply_regex_prefilter("contact: someone@company.com for more info");
1290 assert!(ents
1292 .iter()
1293 .any(|e| e.name == "someone@company.com" && e.entity_type == "concept"));
1294 }
1295
1296 #[test]
1297 fn regex_all_caps_filters_pt_rule_word() {
1298 let ents = apply_regex_prefilter("NUNCA do this. PROIBIDO use X. DEVE follow Y.");
1300 assert!(
1301 !ents.iter().any(|e| e.name == "NUNCA"),
1302 "NUNCA must be filtered as a stopword"
1303 );
1304 assert!(
1305 !ents.iter().any(|e| e.name == "PROIBIDO"),
1306 "PROIBIDO must be filtered"
1307 );
1308 assert!(
1309 !ents.iter().any(|e| e.name == "DEVE"),
1310 "DEVE must be filtered"
1311 );
1312 }
1313
1314 #[test]
1315 fn regex_all_caps_accepts_underscored_constant() {
1316 let ents = apply_regex_prefilter("configure MAX_RETRY=3 and API_TIMEOUT=30");
1318 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1319 assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1320 }
1321
1322 #[test]
1323 fn regex_all_caps_accepts_domain_acronym() {
1324 let ents = apply_regex_prefilter("OPENAI launched GPT-5 with NVIDIA H100");
1326 assert!(ents.iter().any(|e| e.name == "OPENAI"));
1327 assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1328 }
1329
1330 #[test]
1331 fn regex_url_does_not_appear_in_apply_regex_prefilter() {
1332 let ents = apply_regex_prefilter("see https://docs.rs/crate for details");
1334 assert!(
1335 !ents.iter().any(|e| e.name.starts_with("https://")),
1336 "URLs must not appear as entities after the P0-2 split"
1337 );
1338 }
1339
1340 #[test]
1341 fn extract_urls_captures_https() {
1342 let urls = extract_urls("see https://docs.rs/crate for details");
1343 assert_eq!(urls.len(), 1);
1344 assert_eq!(urls[0].url, "https://docs.rs/crate");
1345 assert!(urls[0].offset > 0);
1346 }
1347
1348 #[test]
1349 fn extract_urls_trim_sufixo_pontuacao() {
1350 let urls = extract_urls("link: https://example.com/path. fim");
1351 assert!(!urls.is_empty());
1352 assert!(
1353 !urls[0].url.ends_with('.'),
1354 "sufixo ponto deve ser removido"
1355 );
1356 }
1357
1358 #[test]
1359 fn extract_urls_dedupes_repeated() {
1360 let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1361 let urls = extract_urls(body);
1362 assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1363 }
1364
1365 #[test]
1366 fn regex_uuid_captura_identificador() {
1367 let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1368 assert!(ents.iter().any(|e| e.entity_type == "concept"));
1369 }
1370
1371 #[test]
1372 fn regex_all_caps_captura_constante() {
1373 let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1374 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1375 assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1376 }
1377
1378 #[test]
1379 fn regex_all_caps_ignores_short_words() {
1380 let ents = apply_regex_prefilter("use AI em seu projeto");
1381 assert!(
1382 !ents.iter().any(|e| e.name == "AI"),
1383 "AI tem apenas 2 chars, deve ser ignorado"
1384 );
1385 }
1386
1387 #[test]
1388 fn iob_decodes_per_to_person() {
1389 let tokens = vec![
1390 "John".to_string(),
1391 "Doe".to_string(),
1392 "trabalhou".to_string(),
1393 ];
1394 let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1395 let ents = iob_to_entities(&tokens, &labels);
1396 assert_eq!(ents.len(), 1);
1397 assert_eq!(ents[0].entity_type, "person");
1398 assert!(ents[0].name.contains("John"));
1399 }
1400
1401 #[test]
1402 fn iob_strip_subword_b_prefix() {
1403 let tokens = vec!["Open".to_string(), "##AI".to_string()];
1406 let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1407 let ents = iob_to_entities(&tokens, &labels);
1408 assert!(
1409 ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1410 "should merge ##AI or discard"
1411 );
1412 }
1413
1414 #[test]
1415 fn iob_subword_orphan_discards() {
1416 let tokens = vec!["##AI".to_string()];
1418 let labels = vec!["B-ORG".to_string()];
1419 let ents = iob_to_entities(&tokens, &labels);
1420 assert!(
1421 ents.is_empty(),
1422 "orphan subword without an active entity must be discarded"
1423 );
1424 }
1425
1426 #[test]
1427 fn iob_maps_date_to_date_v1025() {
1428 let tokens = vec!["January".to_string(), "2024".to_string()];
1430 let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1431 let ents = iob_to_entities(&tokens, &labels);
1432 assert_eq!(
1433 ents.len(),
1434 1,
1435 "DATE must be emitted as an entity in v1.0.25"
1436 );
1437 assert_eq!(ents[0].entity_type, "date");
1438 }
1439
1440 #[test]
1441 fn iob_maps_org_to_organization_v1025() {
1442 let tokens = vec!["Empresa".to_string()];
1444 let labels = vec!["B-ORG".to_string()];
1445 let ents = iob_to_entities(&tokens, &labels);
1446 assert_eq!(ents[0].entity_type, "organization");
1447 }
1448
1449 #[test]
1450 fn iob_maps_org_sdk_to_tool() {
1451 let tokens = vec!["tokio-sdk".to_string()];
1452 let labels = vec!["B-ORG".to_string()];
1453 let ents = iob_to_entities(&tokens, &labels);
1454 assert_eq!(ents[0].entity_type, "tool");
1455 }
1456
1457 #[test]
1458 fn iob_maps_loc_to_location_v1025() {
1459 let tokens = vec!["Brasil".to_string()];
1461 let labels = vec!["B-LOC".to_string()];
1462 let ents = iob_to_entities(&tokens, &labels);
1463 assert_eq!(ents[0].entity_type, "location");
1464 }
1465
1466 #[test]
1467 fn build_relationships_respeitam_max_rels() {
1468 let entities: Vec<NewEntity> = (0..20)
1469 .map(|i| NewEntity {
1470 name: format!("entidade_{i}"),
1471 entity_type: "concept".to_string(),
1472 description: None,
1473 })
1474 .collect();
1475 let (rels, truncated) = build_relationships(&entities);
1476 let max_rels = crate::constants::max_relationships_per_memory();
1477 assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1478 if rels.len() == max_rels {
1479 assert!(truncated, "truncated deve ser true quando atingiu o cap");
1480 }
1481 }
1482
1483 #[test]
1484 fn build_relationships_without_duplicates() {
1485 let entities: Vec<NewEntity> = (0..5)
1486 .map(|i| NewEntity {
1487 name: format!("ent_{i}"),
1488 entity_type: "concept".to_string(),
1489 description: None,
1490 })
1491 .collect();
1492 let (rels, _truncated) = build_relationships(&entities);
1493 let mut pares: std::collections::HashSet<(String, String)> =
1494 std::collections::HashSet::new();
1495 for r in &rels {
1496 let par = (r.source.clone(), r.target.clone());
1497 assert!(pares.insert(par), "par duplicado encontrado");
1498 }
1499 }
1500
1501 #[test]
1502 fn merge_dedupes_by_lowercase_name() {
1503 let a = vec![ExtractedEntity {
1506 name: "Rust".to_string(),
1507 entity_type: "concept".to_string(),
1508 }];
1509 let b = vec![ExtractedEntity {
1510 name: "rust".to_string(),
1511 entity_type: "concept".to_string(),
1512 }];
1513 let merged = merge_and_deduplicate(a, b);
1514 assert_eq!(
1515 merged.len(),
1516 1,
1517 "rust and Rust with the same type are the same entity"
1518 );
1519 }
1520
1521 #[test]
1522 fn regex_extractor_implements_trait() {
1523 let extractor = RegexExtractor;
1524 let result = extractor
1525 .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1526 .unwrap();
1527 assert!(!result.entities.is_empty());
1528 }
1529
1530 #[test]
1531 fn extract_returns_ok_without_model() {
1532 let paths = make_paths();
1534 let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1535 let result = extract_graph_auto(body, &paths).unwrap();
1536 assert!(result
1537 .entities
1538 .iter()
1539 .any(|e| e.name.contains("teste@exemplo.com")));
1540 }
1541
1542 #[test]
1543 fn stopwords_filter_v1024_terms() {
1544 let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1547 DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1548 let ents = apply_regex_prefilter(body);
1549 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1550 for word in &[
1551 "ACEITE",
1552 "ACK",
1553 "ACL",
1554 "BORDA",
1555 "CHECKLIST",
1556 "COMPLETED",
1557 "CONFIRME",
1558 "DEVEMOS",
1559 "DONE",
1560 "FIXED",
1561 "NEGUE",
1562 "PENDING",
1563 "PLAN",
1564 "PODEMOS",
1565 "RECUSE",
1566 "TOKEN",
1567 "VAMOS",
1568 ] {
1569 assert!(
1570 !names.contains(word),
1571 "v1.0.24 stopword {word} should be filtered but was found in entities"
1572 );
1573 }
1574 }
1575
1576 #[test]
1577 fn dedup_normalizes_unicode_combining_marks() {
1578 let nfc = vec![ExtractedEntity {
1582 name: "Caf\u{e9}".to_string(),
1583 entity_type: "concept".to_string(),
1584 }];
1585 let nfd_name = "Cafe\u{301}".to_string();
1587 let nfd = vec![ExtractedEntity {
1588 name: nfd_name,
1589 entity_type: "concept".to_string(),
1590 }];
1591 let merged = merge_and_deduplicate(nfc, nfd);
1592 assert_eq!(
1593 merged.len(),
1594 1,
1595 "NFC 'Caf\\u{{e9}}' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1596 );
1597 }
1598
1599 #[test]
1602 fn predict_batch_output_count_matches_input() {
1603 let w1_ids: Vec<u32> = vec![101, 100, 102];
1609 let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1610 let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1611 let w2_tok: Vec<String> = vec![
1612 "[CLS]".into(),
1613 "world".into(),
1614 "foo".into(),
1615 "bar".into(),
1616 "[SEP]".into(),
1617 ];
1618 let windows: Vec<(Vec<u32>, Vec<String>)> =
1619 vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1620
1621 let device = Device::Cpu;
1624 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1625 assert_eq!(max_len, 5, "max_len deve ser 5");
1626
1627 let mut padded_ids: Vec<Tensor> = Vec::new();
1628 for (ids, _) in &windows {
1629 let len = ids.len();
1630 let pad_right = max_len - len;
1631 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1632 let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1633 let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1634 assert_eq!(
1635 t.dims(),
1636 &[max_len],
1637 "each window must have shape (max_len,) after padding"
1638 );
1639 padded_ids.push(t);
1640 }
1641
1642 let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1643 assert_eq!(
1644 stacked.dims(),
1645 &[2, max_len],
1646 "stack deve produzir (batch_size=2, max_len=5)"
1647 );
1648
1649 let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; let fake_logits =
1653 Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1654 for (i, (ids, _)) in windows.iter().enumerate() {
1655 let real_len = ids.len();
1656 let example = fake_logits.get(i).unwrap();
1657 let sliced = example.narrow(0, 0, real_len).unwrap();
1658 assert_eq!(
1659 sliced.dims(),
1660 &[real_len, 9],
1661 "narrow deve preservar apenas {real_len} tokens reais"
1662 );
1663 }
1664 }
1665
1666 #[test]
1667 fn predict_batch_empty_windows_returns_empty() {
1668 let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1671 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1672 assert_eq!(max_len, 0, "zero windows → max_len 0");
1673 let result: Vec<Vec<String>> = if max_len == 0 {
1676 Vec::new()
1677 } else {
1678 unreachable!()
1679 };
1680 assert!(result.is_empty());
1681 }
1682
1683 #[test]
1684 fn ner_batch_size_default_is_8() {
1685 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1688 assert_eq!(crate::constants::ner_batch_size(), 8);
1689 }
1690
1691 #[test]
1692 fn ner_batch_size_env_override_clamped() {
1693 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1695 assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1696
1697 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1698 assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1699
1700 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1701 assert_eq!(
1702 crate::constants::ner_batch_size(),
1703 4,
1704 "valid value preserved"
1705 );
1706
1707 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1708 }
1709
1710 #[test]
1711 fn extraction_method_regex_only_unchanged() {
1712 let result = RegexExtractor.extract("contact: dev@acme.io").unwrap();
1715 assert_eq!(
1716 result.extraction_method, "regex-only",
1717 "RegexExtractor must return regex-only"
1718 );
1719 }
1720
1721 #[test]
1724 fn extend_suffix_pure_numeric_unchanged() {
1725 let ents = vec![ExtractedEntity {
1727 name: "GPT".to_string(),
1728 entity_type: "concept".to_string(),
1729 }];
1730 let result = extend_with_numeric_suffix(ents, "using GPT-5 in the project");
1731 assert_eq!(
1732 result[0].name, "GPT-5",
1733 "purely numeric suffix must be extended"
1734 );
1735 }
1736
1737 #[test]
1738 fn extend_suffix_alphanumeric_letter_after_digit() {
1739 let ents = vec![ExtractedEntity {
1741 name: "GPT".to_string(),
1742 entity_type: "concept".to_string(),
1743 }];
1744 let result = extend_with_numeric_suffix(ents, "using GPT-4o for advanced tasks");
1745 assert_eq!(result[0].name, "GPT-4o", "suffix '4o' must be accepted");
1746 }
1747
1748 #[test]
1749 fn extend_suffix_alphanumeric_b_suffix() {
1750 let ents = vec![ExtractedEntity {
1752 name: "Llama".to_string(),
1753 entity_type: "concept".to_string(),
1754 }];
1755 let result = extend_with_numeric_suffix(ents, "Llama-5b open-weight model");
1756 assert_eq!(result[0].name, "Llama-5b", "suffix '5b' must be accepted");
1757 }
1758
1759 #[test]
1760 fn extend_suffix_alphanumeric_x_suffix() {
1761 let ents = vec![ExtractedEntity {
1763 name: "Mistral".to_string(),
1764 entity_type: "concept".to_string(),
1765 }];
1766 let result = extend_with_numeric_suffix(ents, "testing Mistral-8x in production");
1767 assert_eq!(result[0].name, "Mistral-8x", "suffix '8x' must be accepted");
1768 }
1769
1770 #[test]
1773 fn augment_versioned_gpt4o() {
1774 let result = augment_versioned_model_names(vec![], "using GPT-4o for analysis");
1776 assert!(
1777 result.iter().any(|e| e.name == "GPT-4o"),
1778 "GPT-4o must be captured by augment, found: {:?}",
1779 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1780 );
1781 }
1782
1783 #[test]
1784 fn augment_versioned_claude_4_sonnet() {
1785 let result =
1787 augment_versioned_model_names(vec![], "best model: Claude 4 Sonnet released today");
1788 assert!(
1789 result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1790 "Claude 4 Sonnet must be captured, found: {:?}",
1791 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1792 );
1793 }
1794
1795 #[test]
1796 fn augment_versioned_llama_3_pro() {
1797 let result =
1799 augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1800 assert!(
1801 result.iter().any(|e| e.name == "Llama 3 Pro"),
1802 "Llama 3 Pro deve ser capturado, achados: {:?}",
1803 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1804 );
1805 }
1806
1807 #[test]
1808 fn augment_versioned_mixtral_8x7b() {
1809 let result =
1811 augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1812 assert!(
1813 result.iter().any(|e| e.name == "Mixtral 8x7B"),
1814 "Mixtral 8x7B deve ser capturado, achados: {:?}",
1815 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1816 );
1817 }
1818
1819 #[test]
1820 fn augment_versioned_does_not_duplicate_existing() {
1821 let existing = vec![ExtractedEntity {
1823 name: "Claude 4".to_string(),
1824 entity_type: "concept".to_string(),
1825 }];
1826 let result = augment_versioned_model_names(existing, "using Claude 4 in the project");
1827 let count = result.iter().filter(|e| e.name == "Claude 4").count();
1828 assert_eq!(count, 1, "Claude 4 must not be duplicated");
1829 }
1830
1831 #[test]
1834 fn stopwords_filter_url_jwt_api_v1025() {
1835 let body = "We use URL, JWT, and API REST in our LLM-powered CLI via HTTP/HTTPS and UI.";
1837 let ents = apply_regex_prefilter(body);
1838 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1839 for blocked in &[
1840 "URL", "JWT", "API", "REST", "LLM", "CLI", "HTTP", "HTTPS", "UI",
1841 ] {
1842 assert!(
1843 !names.contains(blocked),
1844 "v1.0.25 stopword {blocked} leaked as entity; found names: {names:?}"
1845 );
1846 }
1847 }
1848
1849 #[test]
1852 fn section_markers_etapa_fase_filtered_v1025() {
1853 let body = "Etapa 3 do plano: implementar Fase 1 da Migra\u{e7}\u{e3}o.";
1857 let ents = apply_regex_prefilter(body);
1858 assert!(
1859 !ents
1860 .iter()
1861 .any(|e| e.name.contains("Etapa") || e.name.contains("Fase")),
1862 "section markers must be stripped; entities: {:?}",
1863 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1864 );
1865 }
1866
1867 #[test]
1868 fn section_markers_passo_secao_filtered_v1025() {
1869 let body = "Siga Passo 2 conforme Se\u{e7}\u{e3}o 3 do manual.";
1872 let ents = apply_regex_prefilter(body);
1873 assert!(
1874 !ents
1875 .iter()
1876 .any(|e| e.name.contains("Passo") || e.name.contains("Se\u{e7}\u{e3}o")),
1877 "Passo/Se\\u{{e7}}\\u{{e3}}o section markers must be stripped; entities: {:?}",
1878 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1879 );
1880 }
1881
1882 #[test]
1885 fn brand_camelcase_extracted_as_organization_v1025() {
1886 let body = "OpenAI launched GPT-4 and PostgreSQL added pgvector.";
1888 let ents = apply_regex_prefilter(body);
1889 let openai = ents.iter().find(|e| e.name == "OpenAI");
1890 assert!(
1891 openai.is_some(),
1892 "OpenAI must be extracted by CamelCase brand regex; entities: {:?}",
1893 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1894 );
1895 assert_eq!(
1896 openai.unwrap().entity_type,
1897 "organization",
1898 "brand CamelCase must map to organization (V008)"
1899 );
1900 }
1901
1902 #[test]
1903 fn brand_postgresql_extracted_as_organization_v1025() {
1904 let body = "migrating from MySQL to PostgreSQL for better performance.";
1905 let ents = apply_regex_prefilter(body);
1906 assert!(
1907 ents.iter()
1908 .any(|e| e.name == "PostgreSQL" && e.entity_type == "organization"),
1909 "PostgreSQL must be extracted as organization; entities: {:?}",
1910 ents.iter()
1911 .map(|e| (&e.name, &e.entity_type))
1912 .collect::<Vec<_>>()
1913 );
1914 }
1915
1916 #[test]
1919 fn iob_org_maps_to_organization_not_project_v1025() {
1920 let tokens = vec!["Microsoft".to_string()];
1922 let labels = vec!["B-ORG".to_string()];
1923 let ents = iob_to_entities(&tokens, &labels);
1924 assert_eq!(
1925 ents[0].entity_type, "organization",
1926 "B-ORG must map to organization (V008); got {}",
1927 ents[0].entity_type
1928 );
1929 }
1930
1931 #[test]
1932 fn iob_loc_maps_to_location_not_concept_v1025() {
1933 let tokens = vec!["S\u{e3}o".to_string(), "Paulo".to_string()];
1936 let labels = vec!["B-LOC".to_string(), "I-LOC".to_string()];
1937 let ents = iob_to_entities(&tokens, &labels);
1938 assert_eq!(
1939 ents[0].entity_type, "location",
1940 "B-LOC must map to location (V008); got {}",
1941 ents[0].entity_type
1942 );
1943 }
1944
1945 #[test]
1946 fn iob_date_maps_to_date_not_discarded_v1025() {
1947 let tokens = vec!["2025".to_string(), "-".to_string(), "12".to_string()];
1949 let labels = vec![
1950 "B-DATE".to_string(),
1951 "I-DATE".to_string(),
1952 "I-DATE".to_string(),
1953 ];
1954 let ents = iob_to_entities(&tokens, &labels);
1955 assert_eq!(
1956 ents.len(),
1957 1,
1958 "DATE entity must be emitted (V008); entities: {ents:?}"
1959 );
1960 assert_eq!(ents[0].entity_type, "date");
1961 }
1962
1963 #[test]
1966 fn pt_verb_le_filtered_as_per_v1025() {
1967 let tokens = vec!["L\u{ea}".to_string(), "o".to_string(), "livro".to_string()];
1970 let labels = vec!["B-PER".to_string(), "O".to_string(), "O".to_string()];
1971 let ents = iob_to_entities(&tokens, &labels);
1972 assert!(
1973 !ents
1974 .iter()
1975 .any(|e| e.name == "L\u{ea}" && e.entity_type == "person"),
1976 "PT verb 'L\\u{{ea}}' tagged B-PER must be filtered; entities: {ents:?}"
1977 );
1978 }
1979
1980 #[test]
1981 fn pt_verb_ver_filtered_as_per_v1025() {
1982 let tokens = vec!["Ver".to_string()];
1984 let labels = vec!["B-PER".to_string()];
1985 let ents = iob_to_entities(&tokens, &labels);
1986 assert!(
1987 ents.is_empty(),
1988 "PT verb 'Ver' tagged B-PER must be filtered; entities: {ents:?}"
1989 );
1990 }
1991
1992 fn entity(name: &str, entity_type: &str) -> ExtractedEntity {
1995 ExtractedEntity {
1996 name: name.to_string(),
1997 entity_type: entity_type.to_string(),
1998 }
1999 }
2000
2001 #[test]
2002 fn merge_resolves_sonne_vs_sonnet_keeps_longest_v1025() {
2003 let regex = vec![entity("Sonne", "concept")];
2005 let ner = vec![entity("Sonnet", "concept")];
2006 let result = merge_and_deduplicate(regex, ner);
2007 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2008 assert_eq!(result[0].name, "Sonnet");
2009 }
2010
2011 #[test]
2012 fn merge_resolves_open_vs_openai_keeps_longest_v1025() {
2013 let regex = vec![
2015 entity("Open", "organization"),
2016 entity("OpenAI", "organization"),
2017 ];
2018 let result = merge_and_deduplicate(regex, vec![]);
2019 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2020 assert_eq!(result[0].name, "OpenAI");
2021 }
2022
2023 #[test]
2024 fn merge_keeps_both_when_no_containment_v1025() {
2025 let regex = vec![entity("Alice", "person"), entity("Bob", "person")];
2027 let result = merge_and_deduplicate(regex, vec![]);
2028 assert_eq!(result.len(), 2, "expected 2 entities, got: {result:?}");
2029 }
2030
2031 #[test]
2032 fn merge_respects_entity_type_boundary_v1025() {
2033 let regex = vec![entity("Apple", "organization"), entity("Apple", "concept")];
2035 let result = merge_and_deduplicate(regex, vec![]);
2036 assert_eq!(
2037 result.len(),
2038 2,
2039 "expected 2 entities (different types), got: {result:?}"
2040 );
2041 }
2042
2043 #[test]
2044 fn merge_case_insensitive_dedup_v1025() {
2045 let regex = vec![
2047 entity("OpenAI", "organization"),
2048 entity("openai", "organization"),
2049 ];
2050 let result = merge_and_deduplicate(regex, vec![]);
2051 assert_eq!(
2052 result.len(),
2053 1,
2054 "expected 1 entity after case-insensitive dedup, got: {result:?}"
2055 );
2056 }
2057
2058 #[test]
2061 fn iob_section_marker_etapa_filtered_v1025() {
2062 let tokens = vec!["Etapa".to_string(), "3".to_string()];
2064 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2065 let ents = iob_to_entities(&tokens, &labels);
2066 assert!(
2067 !ents.iter().any(|e| e.name.contains("Etapa")),
2068 "section marker 'Etapa 3' from BERT must be filtered; entities: {ents:?}"
2069 );
2070 }
2071
2072 #[test]
2073 fn iob_section_marker_fase_filtered_v1025() {
2074 let tokens = vec!["Fase".to_string(), "1".to_string()];
2076 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2077 let ents = iob_to_entities(&tokens, &labels);
2078 assert!(
2079 !ents.iter().any(|e| e.name.contains("Fase")),
2080 "section marker 'Fase 1' from BERT must be filtered; entities: {ents:?}"
2081 );
2082 }
2083
2084 #[test]
2087 fn extract_graph_auto_handles_large_body_under_30s() {
2088 let body = "x ".repeat(40_000);
2091 let paths = make_paths();
2092 let start = std::time::Instant::now();
2093 let result = extract_graph_auto(&body, &paths).expect("extraction must not error");
2094 let elapsed = start.elapsed();
2095 assert!(
2096 elapsed.as_secs() < 30,
2097 "extract_graph_auto took {}s for 80 KB body (cap should keep it well under 30s)",
2098 elapsed.as_secs()
2099 );
2100 let _ = result.entities;
2102 }
2103
2104 #[test]
2107 fn pt_uppercase_stopwords_filtered_v1031() {
2108 let body = "Para o ADAPTER funcionar com PROJETO em modo PASSIVA, devemos usar \
2109 SOMENTE LEITURA conforme a REGRA OBRIGATORIA do EXEMPLO DEFAULT.";
2110 let ents = apply_regex_prefilter(body);
2111 let names: Vec<String> = ents.iter().map(|e| e.name.to_uppercase()).collect();
2112 for stop in &[
2113 "ADAPTER",
2114 "PROJETO",
2115 "PASSIVA",
2116 "SOMENTE",
2117 "LEITURA",
2118 "REGRA",
2119 "OBRIGATORIA",
2120 "EXEMPLO",
2121 "DEFAULT",
2122 ] {
2123 assert!(
2124 !names.contains(&stop.to_string()),
2125 "v1.0.31 A11 stoplist failed: {stop} leaked as entity; got names: {names:?}"
2126 );
2127 }
2128 }
2129
2130 #[test]
2131 fn pt_underscored_identifier_preserved_v1031() {
2132 let ents = apply_regex_prefilter("configure FLOWAIPER_API_KEY=foo and MAX_TIMEOUT=30");
2135 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
2136 assert!(names.contains(&"FLOWAIPER_API_KEY"));
2137 assert!(names.contains(&"MAX_TIMEOUT"));
2138 }
2139
2140 #[test]
2143 fn build_relationships_by_sentence_only_links_co_occurring_entities() {
2144 let body = "Alice met Bob at the conference. Carol works alone in another room.";
2145 let entities = vec![
2146 NewEntity {
2147 name: "Alice".to_string(),
2148 entity_type: "person".to_string(),
2149 description: None,
2150 },
2151 NewEntity {
2152 name: "Bob".to_string(),
2153 entity_type: "person".to_string(),
2154 description: None,
2155 },
2156 NewEntity {
2157 name: "Carol".to_string(),
2158 entity_type: "person".to_string(),
2159 description: None,
2160 },
2161 ];
2162 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2163 assert!(!truncated);
2164 assert_eq!(
2165 rels.len(),
2166 1,
2167 "only Alice/Bob should pair (same sentence); Carol is isolated"
2168 );
2169 let pair = (rels[0].source.as_str(), rels[0].target.as_str());
2170 assert!(
2171 matches!(pair, ("Alice", "Bob") | ("Bob", "Alice")),
2172 "unexpected pair {pair:?}"
2173 );
2174 }
2175
2176 #[test]
2177 fn build_relationships_by_sentence_returns_empty_for_single_entity() {
2178 let body = "Alice is here.";
2179 let entities = vec![NewEntity {
2180 name: "Alice".to_string(),
2181 entity_type: "person".to_string(),
2182 description: None,
2183 }];
2184 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2185 assert!(rels.is_empty());
2186 assert!(!truncated);
2187 }
2188
2189 #[test]
2190 fn build_relationships_by_sentence_dedupes_pairs_across_sentences() {
2191 let body = "Alice met Bob. Bob saw Alice again.";
2192 let entities = vec![
2193 NewEntity {
2194 name: "Alice".to_string(),
2195 entity_type: "person".to_string(),
2196 description: None,
2197 },
2198 NewEntity {
2199 name: "Bob".to_string(),
2200 entity_type: "person".to_string(),
2201 description: None,
2202 },
2203 ];
2204 let (rels, _) = build_relationships_by_sentence_cooccurrence(body, &entities);
2205 assert_eq!(
2206 rels.len(),
2207 1,
2208 "Alice/Bob pair must be emitted only once even when co-occurring in multiple sentences"
2209 );
2210 }
2211
2212 #[test]
2213 fn extraction_max_tokens_default_is_5000() {
2214 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2215 assert_eq!(crate::constants::extraction_max_tokens(), 5_000);
2216 }
2217
2218 #[test]
2219 fn extraction_max_tokens_env_override_clamped() {
2220 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200");
2221 assert_eq!(
2222 crate::constants::extraction_max_tokens(),
2223 5_000,
2224 "value below 512 must fall back to default"
2225 );
2226
2227 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200000");
2228 assert_eq!(
2229 crate::constants::extraction_max_tokens(),
2230 5_000,
2231 "value above 100_000 must fall back to default"
2232 );
2233
2234 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "8000");
2235 assert_eq!(
2236 crate::constants::extraction_max_tokens(),
2237 8_000,
2238 "valid value must be honoured"
2239 );
2240
2241 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2242 }
2243}