Skip to main content

sqlite_graphrag/
embedder.rs

1use crate::constants::{
2    EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
3};
4use crate::errors::AppError;
5use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
6use ort::execution_providers::CPU;
7use std::path::Path;
8use std::sync::{Mutex, OnceLock};
9
10static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
11
12/// Returns the process-wide singleton embedder, initializing it on first call.
13/// Subsequent calls return the cached instance regardless of `models_dir`.
14pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
15    if let Some(m) = EMBEDDER.get() {
16        return Ok(m);
17    }
18
19    // Desabilita arena allocator da EP CPU para reduzir retenção agressiva de memória
20    // entre inferências repetidas com shapes variáveis. O fastembed já desliga
21    // memory pattern em alguns cenários, mas não desliga a CPU arena por padrão.
22    let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
23
24    let model = TextEmbedding::try_new(
25        TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
26            .with_execution_providers(vec![cpu_ep])
27            .with_max_length(EMBEDDING_MAX_TOKENS)
28            .with_show_download_progress(true)
29            .with_cache_dir(models_dir.to_path_buf()),
30    )
31    .map_err(|e| AppError::Embedding(e.to_string()))?;
32    // If another thread raced and won, discard our instance and return theirs.
33    let _ = EMBEDDER.set(Mutex::new(model));
34    Ok(EMBEDDER.get().expect("just set above"))
35}
36
37pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
38    let prefixed = format!("{PASSAGE_PREFIX}{text}");
39    let results = embedder
40        .lock()
41        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
42        .embed(vec![prefixed.as_str()], Some(1))
43        .map_err(|e| AppError::Embedding(e.to_string()))?;
44    let emb = results
45        .into_iter()
46        .next()
47        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
48    assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
49    Ok(emb)
50}
51
52pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
53    let prefixed = format!("{QUERY_PREFIX}{text}");
54    let results = embedder
55        .lock()
56        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
57        .embed(vec![prefixed.as_str()], Some(1))
58        .map_err(|e| AppError::Embedding(e.to_string()))?;
59    let emb = results
60        .into_iter()
61        .next()
62        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
63    Ok(emb)
64}
65
66pub fn embed_passages_batch(
67    embedder: &Mutex<TextEmbedding>,
68    texts: &[String],
69) -> Result<Vec<Vec<f32>>, AppError> {
70    let prefixed: Vec<String> = texts
71        .iter()
72        .map(|t| format!("{PASSAGE_PREFIX}{t}"))
73        .collect();
74    let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
75    let results = embedder
76        .lock()
77        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
78        .embed(strs, Some(FASTEMBED_BATCH_SIZE))
79        .map_err(|e| AppError::Embedding(e.to_string()))?;
80    for emb in &results {
81        assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
82    }
83    Ok(results)
84}
85
86/// Embed multiple passages serially.
87///
88/// This path intentionally avoids ONNX batch inference for robustness when
89/// real-world Markdown chunks trigger pathological runtime behavior.
90pub fn embed_passages_serial<'a, I>(
91    embedder: &Mutex<TextEmbedding>,
92    texts: I,
93) -> Result<Vec<Vec<f32>>, AppError>
94where
95    I: IntoIterator<Item = &'a str>,
96{
97    let iter = texts.into_iter();
98    let (lower, _) = iter.size_hint();
99    let mut results = Vec::with_capacity(lower);
100    for text in iter {
101        results.push(embed_passage(embedder, text)?);
102    }
103    Ok(results)
104}
105
106/// Convert &[f32] to &[u8] for sqlite-vec storage.
107/// # Safety
108/// Safe because f32 has no padding and is well-defined bit pattern.
109pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
110    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
111}
112
113#[cfg(test)]
114mod testes {
115    use super::*;
116    use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
117
118    // --- testes de f32_to_bytes (função pura, sem modelo) ---
119
120    #[test]
121    fn f32_to_bytes_slice_vazio_retorna_vazio() {
122        let v: Vec<f32> = vec![];
123        assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
124    }
125
126    #[test]
127    fn f32_to_bytes_um_elemento_retorna_4_bytes() {
128        let v = vec![1.0_f32];
129        let bytes = f32_to_bytes(&v);
130        assert_eq!(bytes.len(), 4);
131        // roundtrip: os 4 bytes devem reconstruir o f32 original
132        let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
133        assert_eq!(recovered, 1.0_f32);
134    }
135
136    #[test]
137    fn f32_to_bytes_comprimento_e_4x_elementos() {
138        let v = vec![0.0_f32, 1.0, 2.0, 3.0];
139        assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
140    }
141
142    #[test]
143    fn f32_to_bytes_zero_codificado_como_4_zeros() {
144        let v = vec![0.0_f32];
145        assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
146    }
147
148    #[test]
149    fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
150        let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
151        let bytes = f32_to_bytes(&v);
152        assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
153        // reconstrói e compara primeiro e último elemento
154        let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
155        assert!((first - 0.0_f32).abs() < 1e-6);
156        let last_start = (EMBEDDING_DIM - 1) * 4;
157        let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
158        assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
159    }
160
161    // --- verifica prefixos usados pelo embedder (sem modelo) ---
162
163    #[test]
164    fn passage_prefix_nao_vazio() {
165        assert_eq!(PASSAGE_PREFIX, "passage: ");
166    }
167
168    #[test]
169    fn query_prefix_nao_vazio() {
170        assert_eq!(QUERY_PREFIX, "query: ");
171    }
172
173    #[test]
174    fn embedding_dim_e_384() {
175        assert_eq!(EMBEDDING_DIM, 384);
176    }
177
178    // --- testes com modelo real (ignorados no CI normal) ---
179
180    #[test]
181    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
182    fn embed_passage_retorna_vetor_com_dimensao_correta() {
183        let dir = tempfile::tempdir().unwrap();
184        let embedder = get_embedder(dir.path()).unwrap();
185        let result = embed_passage(embedder, "texto de teste").unwrap();
186        assert_eq!(result.len(), EMBEDDING_DIM);
187    }
188
189    #[test]
190    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
191    fn embed_query_retorna_vetor_com_dimensao_correta() {
192        let dir = tempfile::tempdir().unwrap();
193        let embedder = get_embedder(dir.path()).unwrap();
194        let result = embed_query(embedder, "consulta de teste").unwrap();
195        assert_eq!(result.len(), EMBEDDING_DIM);
196    }
197
198    #[test]
199    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
200    fn embed_passages_batch_retorna_um_vetor_por_texto() {
201        let dir = tempfile::tempdir().unwrap();
202        let embedder = get_embedder(dir.path()).unwrap();
203        let textos = vec!["primeiro".to_string(), "segundo".to_string()];
204        let results = embed_passages_batch(embedder, &textos).unwrap();
205        assert_eq!(results.len(), 2);
206        for emb in &results {
207            assert_eq!(emb.len(), EMBEDDING_DIM);
208        }
209    }
210}