1use crate::constants::{
2 EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
3 REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
4};
5use crate::errors::AppError;
6use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
7use ort::execution_providers::CPU;
8use std::path::Path;
9use std::sync::{Mutex, OnceLock};
10
11static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
12
13pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
16 if let Some(m) = EMBEDDER.get() {
17 return Ok(m);
18 }
19
20 maybe_init_dynamic_ort(models_dir)?;
21
22 let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
38
39 let model = TextEmbedding::try_new(
40 TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
41 .with_execution_providers(vec![cpu_ep])
42 .with_max_length(EMBEDDING_MAX_TOKENS)
43 .with_show_download_progress(true)
44 .with_cache_dir(models_dir.to_path_buf()),
45 )
46 .map_err(|e| AppError::Embedding(e.to_string()))?;
47 let _ = EMBEDDER.set(Mutex::new(model));
49 Ok(EMBEDDER.get().expect("just set above"))
50}
51
52#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
53fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
54 let mut candidates = Vec::new();
55
56 if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
57 if !path.is_empty() {
58 candidates.push(std::path::PathBuf::from(path));
59 }
60 }
61
62 if let Ok(exe) = std::env::current_exe() {
63 if let Some(dir) = exe.parent() {
64 candidates.push(dir.join("libonnxruntime.so"));
65 candidates.push(dir.join("lib").join("libonnxruntime.so"));
66 }
67 }
68
69 candidates.push(models_dir.join("libonnxruntime.so"));
70
71 for path in candidates {
72 if !path.exists() {
73 continue;
74 }
75
76 std::env::set_var("ORT_DYLIB_PATH", &path);
77 let _ = ort::init_from(&path)
78 .map_err(|e| AppError::Embedding(e.to_string()))?
79 .commit();
80 return Ok(());
81 }
82
83 Ok(())
84}
85
86#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
87fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
88 Ok(())
89}
90
91pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
92 let prefixed = format!("{PASSAGE_PREFIX}{text}");
93 let results = embedder
94 .lock()
95 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
96 .embed(vec![prefixed.as_str()], Some(1))
97 .map_err(|e| AppError::Embedding(e.to_string()))?;
98 let emb = results
99 .into_iter()
100 .next()
101 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
102 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
103 Ok(emb)
104}
105
106pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
107 let prefixed = format!("{QUERY_PREFIX}{text}");
108 let results = embedder
109 .lock()
110 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
111 .embed(vec![prefixed.as_str()], Some(1))
112 .map_err(|e| AppError::Embedding(e.to_string()))?;
113 let emb = results
114 .into_iter()
115 .next()
116 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
117 Ok(emb)
118}
119
120pub fn embed_passages_batch(
121 embedder: &Mutex<TextEmbedding>,
122 texts: &[&str],
123 batch_size: usize,
124) -> Result<Vec<Vec<f32>>, AppError> {
125 let prefixed: Vec<String> = texts
126 .iter()
127 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
128 .collect();
129 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
130 let results = embedder
131 .lock()
132 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
133 .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
134 .map_err(|e| AppError::Embedding(e.to_string()))?;
135 for emb in &results {
136 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
137 }
138 Ok(results)
139}
140
141pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
142 plan_controlled_batches(token_counts).len()
143}
144
145pub fn embed_passages_controlled(
146 embedder: &Mutex<TextEmbedding>,
147 texts: &[&str],
148 token_counts: &[usize],
149) -> Result<Vec<Vec<f32>>, AppError> {
150 if texts.len() != token_counts.len() {
151 return Err(AppError::Internal(anyhow::anyhow!(
152 "texts/token_counts length mismatch in controlled embedding"
153 )));
154 }
155
156 let mut results = Vec::with_capacity(texts.len());
157 for (start, end) in plan_controlled_batches(token_counts) {
158 if end - start == 1 {
159 results.push(embed_passage(embedder, texts[start])?);
160 continue;
161 }
162
163 results.extend(embed_passages_batch(
164 embedder,
165 &texts[start..end],
166 end - start,
167 )?);
168 }
169
170 Ok(results)
171}
172
173pub fn embed_passages_serial<'a, I>(
178 embedder: &Mutex<TextEmbedding>,
179 texts: I,
180) -> Result<Vec<Vec<f32>>, AppError>
181where
182 I: IntoIterator<Item = &'a str>,
183{
184 let iter = texts.into_iter();
185 let (lower, _) = iter.size_hint();
186 let mut results = Vec::with_capacity(lower);
187 for text in iter {
188 results.push(embed_passage(embedder, text)?);
189 }
190 Ok(results)
191}
192
193fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
194 let mut batches = Vec::new();
195 let mut start = 0usize;
196
197 while start < token_counts.len() {
198 let mut end = start + 1;
199 let mut max_tokens = token_counts[start].max(1);
200
201 while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
202 let candidate_max = max_tokens.max(token_counts[end].max(1));
203 let candidate_len = end + 1 - start;
204 if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
205 break;
206 }
207 max_tokens = candidate_max;
208 end += 1;
209 }
210
211 batches.push((start, end));
212 start = end;
213 }
214
215 batches
216}
217
218pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
222 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
223}
224
225#[cfg(test)]
226mod testes {
227 use super::*;
228 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
229
230 #[test]
233 fn f32_to_bytes_slice_vazio_retorna_vazio() {
234 let v: Vec<f32> = vec![];
235 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
236 }
237
238 #[test]
239 fn f32_to_bytes_um_elemento_retorna_4_bytes() {
240 let v = vec![1.0_f32];
241 let bytes = f32_to_bytes(&v);
242 assert_eq!(bytes.len(), 4);
243 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
245 assert_eq!(recovered, 1.0_f32);
246 }
247
248 #[test]
249 fn f32_to_bytes_comprimento_e_4x_elementos() {
250 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
251 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
252 }
253
254 #[test]
255 fn f32_to_bytes_zero_codificado_como_4_zeros() {
256 let v = vec![0.0_f32];
257 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
258 }
259
260 #[test]
261 fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
262 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
263 let bytes = f32_to_bytes(&v);
264 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
265 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
267 assert!((first - 0.0_f32).abs() < 1e-6);
268 let last_start = (EMBEDDING_DIM - 1) * 4;
269 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
270 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
271 }
272
273 #[test]
276 fn passage_prefix_nao_vazio() {
277 assert_eq!(PASSAGE_PREFIX, "passage: ");
278 }
279
280 #[test]
281 fn query_prefix_nao_vazio() {
282 assert_eq!(QUERY_PREFIX, "query: ");
283 }
284
285 #[test]
286 fn embedding_dim_e_384() {
287 assert_eq!(EMBEDDING_DIM, 384);
288 }
289
290 #[test]
293 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
294 fn embed_passage_retorna_vetor_com_dimensao_correta() {
295 let dir = tempfile::tempdir().unwrap();
296 let embedder = get_embedder(dir.path()).unwrap();
297 let result = embed_passage(embedder, "texto de teste").unwrap();
298 assert_eq!(result.len(), EMBEDDING_DIM);
299 }
300
301 #[test]
302 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
303 fn embed_query_retorna_vetor_com_dimensao_correta() {
304 let dir = tempfile::tempdir().unwrap();
305 let embedder = get_embedder(dir.path()).unwrap();
306 let result = embed_query(embedder, "consulta de teste").unwrap();
307 assert_eq!(result.len(), EMBEDDING_DIM);
308 }
309
310 #[test]
311 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
312 fn embed_passages_batch_retorna_um_vetor_por_texto() {
313 let dir = tempfile::tempdir().unwrap();
314 let embedder = get_embedder(dir.path()).unwrap();
315 let textos = ["primeiro", "segundo"];
316 let results = embed_passages_batch(embedder, &textos, 2).unwrap();
317 assert_eq!(results.len(), 2);
318 for emb in &results {
319 assert_eq!(emb.len(), EMBEDDING_DIM);
320 }
321 }
322
323 #[test]
324 fn controlled_batch_plan_respeita_orcamento() {
325 assert_eq!(
326 plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
327 vec![(0, 4), (4, 5), (5, 6)]
328 );
329 }
330
331 #[test]
332 fn controlled_batch_count_retorna_um_para_chunk_unico() {
333 assert_eq!(controlled_batch_count(&[350]), 1);
334 }
335}