1use crate::constants::{
7 EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
8 REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
9};
10use crate::errors::AppError;
11use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
12use ort::execution_providers::CPU;
13use std::path::Path;
14use std::sync::{Mutex, OnceLock};
15
16static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
17
18pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
21 if let Some(m) = EMBEDDER.get() {
22 return Ok(m);
23 }
24
25 maybe_init_dynamic_ort(models_dir)?;
26
27 let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
43
44 let model = TextEmbedding::try_new(
45 TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
46 .with_execution_providers(vec![cpu_ep])
47 .with_max_length(EMBEDDING_MAX_TOKENS)
48 .with_show_download_progress(true)
49 .with_cache_dir(models_dir.to_path_buf()),
50 )
51 .map_err(|e| AppError::Embedding(e.to_string()))?;
52 let _ = EMBEDDER.set(Mutex::new(model));
54 Ok(EMBEDDER.get().expect("just set above"))
55}
56
57#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
58fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
59 let mut candidates = Vec::new();
60
61 if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
62 if !path.is_empty() {
63 candidates.push(std::path::PathBuf::from(path));
64 }
65 }
66
67 if let Ok(exe) = std::env::current_exe() {
68 if let Some(dir) = exe.parent() {
69 candidates.push(dir.join("libonnxruntime.so"));
70 candidates.push(dir.join("lib").join("libonnxruntime.so"));
71 }
72 }
73
74 candidates.push(models_dir.join("libonnxruntime.so"));
75
76 for path in candidates {
77 if !path.exists() {
78 continue;
79 }
80
81 std::env::set_var("ORT_DYLIB_PATH", &path);
82 let _ = ort::init_from(&path)
83 .map_err(|e| AppError::Embedding(e.to_string()))?
84 .commit();
85 return Ok(());
86 }
87
88 Ok(())
89}
90
91#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
92fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
93 Ok(())
94}
95
96pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
97 let prefixed = format!("{PASSAGE_PREFIX}{text}");
98 let results = embedder
99 .lock()
100 .map_err(|e| AppError::Embedding(format!("mutex do embedder corrompido: {e}")))?
101 .embed(vec![prefixed.as_str()], Some(1))
102 .map_err(|e| AppError::Embedding(e.to_string()))?;
103 let emb = results
104 .into_iter()
105 .next()
106 .ok_or_else(|| AppError::Embedding("resultado de embedding vazio".into()))?;
107 assert_eq!(emb.len(), EMBEDDING_DIM, "dimensão de embedding inesperada");
108 Ok(emb)
109}
110
111pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
112 let prefixed = format!("{QUERY_PREFIX}{text}");
113 let results = embedder
114 .lock()
115 .map_err(|e| AppError::Embedding(format!("mutex do embedder corrompido: {e}")))?
116 .embed(vec![prefixed.as_str()], Some(1))
117 .map_err(|e| AppError::Embedding(e.to_string()))?;
118 let emb = results
119 .into_iter()
120 .next()
121 .ok_or_else(|| AppError::Embedding("resultado de embedding vazio".into()))?;
122 Ok(emb)
123}
124
125pub fn embed_passages_batch(
126 embedder: &Mutex<TextEmbedding>,
127 texts: &[&str],
128 batch_size: usize,
129) -> Result<Vec<Vec<f32>>, AppError> {
130 let prefixed: Vec<String> = texts
131 .iter()
132 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
133 .collect();
134 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
135 let results = embedder
136 .lock()
137 .map_err(|e| AppError::Embedding(format!("mutex do embedder corrompido: {e}")))?
138 .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
139 .map_err(|e| AppError::Embedding(e.to_string()))?;
140 for emb in &results {
141 assert_eq!(emb.len(), EMBEDDING_DIM, "dimensão de embedding inesperada");
142 }
143 Ok(results)
144}
145
146pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
147 plan_controlled_batches(token_counts).len()
148}
149
150pub fn embed_passages_controlled(
151 embedder: &Mutex<TextEmbedding>,
152 texts: &[&str],
153 token_counts: &[usize],
154) -> Result<Vec<Vec<f32>>, AppError> {
155 if texts.len() != token_counts.len() {
156 return Err(AppError::Internal(anyhow::anyhow!(
157 "comprimento de texts/token_counts diverge no embedding controlado"
158 )));
159 }
160
161 let mut results = Vec::with_capacity(texts.len());
162 for (start, end) in plan_controlled_batches(token_counts) {
163 if end - start == 1 {
164 results.push(embed_passage(embedder, texts[start])?);
165 continue;
166 }
167
168 results.extend(embed_passages_batch(
169 embedder,
170 &texts[start..end],
171 end - start,
172 )?);
173 }
174
175 Ok(results)
176}
177
178pub fn embed_passages_serial<'a, I>(
183 embedder: &Mutex<TextEmbedding>,
184 texts: I,
185) -> Result<Vec<Vec<f32>>, AppError>
186where
187 I: IntoIterator<Item = &'a str>,
188{
189 let iter = texts.into_iter();
190 let (lower, _) = iter.size_hint();
191 let mut results = Vec::with_capacity(lower);
192 for text in iter {
193 results.push(embed_passage(embedder, text)?);
194 }
195 Ok(results)
196}
197
198fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
199 let mut batches = Vec::new();
200 let mut start = 0usize;
201
202 while start < token_counts.len() {
203 let mut end = start + 1;
204 let mut max_tokens = token_counts[start].max(1);
205
206 while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
207 let candidate_max = max_tokens.max(token_counts[end].max(1));
208 let candidate_len = end + 1 - start;
209 if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
210 break;
211 }
212 max_tokens = candidate_max;
213 end += 1;
214 }
215
216 batches.push((start, end));
217 start = end;
218 }
219
220 batches
221}
222
223pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
227 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
234
235 #[test]
238 fn f32_to_bytes_slice_vazio_retorna_vazio() {
239 let v: Vec<f32> = vec![];
240 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
241 }
242
243 #[test]
244 fn f32_to_bytes_um_elemento_retorna_4_bytes() {
245 let v = vec![1.0_f32];
246 let bytes = f32_to_bytes(&v);
247 assert_eq!(bytes.len(), 4);
248 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
250 assert_eq!(recovered, 1.0_f32);
251 }
252
253 #[test]
254 fn f32_to_bytes_comprimento_e_4x_elementos() {
255 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
256 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
257 }
258
259 #[test]
260 fn f32_to_bytes_zero_codificado_como_4_zeros() {
261 let v = vec![0.0_f32];
262 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
263 }
264
265 #[test]
266 fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
267 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
268 let bytes = f32_to_bytes(&v);
269 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
270 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
272 assert!((first - 0.0_f32).abs() < 1e-6);
273 let last_start = (EMBEDDING_DIM - 1) * 4;
274 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
275 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
276 }
277
278 #[test]
281 fn passage_prefix_nao_vazio() {
282 assert_eq!(PASSAGE_PREFIX, "passage: ");
283 }
284
285 #[test]
286 fn query_prefix_nao_vazio() {
287 assert_eq!(QUERY_PREFIX, "query: ");
288 }
289
290 #[test]
291 fn embedding_dim_e_384() {
292 assert_eq!(EMBEDDING_DIM, 384);
293 }
294
295 #[test]
298 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
299 fn embed_passage_retorna_vetor_com_dimensao_correta() {
300 let dir = tempfile::tempdir().unwrap();
301 let embedder = get_embedder(dir.path()).unwrap();
302 let result = embed_passage(embedder, "texto de teste").unwrap();
303 assert_eq!(result.len(), EMBEDDING_DIM);
304 }
305
306 #[test]
307 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
308 fn embed_query_retorna_vetor_com_dimensao_correta() {
309 let dir = tempfile::tempdir().unwrap();
310 let embedder = get_embedder(dir.path()).unwrap();
311 let result = embed_query(embedder, "consulta de teste").unwrap();
312 assert_eq!(result.len(), EMBEDDING_DIM);
313 }
314
315 #[test]
316 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
317 fn embed_passages_batch_retorna_um_vetor_por_texto() {
318 let dir = tempfile::tempdir().unwrap();
319 let embedder = get_embedder(dir.path()).unwrap();
320 let textos = ["primeiro", "segundo"];
321 let results = embed_passages_batch(embedder, &textos, 2).unwrap();
322 assert_eq!(results.len(), 2);
323 for emb in &results {
324 assert_eq!(emb.len(), EMBEDDING_DIM);
325 }
326 }
327
328 #[test]
329 fn controlled_batch_plan_respeita_orcamento() {
330 assert_eq!(
331 plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
332 vec![(0, 4), (4, 5), (5, 6)]
333 );
334 }
335
336 #[test]
337 fn controlled_batch_count_retorna_um_para_chunk_unico() {
338 assert_eq!(controlled_batch_count(&[350]), 1);
339 }
340}