Skip to main content

sqlite_graphrag/
embedder.rs

1//! fastembed wrapper and per-process embedding cache.
2//!
3//! Owns the in-process `TextEmbedding` model and exposes batch encode/query
4//! helpers used by remember, recall, and related commands.
5
6use crate::constants::{
7    EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
8    REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
9};
10use crate::errors::AppError;
11use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
12use ort::ep::CPU;
13use std::path::Path;
14use std::sync::{Mutex, OnceLock};
15
16static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
17
18/// Returns the process-wide singleton embedder, initializing it on first call.
19/// Subsequent calls return the cached instance regardless of `models_dir`.
20pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
21    if let Some(m) = EMBEDDER.get() {
22        return Ok(m);
23    }
24
25    maybe_init_dynamic_ort(models_dir)?;
26
27    // Multi-layer mitigation of the explosive RSS observed with variable-shape
28    // payloads. The three current layers are:
29    //   1. `with_arena_allocator(false)` on the CPU execution provider (line below)
30    //   2. env var `ORT_DISABLE_CPU_MEM_ARENA=1` in `main.rs` (default since v1.0.18)
31    //   3. env var `ORT_NUM_THREADS=1` + `ORT_INTRA_OP_NUM_THREADS=1` in `main.rs`
32    // The `with_memory_pattern(false)` flag exists in ort 2.0 (`SessionBuilder`)
33    // but fastembed 5.13.2 does NOT expose access to a custom SessionBuilder via
34    // `TextInitOptions`. If RSS grows again in real corpora, the next
35    // mitigation requires one of the following paths:
36    //   - Fork fastembed to expose `SessionBuilder::with_memory_pattern(false)`
37    //   - Bypass fastembed and use ort directly with a custom SessionBuilder
38    //   - Fixed padding in `plan_controlled_batches` to eliminate variable shapes
39    // References:
40    //   https://onnxruntime.ai/docs/performance/tune-performance/memory.html
41    //   https://github.com/qdrant/fastembed/issues/570
42    let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
43
44    let model = TextEmbedding::try_new(
45        TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
46            .with_execution_providers(vec![cpu_ep])
47            .with_max_length(EMBEDDING_MAX_TOKENS)
48            .with_show_download_progress(true)
49            .with_cache_dir(models_dir.to_path_buf()),
50    )
51    .map_err(|e| AppError::Embedding(e.to_string()))?;
52    // If another thread raced and won, discard our instance and return theirs.
53    let _ = EMBEDDER.set(Mutex::new(model));
54    EMBEDDER.get().ok_or_else(|| {
55        AppError::Embedding(
56            "embedder OnceLock unexpectedly empty after set() (likely a racing initializer aborted before completion)"
57                .into(),
58        )
59    })
60}
61
62#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
63fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
64    let mut candidates = Vec::new();
65
66    if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
67        if !path.is_empty() {
68            candidates.push(std::path::PathBuf::from(path));
69        }
70    }
71
72    if let Ok(exe) = std::env::current_exe() {
73        if let Some(dir) = exe.parent() {
74            candidates.push(dir.join("libonnxruntime.so"));
75            candidates.push(dir.join("lib").join("libonnxruntime.so"));
76        }
77    }
78
79    candidates.push(models_dir.join("libonnxruntime.so"));
80
81    for path in candidates {
82        if !path.exists() {
83            continue;
84        }
85
86        std::env::set_var("ORT_DYLIB_PATH", &path);
87        let _ = ort::init_from(&path)
88            .map_err(|e| AppError::Embedding(e.to_string()))?
89            .commit();
90        return Ok(());
91    }
92
93    Ok(())
94}
95
96#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
97fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
98    Ok(())
99}
100
101/// Embeds a single passage using the `passage:` prefix required by E5 models.
102///
103/// # Errors
104/// Returns `Err` when the mutex is poisoned or the model returns an unexpected result.
105pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
106    let prefixed = format!("{PASSAGE_PREFIX}{text}");
107    let results = embedder
108        .lock()
109        .map_err(|e| AppError::Embedding(format!("embedder mutex poisoned: {e}")))?
110        .embed(vec![prefixed.as_str()], Some(1))
111        .map_err(|e| AppError::Embedding(e.to_string()))?;
112    let emb = results
113        .into_iter()
114        .next()
115        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
116    assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
117    Ok(emb)
118}
119
120/// Embeds a search query using the `query:` prefix required by E5 models.
121///
122/// # Errors
123/// Returns `Err` when the mutex is poisoned or the model returns an unexpected result.
124pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
125    let prefixed = format!("{QUERY_PREFIX}{text}");
126    let results = embedder
127        .lock()
128        .map_err(|e| AppError::Embedding(format!("embedder mutex poisoned: {e}")))?
129        .embed(vec![prefixed.as_str()], Some(1))
130        .map_err(|e| AppError::Embedding(e.to_string()))?;
131    let emb = results
132        .into_iter()
133        .next()
134        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
135    Ok(emb)
136}
137
138/// Embeds multiple passages in a single ONNX batch call.
139///
140/// `batch_size` is capped at `FASTEMBED_BATCH_SIZE`. All texts receive the `passage:` prefix.
141///
142/// # Errors
143/// Returns `Err` when the mutex is poisoned or the model inference fails.
144pub fn embed_passages_batch(
145    embedder: &Mutex<TextEmbedding>,
146    texts: &[&str],
147    batch_size: usize,
148) -> Result<Vec<Vec<f32>>, AppError> {
149    let prefixed: Vec<String> = texts
150        .iter()
151        .map(|t| format!("{PASSAGE_PREFIX}{t}"))
152        .collect();
153    let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
154    let results = embedder
155        .lock()
156        .map_err(|e| AppError::Embedding(format!("embedder mutex poisoned: {e}")))?
157        .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
158        .map_err(|e| AppError::Embedding(e.to_string()))?;
159    for emb in &results {
160        assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
161    }
162    Ok(results)
163}
164
165/// Returns the number of batches that [`embed_passages_controlled`] would produce
166/// for the given `token_counts` slice without running inference.
167pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
168    plan_controlled_batches(token_counts).len()
169}
170
171/// Embeds passages grouped into token-budget-aware batches to avoid OOM on variable-length inputs.
172///
173/// `texts` and `token_counts` must have the same length. Batches are planned using an
174/// internal budget algorithm and single-item batches fall back to [`embed_passage`].
175///
176/// # Errors
177/// Returns `Err` when lengths differ, the mutex is poisoned, or inference fails.
178pub fn embed_passages_controlled(
179    embedder: &Mutex<TextEmbedding>,
180    texts: &[&str],
181    token_counts: &[usize],
182) -> Result<Vec<Vec<f32>>, AppError> {
183    if texts.len() != token_counts.len() {
184        return Err(AppError::Internal(anyhow::anyhow!(
185            "texts/token_counts length mismatch in controlled embedding"
186        )));
187    }
188
189    let mut results = Vec::with_capacity(texts.len());
190    for (start, end) in plan_controlled_batches(token_counts) {
191        if end - start == 1 {
192            results.push(embed_passage(embedder, texts[start])?);
193            continue;
194        }
195
196        results.extend(embed_passages_batch(
197            embedder,
198            &texts[start..end],
199            end - start,
200        )?);
201    }
202
203    Ok(results)
204}
205
206/// Embed multiple passages one-by-one (serial ONNX inference).
207///
208/// Serialization is **intentional**: ONNX batch inference can trigger pathological
209/// runtime behaviour on real-world Markdown chunks (variable token lengths cause
210/// extreme padding overhead). Callers that need parallelism should use the rayon
211/// `ThreadPool` in `src/commands/ingest.rs::run`, which partitions work across
212/// CPU threads and calls this function per shard.
213pub fn embed_passages_serial<'a, I>(
214    embedder: &Mutex<TextEmbedding>,
215    texts: I,
216) -> Result<Vec<Vec<f32>>, AppError>
217where
218    I: IntoIterator<Item = &'a str>,
219{
220    let iter = texts.into_iter();
221    let (lower, _) = iter.size_hint();
222    let mut results = Vec::with_capacity(lower);
223    for text in iter {
224        results.push(embed_passage(embedder, text)?);
225    }
226    Ok(results)
227}
228
229fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
230    let mut batches = Vec::new();
231    let mut start = 0usize;
232
233    while start < token_counts.len() {
234        let mut end = start + 1;
235        let mut max_tokens = token_counts[start].max(1);
236
237        while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
238            let candidate_max = max_tokens.max(token_counts[end].max(1));
239            let candidate_len = end + 1 - start;
240            if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
241                break;
242            }
243            max_tokens = candidate_max;
244            end += 1;
245        }
246
247        batches.push((start, end));
248        start = end;
249    }
250
251    batches
252}
253
254/// Convert `&[f32]` to `&[u8]` for sqlite-vec storage.
255///
256/// # Safety
257///
258/// This function is sound when the following invariants hold:
259/// 1. `f32` has no padding bytes per the Rust reference
260///    (<https://doc.rust-lang.org/reference/types/numeric.html>);
261///    `[f32]` has the same byte representation as `[u8; size_of_val(v)]`.
262/// 2. The returned `&[u8]` borrows from `v`; its lifetime is tied to the input slice.
263/// 3. Endianness matches sqlite-vec on supported platforms (x86_64, aarch64 little-endian).
264///    Targets with big-endian `f32` storage are not supported by sqlite-vec.
265pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
266    // SAFETY: see invariants above. f32→u8 transmute via from_raw_parts is sound.
267    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
274
275    // --- f32_to_bytes tests (pure function, no model) ---
276
277    #[test]
278    fn f32_to_bytes_empty_slice_returns_empty() {
279        let v: Vec<f32> = vec![];
280        assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
281    }
282
283    #[test]
284    fn f32_to_bytes_one_element_returns_4_bytes() {
285        let v = vec![1.0_f32];
286        let bytes = f32_to_bytes(&v);
287        assert_eq!(bytes.len(), 4);
288        // roundtrip: the 4 bytes must reconstruct the original f32
289        let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
290        assert_eq!(recovered, 1.0_f32);
291    }
292
293    #[test]
294    fn f32_to_bytes_length_is_4x_elements() {
295        let v = vec![0.0_f32, 1.0, 2.0, 3.0];
296        assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
297    }
298
299    #[test]
300    fn f32_to_bytes_zero_encoded_as_4_zeros() {
301        let v = vec![0.0_f32];
302        assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
303    }
304
305    #[test]
306    fn f32_to_bytes_roundtrip_vector_embedding_dim() {
307        let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
308        let bytes = f32_to_bytes(&v);
309        assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
310        // reconstructs and compares first and last element
311        let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
312        assert!((first - 0.0_f32).abs() < 1e-6);
313        let last_start = (EMBEDDING_DIM - 1) * 4;
314        let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
315        assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
316    }
317
318    // --- verifies prefixes used by the embedder (no model) ---
319
320    #[test]
321    fn passage_prefix_not_empty() {
322        assert_eq!(PASSAGE_PREFIX, "passage: ");
323    }
324
325    #[test]
326    fn query_prefix_not_empty() {
327        assert_eq!(QUERY_PREFIX, "query: ");
328    }
329
330    #[test]
331    fn embedding_dim_is_384() {
332        assert_eq!(EMBEDDING_DIM, 384);
333    }
334
335    // --- testes com modelo real (ignorados no CI normal) ---
336
337    #[test]
338    #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
339    fn embed_passage_returns_vector_with_correct_dimension() {
340        let dir = tempfile::tempdir().unwrap();
341        let embedder = get_embedder(dir.path()).unwrap();
342        let result = embed_passage(embedder, "test text").unwrap();
343        assert_eq!(result.len(), EMBEDDING_DIM);
344    }
345
346    #[test]
347    #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
348    fn embed_query_returns_vector_with_correct_dimension() {
349        let dir = tempfile::tempdir().unwrap();
350        let embedder = get_embedder(dir.path()).unwrap();
351        let result = embed_query(embedder, "test query").unwrap();
352        assert_eq!(result.len(), EMBEDDING_DIM);
353    }
354
355    #[test]
356    #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
357    fn embed_passages_batch_returns_one_vector_per_text() {
358        let dir = tempfile::tempdir().unwrap();
359        let embedder = get_embedder(dir.path()).unwrap();
360        let textos = ["primeiro", "segundo"];
361        let results = embed_passages_batch(embedder, &textos, 2).unwrap();
362        assert_eq!(results.len(), 2);
363        for emb in &results {
364            assert_eq!(emb.len(), EMBEDDING_DIM);
365        }
366    }
367
368    #[test]
369    fn controlled_batch_plan_respects_budget() {
370        assert_eq!(
371            plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
372            vec![(0, 4), (4, 5), (5, 6)]
373        );
374    }
375
376    #[test]
377    fn controlled_batch_count_returns_one_for_single_chunk() {
378        assert_eq!(controlled_batch_count(&[350]), 1);
379    }
380}