Skip to main content

sqlite_graphrag/
extraction.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::OnceLock;
4
5use anyhow::{Context, Result};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::{Linear, Module, VarBuilder};
8use candle_transformers::models::bert::{BertModel, Config as BertConfig};
9use regex::Regex;
10use serde::Deserialize;
11
12use crate::paths::AppPaths;
13use crate::storage::entities::{NewEntity, NewRelationship};
14
15const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
16const MAX_SEQ_LEN: usize = 512;
17const STRIDE: usize = 256;
18const MAX_ENTS: usize = 30;
19const MAX_RELS: usize = 50;
20const TOP_K_RELATIONS: usize = 5;
21const DEFAULT_RELATION: &str = "mentions";
22const MIN_ENTITY_CHARS: usize = 2;
23
24static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
25static REGEX_URL: OnceLock<Regex> = OnceLock::new();
26static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
27static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
28
29fn regex_email() -> &'static Regex {
30    REGEX_EMAIL
31        .get_or_init(|| Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}").unwrap())
32}
33
34fn regex_url() -> &'static Regex {
35    REGEX_URL.get_or_init(|| Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#).unwrap())
36}
37
38fn regex_uuid() -> &'static Regex {
39    REGEX_UUID.get_or_init(|| {
40        Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
41            .unwrap()
42    })
43}
44
45fn regex_all_caps() -> &'static Regex {
46    REGEX_ALL_CAPS.get_or_init(|| Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b").unwrap())
47}
48
49#[derive(Debug, Clone, PartialEq)]
50pub struct ExtractedEntity {
51    pub name: String,
52    pub entity_type: String,
53}
54
55#[derive(Debug, Clone)]
56pub struct ExtractionResult {
57    pub entities: Vec<NewEntity>,
58    pub relationships: Vec<NewRelationship>,
59}
60
61pub trait Extractor: Send + Sync {
62    fn extract(&self, body: &str) -> Result<ExtractionResult>;
63}
64
65#[derive(Deserialize)]
66struct ModelConfig {
67    #[serde(default)]
68    id2label: HashMap<String, String>,
69    hidden_size: usize,
70}
71
72struct BertNerModel {
73    bert: BertModel,
74    classifier: Linear,
75    device: Device,
76    id2label: HashMap<usize, String>,
77}
78
79impl BertNerModel {
80    fn load(model_dir: &Path) -> Result<Self> {
81        let config_path = model_dir.join("config.json");
82        let weights_path = model_dir.join("model.safetensors");
83
84        let config_str = std::fs::read_to_string(&config_path)
85            .with_context(|| format!("lendo config.json em {config_path:?}"))?;
86        let model_cfg: ModelConfig =
87            serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
88
89        let id2label: HashMap<usize, String> = model_cfg
90            .id2label
91            .into_iter()
92            .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
93            .collect();
94
95        let num_labels = id2label.len().max(9);
96        let hidden_size = model_cfg.hidden_size;
97
98        let bert_config_str = std::fs::read_to_string(&config_path)
99            .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
100        let bert_cfg: BertConfig =
101            serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
102
103        let device = Device::Cpu;
104
105        let vb = unsafe {
106            VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
107                .with_context(|| format!("mapeando {weights_path:?}"))?
108        };
109        let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("carregando BertModel")?;
110
111        let weight = Tensor::zeros((num_labels, hidden_size), DType::F32, &device)
112            .context("criando peso do classificador")?;
113        let bias = Tensor::zeros(num_labels, DType::F32, &device)
114            .context("criando bias do classificador")?;
115        let classifier = Linear::new(weight, Some(bias));
116
117        Ok(Self {
118            bert,
119            classifier,
120            device,
121            id2label,
122        })
123    }
124
125    fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
126        let len = token_ids.len();
127        let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
128        let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
129
130        let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
131            .context("criando tensor input_ids")?;
132        let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
133            .context("criando tensor token_type_ids")?;
134        let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
135            .context("criando tensor attention_mask")?;
136
137        let sequence_output = self
138            .bert
139            .forward(&input_ids, &token_type_ids, Some(&attn_mask))
140            .context("forward pass do BertModel")?;
141
142        let logits = self
143            .classifier
144            .forward(&sequence_output)
145            .context("forward pass do classificador")?;
146
147        let logits_2d = logits.squeeze(0).context("removendo dimensão batch")?;
148
149        let num_tokens = logits_2d.dim(0).context("dim(0)")?;
150
151        let mut labels = Vec::with_capacity(num_tokens);
152        for i in 0..num_tokens {
153            let token_logits = logits_2d.get(i).context("get token logits")?;
154            let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
155            let argmax = vec
156                .iter()
157                .enumerate()
158                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
159                .map(|(idx, _)| idx)
160                .unwrap_or(0);
161            let label = self
162                .id2label
163                .get(&argmax)
164                .cloned()
165                .unwrap_or_else(|| "O".to_string());
166            labels.push(label);
167        }
168
169        Ok(labels)
170    }
171}
172
173static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
174
175fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
176    NER_MODEL
177        .get_or_init(|| match load_model(paths) {
178            Ok(m) => Some(m),
179            Err(e) => {
180                tracing::warn!("NER model não disponível (graceful degradation): {e:#}");
181                None
182            }
183        })
184        .as_ref()
185}
186
187fn model_dir(paths: &AppPaths) -> PathBuf {
188    paths.models.join("bert-multilingual-ner")
189}
190
191fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
192    let dir = model_dir(paths);
193    std::fs::create_dir_all(&dir)
194        .with_context(|| format!("criando diretório do modelo: {dir:?}"))?;
195
196    let weights = dir.join("model.safetensors");
197    let config = dir.join("config.json");
198    let tokenizer = dir.join("tokenizer.json");
199
200    if weights.exists() && config.exists() && tokenizer.exists() {
201        return Ok(dir);
202    }
203
204    tracing::info!("Baixando modelo NER (primeira execução, ~676 MB)...");
205    crate::output::emit_progress_i18n(
206        "Downloading NER model (first run, ~676 MB)...",
207        "Baixando modelo NER (primeira execução, ~676 MB)...",
208    );
209
210    let api = huggingface_hub::api::sync::Api::new().context("criando cliente HF Hub")?;
211    let repo = api.model(MODEL_ID.to_string());
212
213    for filename in &[
214        "model.safetensors",
215        "config.json",
216        "tokenizer.json",
217        "tokenizer_config.json",
218    ] {
219        let dest = dir.join(filename);
220        if !dest.exists() {
221            let src = repo
222                .get(filename)
223                .with_context(|| format!("baixando {filename} do HF Hub"))?;
224            std::fs::copy(&src, &dest)
225                .with_context(|| format!("copiando {filename} para cache"))?;
226        }
227    }
228
229    Ok(dir)
230}
231
232fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
233    let dir = ensure_model_files(paths)?;
234    BertNerModel::load(&dir)
235}
236
237fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
238    let mut entities = Vec::new();
239    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
240
241    let add = |entities: &mut Vec<ExtractedEntity>,
242               seen: &mut std::collections::HashSet<String>,
243               name: &str,
244               entity_type: &str| {
245        let name = name.trim().to_string();
246        if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
247            entities.push(ExtractedEntity {
248                name,
249                entity_type: entity_type.to_string(),
250            });
251        }
252    };
253
254    for m in regex_email().find_iter(body) {
255        add(&mut entities, &mut seen, m.as_str(), "person");
256    }
257    for m in regex_url().find_iter(body) {
258        add(&mut entities, &mut seen, m.as_str(), "concept");
259    }
260    for m in regex_uuid().find_iter(body) {
261        add(&mut entities, &mut seen, m.as_str(), "concept");
262    }
263    for m in regex_all_caps().find_iter(body) {
264        add(&mut entities, &mut seen, m.as_str(), "concept");
265    }
266
267    entities
268}
269
270fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
271    let mut entities: Vec<ExtractedEntity> = Vec::new();
272    let mut current_parts: Vec<String> = Vec::new();
273    let mut current_type: Option<String> = None;
274
275    let flush =
276        |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
277            if let Some(t) = typ.take() {
278                let name = parts.join(" ").trim().to_string();
279                if name.len() >= MIN_ENTITY_CHARS {
280                    entities.push(ExtractedEntity {
281                        name,
282                        entity_type: t,
283                    });
284                }
285                parts.clear();
286            }
287        };
288
289    for (token, label) in tokens.iter().zip(labels.iter()) {
290        if label == "O" {
291            flush(&mut current_parts, &mut current_type, &mut entities);
292            continue;
293        }
294
295        let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
296            ("B", rest)
297        } else if let Some(rest) = label.strip_prefix("I-") {
298            ("I", rest)
299        } else {
300            flush(&mut current_parts, &mut current_type, &mut entities);
301            continue;
302        };
303
304        let entity_type = match bio_type {
305            "DATE" => {
306                flush(&mut current_parts, &mut current_type, &mut entities);
307                continue;
308            }
309            "PER" => "person",
310            "ORG" => {
311                let t = token.to_lowercase();
312                if t.contains("lib")
313                    || t.contains("sdk")
314                    || t.contains("cli")
315                    || t.contains("crate")
316                    || t.contains("npm")
317                {
318                    "tool"
319                } else {
320                    "project"
321                }
322            }
323            "LOC" => "concept",
324            other => other,
325        };
326
327        if prefix == "B" {
328            flush(&mut current_parts, &mut current_type, &mut entities);
329            current_parts.push(token.clone());
330            current_type = Some(entity_type.to_string());
331        } else if prefix == "I" && current_type.is_some() {
332            let clean = token.strip_prefix("##").unwrap_or(token.as_str());
333            if token.starts_with("##") {
334                if let Some(last) = current_parts.last_mut() {
335                    last.push_str(clean);
336                }
337            } else {
338                current_parts.push(clean.to_string());
339            }
340        }
341    }
342
343    flush(&mut current_parts, &mut current_type, &mut entities);
344    entities
345}
346
347fn build_relationships(entities: &[NewEntity]) -> Vec<NewRelationship> {
348    if entities.len() < 2 {
349        return Vec::new();
350    }
351
352    let n = entities.len().min(MAX_ENTS);
353    let mut rels: Vec<NewRelationship> = Vec::new();
354    let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
355
356    'outer: for i in 0..n {
357        let count = rels.len();
358        if count >= MAX_RELS {
359            break;
360        }
361        let added_for_this = rels.len() - count.min(rels.len());
362        let _ = added_for_this;
363
364        let mut for_entity = 0usize;
365        for j in (i + 1)..n {
366            if for_entity >= TOP_K_RELATIONS {
367                break;
368            }
369            if rels.len() >= MAX_RELS {
370                break 'outer;
371            }
372
373            let src = &entities[i].name;
374            let tgt = &entities[j].name;
375            let key = (src.clone(), tgt.clone());
376
377            if seen.contains(&key) {
378                continue;
379            }
380            seen.insert(key);
381
382            rels.push(NewRelationship {
383                source: src.clone(),
384                target: tgt.clone(),
385                relation: DEFAULT_RELATION.to_string(),
386                strength: 0.5,
387                description: None,
388            });
389            for_entity += 1;
390        }
391    }
392
393    rels
394}
395
396fn run_ner_sliding_window(
397    model: &BertNerModel,
398    body: &str,
399    paths: &AppPaths,
400) -> Result<Vec<ExtractedEntity>> {
401    let tokenizer_path = model_dir(paths).join("tokenizer.json");
402    let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
403        .map_err(|e| anyhow::anyhow!("carregando tokenizer NER: {e}"))?;
404
405    let encoding = tokenizer
406        .encode(body, false)
407        .map_err(|e| anyhow::anyhow!("encoding NER: {e}"))?;
408
409    let all_ids: Vec<u32> = encoding.get_ids().to_vec();
410    let all_tokens: Vec<String> = encoding
411        .get_tokens()
412        .iter()
413        .map(|s| s.to_string())
414        .collect();
415
416    if all_ids.is_empty() {
417        return Ok(Vec::new());
418    }
419
420    let mut entities: Vec<ExtractedEntity> = Vec::new();
421    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
422
423    let mut start = 0usize;
424    loop {
425        let end = (start + MAX_SEQ_LEN).min(all_ids.len());
426        let window_ids = &all_ids[start..end];
427        let window_tokens = &all_tokens[start..end];
428        let attention_mask: Vec<u32> = vec![1u32; window_ids.len()];
429
430        match model.predict(window_ids, &attention_mask) {
431            Ok(labels) => {
432                let window_ents = iob_to_entities(window_tokens, &labels);
433                for ent in window_ents {
434                    if seen.insert(ent.name.clone()) {
435                        entities.push(ent);
436                    }
437                }
438            }
439            Err(e) => {
440                tracing::warn!("janela NER falhou (start={start}): {e:#}");
441            }
442        }
443
444        if end >= all_ids.len() {
445            break;
446        }
447        start += STRIDE;
448    }
449
450    Ok(entities)
451}
452
453fn merge_and_deduplicate(
454    regex_ents: Vec<ExtractedEntity>,
455    ner_ents: Vec<ExtractedEntity>,
456) -> Vec<ExtractedEntity> {
457    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
458    let mut result: Vec<ExtractedEntity> = Vec::new();
459
460    for ent in regex_ents.into_iter().chain(ner_ents) {
461        let key = ent.name.to_lowercase();
462        if seen.insert(key) {
463            result.push(ent);
464        }
465        if result.len() >= MAX_ENTS {
466            break;
467        }
468    }
469
470    result
471}
472
473fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
474    extracted
475        .into_iter()
476        .map(|e| NewEntity {
477            name: e.name,
478            entity_type: e.entity_type,
479            description: None,
480        })
481        .collect()
482}
483
484pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
485    let regex_entities = apply_regex_prefilter(body);
486
487    let ner_entities = match get_or_init_model(paths) {
488        Some(model) => match run_ner_sliding_window(model, body, paths) {
489            Ok(ents) => ents,
490            Err(e) => {
491                tracing::warn!("NER falhou, usando apenas regex: {e:#}");
492                Vec::new()
493            }
494        },
495        None => Vec::new(),
496    };
497
498    let merged = merge_and_deduplicate(regex_entities, ner_entities);
499    let entities = to_new_entities(merged);
500    let relationships = build_relationships(&entities);
501
502    Ok(ExtractionResult {
503        entities,
504        relationships,
505    })
506}
507
508pub struct RegexExtractor;
509
510impl Extractor for RegexExtractor {
511    fn extract(&self, body: &str) -> Result<ExtractionResult> {
512        let regex_entities = apply_regex_prefilter(body);
513        let entities = to_new_entities(regex_entities);
514        let relationships = build_relationships(&entities);
515        Ok(ExtractionResult {
516            entities,
517            relationships,
518        })
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    fn make_paths() -> AppPaths {
527        use std::path::PathBuf;
528        AppPaths {
529            db: PathBuf::from("/tmp/test.sqlite"),
530            models: PathBuf::from("/tmp/test_models"),
531        }
532    }
533
534    #[test]
535    fn regex_email_captura_endereco() {
536        let ents = apply_regex_prefilter("contato: fulano@empresa.com.br para mais info");
537        assert!(ents
538            .iter()
539            .any(|e| e.name == "fulano@empresa.com.br" && e.entity_type == "person"));
540    }
541
542    #[test]
543    fn regex_url_captura_link() {
544        let ents = apply_regex_prefilter("veja https://docs.rs/crate para detalhes");
545        assert!(ents
546            .iter()
547            .any(|e| e.name.starts_with("https://") && e.entity_type == "concept"));
548    }
549
550    #[test]
551    fn regex_uuid_captura_identificador() {
552        let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
553        assert!(ents.iter().any(|e| e.entity_type == "concept"));
554    }
555
556    #[test]
557    fn regex_all_caps_captura_constante() {
558        let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
559        assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
560        assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
561    }
562
563    #[test]
564    fn regex_all_caps_ignora_palavras_curtas() {
565        let ents = apply_regex_prefilter("use AI em seu projeto");
566        assert!(
567            !ents.iter().any(|e| e.name == "AI"),
568            "AI tem apenas 2 chars, deve ser ignorado"
569        );
570    }
571
572    #[test]
573    fn iob_decodifica_per_para_person() {
574        let tokens = vec![
575            "John".to_string(),
576            "Doe".to_string(),
577            "trabalhou".to_string(),
578        ];
579        let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
580        let ents = iob_to_entities(&tokens, &labels);
581        assert_eq!(ents.len(), 1);
582        assert_eq!(ents[0].entity_type, "person");
583        assert!(ents[0].name.contains("John"));
584    }
585
586    #[test]
587    fn iob_descarta_date() {
588        let tokens = vec!["Janeiro".to_string(), "2024".to_string()];
589        let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
590        let ents = iob_to_entities(&tokens, &labels);
591        assert!(ents.is_empty(), "DATE deve ser descartado");
592    }
593
594    #[test]
595    fn iob_mapeia_org_para_project() {
596        let tokens = vec!["Empresa".to_string()];
597        let labels = vec!["B-ORG".to_string()];
598        let ents = iob_to_entities(&tokens, &labels);
599        assert_eq!(ents[0].entity_type, "project");
600    }
601
602    #[test]
603    fn iob_mapeia_org_sdk_para_tool() {
604        let tokens = vec!["tokio-sdk".to_string()];
605        let labels = vec!["B-ORG".to_string()];
606        let ents = iob_to_entities(&tokens, &labels);
607        assert_eq!(ents[0].entity_type, "tool");
608    }
609
610    #[test]
611    fn iob_mapeia_loc_para_concept() {
612        let tokens = vec!["Brasil".to_string()];
613        let labels = vec!["B-LOC".to_string()];
614        let ents = iob_to_entities(&tokens, &labels);
615        assert_eq!(ents[0].entity_type, "concept");
616    }
617
618    #[test]
619    fn build_relationships_respeitam_max_rels() {
620        let entities: Vec<NewEntity> = (0..20)
621            .map(|i| NewEntity {
622                name: format!("entidade_{i}"),
623                entity_type: "concept".to_string(),
624                description: None,
625            })
626            .collect();
627        let rels = build_relationships(&entities);
628        assert!(rels.len() <= MAX_RELS, "deve respeitar MAX_RELS={MAX_RELS}");
629    }
630
631    #[test]
632    fn build_relationships_sem_duplicatas() {
633        let entities: Vec<NewEntity> = (0..5)
634            .map(|i| NewEntity {
635                name: format!("ent_{i}"),
636                entity_type: "concept".to_string(),
637                description: None,
638            })
639            .collect();
640        let rels = build_relationships(&entities);
641        let mut pares: std::collections::HashSet<(String, String)> =
642            std::collections::HashSet::new();
643        for r in &rels {
644            let par = (r.source.clone(), r.target.clone());
645            assert!(pares.insert(par), "par duplicado encontrado");
646        }
647    }
648
649    #[test]
650    fn merge_deduplica_por_nome_lowercase() {
651        let a = vec![ExtractedEntity {
652            name: "Rust".to_string(),
653            entity_type: "concept".to_string(),
654        }];
655        let b = vec![ExtractedEntity {
656            name: "rust".to_string(),
657            entity_type: "tool".to_string(),
658        }];
659        let merged = merge_and_deduplicate(a, b);
660        assert_eq!(merged.len(), 1, "rust e Rust são a mesma entidade");
661    }
662
663    #[test]
664    fn regex_extractor_implementa_trait() {
665        let extractor = RegexExtractor;
666        let result = extractor
667            .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
668            .unwrap();
669        assert!(!result.entities.is_empty());
670    }
671
672    #[test]
673    fn extract_retorna_ok_sem_modelo() {
674        // Sem modelo baixado, deve retornar Ok com apenas entidades regex
675        let paths = make_paths();
676        let body = "contato: teste@exemplo.com com MAX_RETRY=3";
677        let result = extract_graph_auto(body, &paths).unwrap();
678        assert!(result
679            .entities
680            .iter()
681            .any(|e| e.name.contains("teste@exemplo.com")));
682    }
683}