sqlite_graphrag/
embedder.rs1use crate::constants::{
2 EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
3};
4use crate::errors::AppError;
5use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
6use ort::execution_providers::CPU;
7use std::path::Path;
8use std::sync::{Mutex, OnceLock};
9
10static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
11
12pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
15 if let Some(m) = EMBEDDER.get() {
16 return Ok(m);
17 }
18
19 let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
23
24 let model = TextEmbedding::try_new(
25 TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
26 .with_execution_providers(vec![cpu_ep])
27 .with_max_length(EMBEDDING_MAX_TOKENS)
28 .with_show_download_progress(true)
29 .with_cache_dir(models_dir.to_path_buf()),
30 )
31 .map_err(|e| AppError::Embedding(e.to_string()))?;
32 let _ = EMBEDDER.set(Mutex::new(model));
34 Ok(EMBEDDER.get().expect("just set above"))
35}
36
37pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
38 let prefixed = format!("{PASSAGE_PREFIX}{text}");
39 let results = embedder
40 .lock()
41 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
42 .embed(vec![prefixed.as_str()], Some(1))
43 .map_err(|e| AppError::Embedding(e.to_string()))?;
44 let emb = results
45 .into_iter()
46 .next()
47 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
48 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
49 Ok(emb)
50}
51
52pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
53 let prefixed = format!("{QUERY_PREFIX}{text}");
54 let results = embedder
55 .lock()
56 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
57 .embed(vec![prefixed.as_str()], Some(1))
58 .map_err(|e| AppError::Embedding(e.to_string()))?;
59 let emb = results
60 .into_iter()
61 .next()
62 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
63 Ok(emb)
64}
65
66pub fn embed_passages_batch(
67 embedder: &Mutex<TextEmbedding>,
68 texts: &[String],
69) -> Result<Vec<Vec<f32>>, AppError> {
70 let prefixed: Vec<String> = texts
71 .iter()
72 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
73 .collect();
74 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
75 let results = embedder
76 .lock()
77 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
78 .embed(strs, Some(FASTEMBED_BATCH_SIZE))
79 .map_err(|e| AppError::Embedding(e.to_string()))?;
80 for emb in &results {
81 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
82 }
83 Ok(results)
84}
85
86pub fn embed_passages_serial<'a, I>(
91 embedder: &Mutex<TextEmbedding>,
92 texts: I,
93) -> Result<Vec<Vec<f32>>, AppError>
94where
95 I: IntoIterator<Item = &'a str>,
96{
97 let iter = texts.into_iter();
98 let (lower, _) = iter.size_hint();
99 let mut results = Vec::with_capacity(lower);
100 for text in iter {
101 results.push(embed_passage(embedder, text)?);
102 }
103 Ok(results)
104}
105
106pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
110 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
111}
112
113#[cfg(test)]
114mod testes {
115 use super::*;
116 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
117
118 #[test]
121 fn f32_to_bytes_slice_vazio_retorna_vazio() {
122 let v: Vec<f32> = vec![];
123 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
124 }
125
126 #[test]
127 fn f32_to_bytes_um_elemento_retorna_4_bytes() {
128 let v = vec![1.0_f32];
129 let bytes = f32_to_bytes(&v);
130 assert_eq!(bytes.len(), 4);
131 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
133 assert_eq!(recovered, 1.0_f32);
134 }
135
136 #[test]
137 fn f32_to_bytes_comprimento_e_4x_elementos() {
138 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
139 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
140 }
141
142 #[test]
143 fn f32_to_bytes_zero_codificado_como_4_zeros() {
144 let v = vec![0.0_f32];
145 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
146 }
147
148 #[test]
149 fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
150 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
151 let bytes = f32_to_bytes(&v);
152 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
153 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
155 assert!((first - 0.0_f32).abs() < 1e-6);
156 let last_start = (EMBEDDING_DIM - 1) * 4;
157 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
158 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
159 }
160
161 #[test]
164 fn passage_prefix_nao_vazio() {
165 assert_eq!(PASSAGE_PREFIX, "passage: ");
166 }
167
168 #[test]
169 fn query_prefix_nao_vazio() {
170 assert_eq!(QUERY_PREFIX, "query: ");
171 }
172
173 #[test]
174 fn embedding_dim_e_384() {
175 assert_eq!(EMBEDDING_DIM, 384);
176 }
177
178 #[test]
181 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
182 fn embed_passage_retorna_vetor_com_dimensao_correta() {
183 let dir = tempfile::tempdir().unwrap();
184 let embedder = get_embedder(dir.path()).unwrap();
185 let result = embed_passage(embedder, "texto de teste").unwrap();
186 assert_eq!(result.len(), EMBEDDING_DIM);
187 }
188
189 #[test]
190 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
191 fn embed_query_retorna_vetor_com_dimensao_correta() {
192 let dir = tempfile::tempdir().unwrap();
193 let embedder = get_embedder(dir.path()).unwrap();
194 let result = embed_query(embedder, "consulta de teste").unwrap();
195 assert_eq!(result.len(), EMBEDDING_DIM);
196 }
197
198 #[test]
199 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
200 fn embed_passages_batch_retorna_um_vetor_por_texto() {
201 let dir = tempfile::tempdir().unwrap();
202 let embedder = get_embedder(dir.path()).unwrap();
203 let textos = vec!["primeiro".to_string(), "segundo".to_string()];
204 let results = embed_passages_batch(embedder, &textos).unwrap();
205 assert_eq!(results.len(), 2);
206 for emb in &results {
207 assert_eq!(emb.len(), EMBEDDING_DIM);
208 }
209 }
210}