Skip to main content

sqlite_graphrag/
embedder.rs

1use crate::constants::{
2    EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
3    REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
4};
5use crate::errors::AppError;
6use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
7use ort::execution_providers::CPU;
8use std::path::Path;
9use std::sync::{Mutex, OnceLock};
10
11static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
12
13/// Returns the process-wide singleton embedder, initializing it on first call.
14/// Subsequent calls return the cached instance regardless of `models_dir`.
15pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
16    if let Some(m) = EMBEDDER.get() {
17        return Ok(m);
18    }
19
20    maybe_init_dynamic_ort(models_dir)?;
21
22    // Mitigação multi-camada do RSS explosivo observado com payloads de shapes
23    // variáveis. As três camadas atuais são:
24    //   1. `with_arena_allocator(false)` no execution provider CPU (linha abaixo)
25    //   2. env var `ORT_DISABLE_CPU_MEM_ARENA=1` em `main.rs` (default desde v1.0.18)
26    //   3. env var `ORT_NUM_THREADS=1` + `ORT_INTRA_OP_NUM_THREADS=1` em `main.rs`
27    // A bandeira `with_memory_pattern(false)` existe em ort 2.0 (`SessionBuilder`)
28    // mas fastembed 5.13.2 NÃO expõe acesso ao SessionBuilder customizado via
29    // `TextInitOptions`. Caso o RSS volte a crescer em corpora reais, a próxima
30    // mitigação requer um dos seguintes caminhos:
31    //   - Forkar fastembed para expor `SessionBuilder::with_memory_pattern(false)`
32    //   - Bypass de fastembed e uso direto de ort com SessionBuilder customizado
33    //   - Padding fixo em `plan_controlled_batches` para eliminar shapes variáveis
34    // Referências:
35    //   https://onnxruntime.ai/docs/performance/tune-performance/memory.html
36    //   https://github.com/qdrant/fastembed/issues/570
37    let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
38
39    let model = TextEmbedding::try_new(
40        TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
41            .with_execution_providers(vec![cpu_ep])
42            .with_max_length(EMBEDDING_MAX_TOKENS)
43            .with_show_download_progress(true)
44            .with_cache_dir(models_dir.to_path_buf()),
45    )
46    .map_err(|e| AppError::Embedding(e.to_string()))?;
47    // If another thread raced and won, discard our instance and return theirs.
48    let _ = EMBEDDER.set(Mutex::new(model));
49    Ok(EMBEDDER.get().expect("just set above"))
50}
51
52#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
53fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
54    let mut candidates = Vec::new();
55
56    if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
57        if !path.is_empty() {
58            candidates.push(std::path::PathBuf::from(path));
59        }
60    }
61
62    if let Ok(exe) = std::env::current_exe() {
63        if let Some(dir) = exe.parent() {
64            candidates.push(dir.join("libonnxruntime.so"));
65            candidates.push(dir.join("lib").join("libonnxruntime.so"));
66        }
67    }
68
69    candidates.push(models_dir.join("libonnxruntime.so"));
70
71    for path in candidates {
72        if !path.exists() {
73            continue;
74        }
75
76        std::env::set_var("ORT_DYLIB_PATH", &path);
77        let _ = ort::init_from(&path)
78            .map_err(|e| AppError::Embedding(e.to_string()))?
79            .commit();
80        return Ok(());
81    }
82
83    Ok(())
84}
85
86#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
87fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
88    Ok(())
89}
90
91pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
92    let prefixed = format!("{PASSAGE_PREFIX}{text}");
93    let results = embedder
94        .lock()
95        .map_err(|e| AppError::Embedding(format!("mutex do embedder corrompido: {e}")))?
96        .embed(vec![prefixed.as_str()], Some(1))
97        .map_err(|e| AppError::Embedding(e.to_string()))?;
98    let emb = results
99        .into_iter()
100        .next()
101        .ok_or_else(|| AppError::Embedding("resultado de embedding vazio".into()))?;
102    assert_eq!(emb.len(), EMBEDDING_DIM, "dimensão de embedding inesperada");
103    Ok(emb)
104}
105
106pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
107    let prefixed = format!("{QUERY_PREFIX}{text}");
108    let results = embedder
109        .lock()
110        .map_err(|e| AppError::Embedding(format!("mutex do embedder corrompido: {e}")))?
111        .embed(vec![prefixed.as_str()], Some(1))
112        .map_err(|e| AppError::Embedding(e.to_string()))?;
113    let emb = results
114        .into_iter()
115        .next()
116        .ok_or_else(|| AppError::Embedding("resultado de embedding vazio".into()))?;
117    Ok(emb)
118}
119
120pub fn embed_passages_batch(
121    embedder: &Mutex<TextEmbedding>,
122    texts: &[&str],
123    batch_size: usize,
124) -> Result<Vec<Vec<f32>>, AppError> {
125    let prefixed: Vec<String> = texts
126        .iter()
127        .map(|t| format!("{PASSAGE_PREFIX}{t}"))
128        .collect();
129    let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
130    let results = embedder
131        .lock()
132        .map_err(|e| AppError::Embedding(format!("mutex do embedder corrompido: {e}")))?
133        .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
134        .map_err(|e| AppError::Embedding(e.to_string()))?;
135    for emb in &results {
136        assert_eq!(emb.len(), EMBEDDING_DIM, "dimensão de embedding inesperada");
137    }
138    Ok(results)
139}
140
141pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
142    plan_controlled_batches(token_counts).len()
143}
144
145pub fn embed_passages_controlled(
146    embedder: &Mutex<TextEmbedding>,
147    texts: &[&str],
148    token_counts: &[usize],
149) -> Result<Vec<Vec<f32>>, AppError> {
150    if texts.len() != token_counts.len() {
151        return Err(AppError::Internal(anyhow::anyhow!(
152            "comprimento de texts/token_counts diverge no embedding controlado"
153        )));
154    }
155
156    let mut results = Vec::with_capacity(texts.len());
157    for (start, end) in plan_controlled_batches(token_counts) {
158        if end - start == 1 {
159            results.push(embed_passage(embedder, texts[start])?);
160            continue;
161        }
162
163        results.extend(embed_passages_batch(
164            embedder,
165            &texts[start..end],
166            end - start,
167        )?);
168    }
169
170    Ok(results)
171}
172
173/// Embed multiple passages serially.
174///
175/// This path intentionally avoids ONNX batch inference for robustness when
176/// real-world Markdown chunks trigger pathological runtime behavior.
177pub fn embed_passages_serial<'a, I>(
178    embedder: &Mutex<TextEmbedding>,
179    texts: I,
180) -> Result<Vec<Vec<f32>>, AppError>
181where
182    I: IntoIterator<Item = &'a str>,
183{
184    let iter = texts.into_iter();
185    let (lower, _) = iter.size_hint();
186    let mut results = Vec::with_capacity(lower);
187    for text in iter {
188        results.push(embed_passage(embedder, text)?);
189    }
190    Ok(results)
191}
192
193fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
194    let mut batches = Vec::new();
195    let mut start = 0usize;
196
197    while start < token_counts.len() {
198        let mut end = start + 1;
199        let mut max_tokens = token_counts[start].max(1);
200
201        while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
202            let candidate_max = max_tokens.max(token_counts[end].max(1));
203            let candidate_len = end + 1 - start;
204            if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
205                break;
206            }
207            max_tokens = candidate_max;
208            end += 1;
209        }
210
211        batches.push((start, end));
212        start = end;
213    }
214
215    batches
216}
217
218/// Convert &[f32] to &[u8] for sqlite-vec storage.
219/// # Safety
220/// Safe because f32 has no padding and is well-defined bit pattern.
221pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
222    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
223}
224
225#[cfg(test)]
226mod testes {
227    use super::*;
228    use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
229
230    // --- testes de f32_to_bytes (função pura, sem modelo) ---
231
232    #[test]
233    fn f32_to_bytes_slice_vazio_retorna_vazio() {
234        let v: Vec<f32> = vec![];
235        assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
236    }
237
238    #[test]
239    fn f32_to_bytes_um_elemento_retorna_4_bytes() {
240        let v = vec![1.0_f32];
241        let bytes = f32_to_bytes(&v);
242        assert_eq!(bytes.len(), 4);
243        // roundtrip: os 4 bytes devem reconstruir o f32 original
244        let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
245        assert_eq!(recovered, 1.0_f32);
246    }
247
248    #[test]
249    fn f32_to_bytes_comprimento_e_4x_elementos() {
250        let v = vec![0.0_f32, 1.0, 2.0, 3.0];
251        assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
252    }
253
254    #[test]
255    fn f32_to_bytes_zero_codificado_como_4_zeros() {
256        let v = vec![0.0_f32];
257        assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
258    }
259
260    #[test]
261    fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
262        let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
263        let bytes = f32_to_bytes(&v);
264        assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
265        // reconstrói e compara primeiro e último elemento
266        let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
267        assert!((first - 0.0_f32).abs() < 1e-6);
268        let last_start = (EMBEDDING_DIM - 1) * 4;
269        let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
270        assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
271    }
272
273    // --- verifica prefixos usados pelo embedder (sem modelo) ---
274
275    #[test]
276    fn passage_prefix_nao_vazio() {
277        assert_eq!(PASSAGE_PREFIX, "passage: ");
278    }
279
280    #[test]
281    fn query_prefix_nao_vazio() {
282        assert_eq!(QUERY_PREFIX, "query: ");
283    }
284
285    #[test]
286    fn embedding_dim_e_384() {
287        assert_eq!(EMBEDDING_DIM, 384);
288    }
289
290    // --- testes com modelo real (ignorados no CI normal) ---
291
292    #[test]
293    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
294    fn embed_passage_retorna_vetor_com_dimensao_correta() {
295        let dir = tempfile::tempdir().unwrap();
296        let embedder = get_embedder(dir.path()).unwrap();
297        let result = embed_passage(embedder, "texto de teste").unwrap();
298        assert_eq!(result.len(), EMBEDDING_DIM);
299    }
300
301    #[test]
302    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
303    fn embed_query_retorna_vetor_com_dimensao_correta() {
304        let dir = tempfile::tempdir().unwrap();
305        let embedder = get_embedder(dir.path()).unwrap();
306        let result = embed_query(embedder, "consulta de teste").unwrap();
307        assert_eq!(result.len(), EMBEDDING_DIM);
308    }
309
310    #[test]
311    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
312    fn embed_passages_batch_retorna_um_vetor_por_texto() {
313        let dir = tempfile::tempdir().unwrap();
314        let embedder = get_embedder(dir.path()).unwrap();
315        let textos = ["primeiro", "segundo"];
316        let results = embed_passages_batch(embedder, &textos, 2).unwrap();
317        assert_eq!(results.len(), 2);
318        for emb in &results {
319            assert_eq!(emb.len(), EMBEDDING_DIM);
320        }
321    }
322
323    #[test]
324    fn controlled_batch_plan_respeita_orcamento() {
325        assert_eq!(
326            plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
327            vec![(0, 4), (4, 5), (5, 6)]
328        );
329    }
330
331    #[test]
332    fn controlled_batch_count_retorna_um_para_chunk_unico() {
333        assert_eq!(controlled_batch_count(&[350]), 1);
334    }
335}