Skip to main content

sqlite_graphrag/
embedder.rs

1//! fastembed wrapper and per-process embedding cache.
2//!
3//! Owns the in-process `TextEmbedding` model and exposes batch encode/query
4//! helpers used by remember, recall, and related commands.
5// Workload: CPU-bound (ONNX inference, matrix multiplication via fastembed)
6
7use crate::constants::{
8    EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
9    REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
10};
11use crate::errors::AppError;
12use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
13use ort::ep::CPU;
14use parking_lot::Mutex;
15use std::path::Path;
16use std::sync::OnceLock;
17
18/// Process-wide singleton embedding model behind a `Mutex`.
19///
20/// ONNX Runtime's `Session` is not guaranteed thread-safe for concurrent
21/// inference; `Mutex` serialises all embedding calls.  This is correct by
22/// design — without the daemon, embedding throughput is intentionally serial.
23///
24/// For parallel workloads (enrich, ingest) start the daemon first:
25/// `sqlite-graphrag daemon` — the model is loaded once and served via UDS,
26/// eliminating Mutex contention across CLI invocations.
27static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
28
29/// Returns the process-wide singleton embedder, initializing it on first call.
30/// Subsequent calls return the cached instance regardless of `models_dir`.
31///
32/// # Errors
33///
34/// - [`AppError::Embedding`] — ONNX model load failure or runtime initialisation error.
35/// - [`AppError::Io`] — cache directory is inaccessible or cannot be created.
36pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
37    if let Some(m) = EMBEDDER.get() {
38        return Ok(m);
39    }
40
41    maybe_init_dynamic_ort(models_dir)?;
42
43    // Multi-layer mitigation of the explosive RSS observed with variable-shape
44    // payloads. The three current layers are:
45    //   1. `with_arena_allocator(false)` on the CPU execution provider (line below)
46    //   2. env var `ORT_DISABLE_CPU_MEM_ARENA=1` in `main.rs` (default since v1.0.18)
47    //   3. env var `ORT_NUM_THREADS=1` + `ORT_INTRA_OP_NUM_THREADS=1` in `main.rs`
48    // The `with_memory_pattern(false)` flag exists in ort 2.0 (`SessionBuilder`)
49    // but fastembed 5.13.2 does NOT expose access to a custom SessionBuilder via
50    // `TextInitOptions`. If RSS grows again in real corpora, the next
51    // mitigation requires one of the following paths:
52    //   - Fork fastembed to expose `SessionBuilder::with_memory_pattern(false)`
53    //   - Bypass fastembed and use ort directly with a custom SessionBuilder
54    //   - Fixed padding in `plan_controlled_batches` to eliminate variable shapes
55    // References:
56    //   https://onnxruntime.ai/docs/performance/tune-performance/memory.html
57    //   https://github.com/qdrant/fastembed/issues/570
58    let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
59
60    let model = TextEmbedding::try_new(
61        TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
62            .with_execution_providers(vec![cpu_ep])
63            .with_max_length(EMBEDDING_MAX_TOKENS)
64            .with_show_download_progress(true)
65            .with_cache_dir(models_dir.to_path_buf()),
66    )
67    .map_err(|e| AppError::Embedding(e.to_string()))?;
68    // If another thread raced and won, discard our instance and return theirs.
69    let _ = EMBEDDER.set(Mutex::new(model));
70    EMBEDDER.get().ok_or_else(|| {
71        AppError::Embedding(
72            "embedder OnceLock unexpectedly empty after set() (likely a racing initializer aborted before completion)"
73                .into(),
74        )
75    })
76}
77
78#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
79fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
80    let mut candidates = Vec::with_capacity(4);
81
82    if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
83        if !path.is_empty() {
84            candidates.push(std::path::PathBuf::from(path));
85        }
86    }
87
88    if let Ok(exe) = std::env::current_exe() {
89        if let Some(dir) = exe.parent() {
90            candidates.push(dir.join("libonnxruntime.so"));
91            candidates.push(dir.join("lib").join("libonnxruntime.so"));
92        }
93    }
94
95    candidates.push(models_dir.join("libonnxruntime.so"));
96
97    for path in candidates {
98        if !path.exists() {
99            continue;
100        }
101
102        std::env::set_var("ORT_DYLIB_PATH", &path);
103        let _ = ort::init_from(&path)
104            .map_err(|e| AppError::Embedding(e.to_string()))?
105            .commit();
106        return Ok(());
107    }
108
109    Ok(())
110}
111
112#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
113fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
114    Ok(())
115}
116
117/// Embeds a single passage using the `passage:` prefix required by E5 models.
118///
119/// # Errors
120/// Returns `Err` when the model returns an unexpected result.
121#[tracing::instrument(skip(embedder, text), fields(text_len = text.len()))]
122pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
123    let prefixed = format!("{PASSAGE_PREFIX}{text}");
124    let results = embedder
125        .lock()
126        .embed(vec![prefixed.as_str()], Some(1))
127        .map_err(|e| AppError::Embedding(e.to_string()))?;
128    let emb = results
129        .into_iter()
130        .next()
131        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
132    assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
133    Ok(emb)
134}
135
136/// Embeds a search query using the `query:` prefix required by E5 models.
137///
138/// # Errors
139/// Returns `Err` when the model returns an unexpected result.
140#[tracing::instrument(skip(embedder, text), fields(text_len = text.len()))]
141pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
142    let prefixed = format!("{QUERY_PREFIX}{text}");
143    let results = embedder
144        .lock()
145        .embed(vec![prefixed.as_str()], Some(1))
146        .map_err(|e| AppError::Embedding(e.to_string()))?;
147    let emb = results
148        .into_iter()
149        .next()
150        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
151    Ok(emb)
152}
153
154/// Embeds multiple passages in a single ONNX batch call.
155///
156/// `batch_size` is capped at `FASTEMBED_BATCH_SIZE`. All texts receive the `passage:` prefix.
157///
158/// # Errors
159/// Returns `Err` when the model inference fails.
160#[tracing::instrument(skip(embedder, texts), fields(batch_size = texts.len()))]
161pub fn embed_passages_batch(
162    embedder: &Mutex<TextEmbedding>,
163    texts: &[&str],
164    batch_size: usize,
165) -> Result<Vec<Vec<f32>>, AppError> {
166    let prefixed: Vec<String> = texts
167        .iter()
168        .map(|t| format!("{PASSAGE_PREFIX}{t}"))
169        .collect();
170    let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
171    let results = embedder
172        .lock()
173        .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
174        .map_err(|e| AppError::Embedding(e.to_string()))?;
175    for emb in &results {
176        assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
177    }
178    Ok(results)
179}
180
181/// Returns the number of batches that [`embed_passages_controlled`] would produce
182/// for the given `token_counts` slice without running inference.
183pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
184    plan_controlled_batches(token_counts).len()
185}
186
187/// Embeds passages grouped into token-budget-aware batches to avoid OOM on variable-length inputs.
188///
189/// `texts` and `token_counts` must have the same length. Batches are planned using an
190/// internal budget algorithm and single-item batches fall back to [`embed_passage`].
191///
192/// # Errors
193/// Returns `Err` when lengths differ, the mutex is poisoned, or inference fails.
194pub fn embed_passages_controlled(
195    embedder: &Mutex<TextEmbedding>,
196    texts: &[&str],
197    token_counts: &[usize],
198) -> Result<Vec<Vec<f32>>, AppError> {
199    if texts.len() != token_counts.len() {
200        return Err(AppError::Internal(anyhow::anyhow!(
201            "texts/token_counts length mismatch in controlled embedding"
202        )));
203    }
204
205    let mut results = Vec::with_capacity(texts.len());
206    for (start, end) in plan_controlled_batches(token_counts) {
207        if end - start == 1 {
208            results.push(embed_passage(embedder, texts[start])?);
209            continue;
210        }
211
212        results.extend(embed_passages_batch(
213            embedder,
214            &texts[start..end],
215            end - start,
216        )?);
217    }
218
219    Ok(results)
220}
221
222/// Embed multiple passages one-by-one (serial ONNX inference).
223///
224/// Serialization is **intentional**: ONNX batch inference can trigger pathological
225/// runtime behaviour on real-world Markdown chunks (variable token lengths cause
226/// extreme padding overhead). Callers that need parallelism should use the rayon
227/// `ThreadPool` in `src/commands/ingest.rs::run`, which partitions work across
228/// CPU threads and calls this function per shard.
229///
230/// # Errors
231///
232/// Returns [`AppError::Embedding`] when the ONNX encoder fails on any passage.
233pub fn embed_passages_serial<'a, I>(
234    embedder: &Mutex<TextEmbedding>,
235    texts: I,
236) -> Result<Vec<Vec<f32>>, AppError>
237where
238    I: IntoIterator<Item = &'a str>,
239{
240    let iter = texts.into_iter();
241    let (lower, _) = iter.size_hint();
242    let mut results = Vec::with_capacity(lower);
243    for text in iter {
244        results.push(embed_passage(embedder, text)?);
245    }
246    Ok(results)
247}
248
249fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
250    let mut batches =
251        Vec::with_capacity((token_counts.len() / REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS).max(1));
252    let mut start = 0usize;
253
254    while start < token_counts.len() {
255        let mut end = start + 1;
256        let mut max_tokens = token_counts[start].max(1);
257
258        while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
259            let candidate_max = max_tokens.max(token_counts[end].max(1));
260            let candidate_len = end + 1 - start;
261            if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
262                break;
263            }
264            max_tokens = candidate_max;
265            end += 1;
266        }
267
268        batches.push((start, end));
269        start = end;
270    }
271
272    batches
273}
274
275/// Convert `&[f32]` to `&[u8]` for sqlite-vec storage.
276///
277/// # Safety
278///
279/// This function is sound when the following invariants hold:
280/// 1. `f32` has no padding bytes per the Rust reference
281///    (<https://doc.rust-lang.org/reference/types/numeric.html>);
282///    `[f32]` has the same byte representation as `[u8; size_of_val(v)]`.
283/// 2. The returned `&[u8]` borrows from `v`; its lifetime is tied to the input slice.
284/// 3. Endianness matches sqlite-vec on supported platforms (x86_64, aarch64 little-endian).
285///    Targets with big-endian `f32` storage are not supported by sqlite-vec.
286#[cfg(target_endian = "big")]
287compile_error!(
288    "sqlite-graphrag requires little-endian f32 layout for sqlite-vec compatibility. \
289     Big-endian targets (PPC64, S390x) are not supported."
290);
291
292pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
293    // SAFETY: see invariants above. f32→u8 transmute via from_raw_parts is sound.
294    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
301
302    // --- f32_to_bytes tests (pure function, no model) ---
303
304    #[test]
305    fn f32_to_bytes_empty_slice_returns_empty() {
306        let v: Vec<f32> = vec![];
307        assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
308    }
309
310    #[test]
311    fn f32_to_bytes_one_element_returns_4_bytes() {
312        let v = vec![1.0_f32];
313        let bytes = f32_to_bytes(&v);
314        assert_eq!(bytes.len(), 4);
315        // roundtrip: the 4 bytes must reconstruct the original f32
316        let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
317        assert_eq!(recovered, 1.0_f32);
318    }
319
320    #[test]
321    fn f32_to_bytes_length_is_4x_elements() {
322        let v = vec![0.0_f32, 1.0, 2.0, 3.0];
323        assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
324    }
325
326    #[test]
327    fn f32_to_bytes_zero_encoded_as_4_zeros() {
328        let v = vec![0.0_f32];
329        assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
330    }
331
332    #[test]
333    fn f32_to_bytes_roundtrip_vector_embedding_dim() {
334        let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
335        let bytes = f32_to_bytes(&v);
336        assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
337        // reconstructs and compares first and last element
338        let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
339        assert!((first - 0.0_f32).abs() < 1e-6);
340        let last_start = (EMBEDDING_DIM - 1) * 4;
341        let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
342        assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
343    }
344
345    // --- verifies prefixes used by the embedder (no model) ---
346
347    #[test]
348    fn passage_prefix_not_empty() {
349        assert_eq!(PASSAGE_PREFIX, "passage: ");
350    }
351
352    #[test]
353    fn query_prefix_not_empty() {
354        assert_eq!(QUERY_PREFIX, "query: ");
355    }
356
357    #[test]
358    fn embedding_dim_is_384() {
359        assert_eq!(EMBEDDING_DIM, 384);
360    }
361
362    // --- testes com modelo real (ignorados no CI normal) ---
363
364    #[test]
365    #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
366    fn embed_passage_returns_vector_with_correct_dimension() {
367        let dir = tempfile::tempdir().unwrap();
368        let embedder = get_embedder(dir.path()).unwrap();
369        let result = embed_passage(embedder, "test text").unwrap();
370        assert_eq!(result.len(), EMBEDDING_DIM);
371    }
372
373    #[test]
374    #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
375    fn embed_query_returns_vector_with_correct_dimension() {
376        let dir = tempfile::tempdir().unwrap();
377        let embedder = get_embedder(dir.path()).unwrap();
378        let result = embed_query(embedder, "test query").unwrap();
379        assert_eq!(result.len(), EMBEDDING_DIM);
380    }
381
382    #[test]
383    #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
384    fn embed_passages_batch_returns_one_vector_per_text() {
385        let dir = tempfile::tempdir().unwrap();
386        let embedder = get_embedder(dir.path()).unwrap();
387        let textos = ["primeiro", "segundo"];
388        let results = embed_passages_batch(embedder, &textos, 2).unwrap();
389        assert_eq!(results.len(), 2);
390        for emb in &results {
391            assert_eq!(emb.len(), EMBEDDING_DIM);
392        }
393    }
394
395    #[test]
396    fn controlled_batch_plan_respects_budget() {
397        assert_eq!(
398            plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
399            vec![(0, 4), (4, 5), (5, 6)]
400        );
401    }
402
403    #[test]
404    fn controlled_batch_count_returns_one_for_single_chunk() {
405        assert_eq!(controlled_batch_count(&[350]), 1);
406    }
407}