1use crate::constants::{
2 EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
3 REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
4};
5use crate::errors::AppError;
6use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
7use ort::execution_providers::CPU;
8use std::path::Path;
9use std::sync::{Mutex, OnceLock};
10
11static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
12
13pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
16 if let Some(m) = EMBEDDER.get() {
17 return Ok(m);
18 }
19
20 let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
24
25 let model = TextEmbedding::try_new(
26 TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
27 .with_execution_providers(vec![cpu_ep])
28 .with_max_length(EMBEDDING_MAX_TOKENS)
29 .with_show_download_progress(true)
30 .with_cache_dir(models_dir.to_path_buf()),
31 )
32 .map_err(|e| AppError::Embedding(e.to_string()))?;
33 let _ = EMBEDDER.set(Mutex::new(model));
35 Ok(EMBEDDER.get().expect("just set above"))
36}
37
38pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
39 let prefixed = format!("{PASSAGE_PREFIX}{text}");
40 let results = embedder
41 .lock()
42 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
43 .embed(vec![prefixed.as_str()], Some(1))
44 .map_err(|e| AppError::Embedding(e.to_string()))?;
45 let emb = results
46 .into_iter()
47 .next()
48 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
49 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
50 Ok(emb)
51}
52
53pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
54 let prefixed = format!("{QUERY_PREFIX}{text}");
55 let results = embedder
56 .lock()
57 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
58 .embed(vec![prefixed.as_str()], Some(1))
59 .map_err(|e| AppError::Embedding(e.to_string()))?;
60 let emb = results
61 .into_iter()
62 .next()
63 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
64 Ok(emb)
65}
66
67pub fn embed_passages_batch(
68 embedder: &Mutex<TextEmbedding>,
69 texts: &[&str],
70 batch_size: usize,
71) -> Result<Vec<Vec<f32>>, AppError> {
72 let prefixed: Vec<String> = texts
73 .iter()
74 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
75 .collect();
76 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
77 let results = embedder
78 .lock()
79 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
80 .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
81 .map_err(|e| AppError::Embedding(e.to_string()))?;
82 for emb in &results {
83 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
84 }
85 Ok(results)
86}
87
88pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
89 plan_controlled_batches(token_counts).len()
90}
91
92pub fn embed_passages_controlled(
93 embedder: &Mutex<TextEmbedding>,
94 texts: &[&str],
95 token_counts: &[usize],
96) -> Result<Vec<Vec<f32>>, AppError> {
97 if texts.len() != token_counts.len() {
98 return Err(AppError::Internal(anyhow::anyhow!(
99 "texts/token_counts length mismatch in controlled embedding"
100 )));
101 }
102
103 let mut results = Vec::with_capacity(texts.len());
104 for (start, end) in plan_controlled_batches(token_counts) {
105 if end - start == 1 {
106 results.push(embed_passage(embedder, texts[start])?);
107 continue;
108 }
109
110 results.extend(embed_passages_batch(
111 embedder,
112 &texts[start..end],
113 end - start,
114 )?);
115 }
116
117 Ok(results)
118}
119
120pub fn embed_passages_serial<'a, I>(
125 embedder: &Mutex<TextEmbedding>,
126 texts: I,
127) -> Result<Vec<Vec<f32>>, AppError>
128where
129 I: IntoIterator<Item = &'a str>,
130{
131 let iter = texts.into_iter();
132 let (lower, _) = iter.size_hint();
133 let mut results = Vec::with_capacity(lower);
134 for text in iter {
135 results.push(embed_passage(embedder, text)?);
136 }
137 Ok(results)
138}
139
140fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
141 let mut batches = Vec::new();
142 let mut start = 0usize;
143
144 while start < token_counts.len() {
145 let mut end = start + 1;
146 let mut max_tokens = token_counts[start].max(1);
147
148 while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
149 let candidate_max = max_tokens.max(token_counts[end].max(1));
150 let candidate_len = end + 1 - start;
151 if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
152 break;
153 }
154 max_tokens = candidate_max;
155 end += 1;
156 }
157
158 batches.push((start, end));
159 start = end;
160 }
161
162 batches
163}
164
165pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
169 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
170}
171
172#[cfg(test)]
173mod testes {
174 use super::*;
175 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
176
177 #[test]
180 fn f32_to_bytes_slice_vazio_retorna_vazio() {
181 let v: Vec<f32> = vec![];
182 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
183 }
184
185 #[test]
186 fn f32_to_bytes_um_elemento_retorna_4_bytes() {
187 let v = vec![1.0_f32];
188 let bytes = f32_to_bytes(&v);
189 assert_eq!(bytes.len(), 4);
190 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
192 assert_eq!(recovered, 1.0_f32);
193 }
194
195 #[test]
196 fn f32_to_bytes_comprimento_e_4x_elementos() {
197 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
198 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
199 }
200
201 #[test]
202 fn f32_to_bytes_zero_codificado_como_4_zeros() {
203 let v = vec![0.0_f32];
204 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
205 }
206
207 #[test]
208 fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
209 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
210 let bytes = f32_to_bytes(&v);
211 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
212 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
214 assert!((first - 0.0_f32).abs() < 1e-6);
215 let last_start = (EMBEDDING_DIM - 1) * 4;
216 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
217 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
218 }
219
220 #[test]
223 fn passage_prefix_nao_vazio() {
224 assert_eq!(PASSAGE_PREFIX, "passage: ");
225 }
226
227 #[test]
228 fn query_prefix_nao_vazio() {
229 assert_eq!(QUERY_PREFIX, "query: ");
230 }
231
232 #[test]
233 fn embedding_dim_e_384() {
234 assert_eq!(EMBEDDING_DIM, 384);
235 }
236
237 #[test]
240 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
241 fn embed_passage_retorna_vetor_com_dimensao_correta() {
242 let dir = tempfile::tempdir().unwrap();
243 let embedder = get_embedder(dir.path()).unwrap();
244 let result = embed_passage(embedder, "texto de teste").unwrap();
245 assert_eq!(result.len(), EMBEDDING_DIM);
246 }
247
248 #[test]
249 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
250 fn embed_query_retorna_vetor_com_dimensao_correta() {
251 let dir = tempfile::tempdir().unwrap();
252 let embedder = get_embedder(dir.path()).unwrap();
253 let result = embed_query(embedder, "consulta de teste").unwrap();
254 assert_eq!(result.len(), EMBEDDING_DIM);
255 }
256
257 #[test]
258 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
259 fn embed_passages_batch_retorna_um_vetor_por_texto() {
260 let dir = tempfile::tempdir().unwrap();
261 let embedder = get_embedder(dir.path()).unwrap();
262 let textos = ["primeiro", "segundo"];
263 let results = embed_passages_batch(embedder, &textos, 2).unwrap();
264 assert_eq!(results.len(), 2);
265 for emb in &results {
266 assert_eq!(emb.len(), EMBEDDING_DIM);
267 }
268 }
269
270 #[test]
271 fn controlled_batch_plan_respeita_orcamento() {
272 assert_eq!(
273 plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
274 vec![(0, 4), (4, 5), (5, 6)]
275 );
276 }
277
278 #[test]
279 fn controlled_batch_count_retorna_um_para_chunk_unico() {
280 assert_eq!(controlled_batch_count(&[350]), 1);
281 }
282}