1#[cfg(feature = "nemotron")]
4mod nemotron;
5#[cfg(feature = "nemotron")]
6pub use nemotron::{NemotronConfig, NemotronEmbedder};
7
8use crate::{Chunk, Error, Result};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum PoolingStrategy {
15 Cls,
17 Mean,
19 WeightedMean,
21 LastToken,
23}
24
25impl Default for PoolingStrategy {
26 fn default() -> Self {
27 Self::Mean
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct EmbeddingConfig {
34 pub normalize: bool,
36 pub query_prefix: Option<String>,
38 pub document_prefix: Option<String>,
40 pub max_length: usize,
42 pub pooling: PoolingStrategy,
44}
45
46impl Default for EmbeddingConfig {
47 fn default() -> Self {
48 Self {
49 normalize: true,
50 query_prefix: None,
51 document_prefix: None,
52 max_length: 512,
53 pooling: PoolingStrategy::Mean,
54 }
55 }
56}
57
58#[async_trait]
60pub trait Embedder: Send + Sync {
61 fn embed(&self, text: &str) -> Result<Vec<f32>>;
63
64 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
66
67 fn dimension(&self) -> usize;
69
70 fn model_id(&self) -> &str;
72
73 fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
75 self.embed(query)
76 }
77
78 fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
80 self.embed(document)
81 }
82
83 fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
85 let texts: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
86 let embeddings = self.embed_batch(&texts)?;
87
88 for (chunk, embedding) in chunks.iter_mut().zip(embeddings) {
89 chunk.set_embedding(embedding);
90 }
91
92 Ok(())
93 }
94}
95
96impl Embedder for Box<dyn Embedder> {
99 fn embed(&self, text: &str) -> Result<Vec<f32>> {
100 (**self).embed(text)
101 }
102
103 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
104 (**self).embed_batch(texts)
105 }
106
107 fn dimension(&self) -> usize {
108 (**self).dimension()
109 }
110
111 fn model_id(&self) -> &str {
112 (**self).model_id()
113 }
114
115 fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
116 (**self).embed_query(query)
117 }
118
119 fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
120 (**self).embed_document(document)
121 }
122
123 fn embed_chunks(&self, chunks: &mut [Chunk]) -> Result<()> {
124 (**self).embed_chunks(chunks)
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct MockEmbedder {
131 dimension: usize,
132 model_id: String,
133 config: EmbeddingConfig,
134}
135
136impl MockEmbedder {
137 #[must_use]
139 pub fn new(dimension: usize) -> Self {
140 Self {
141 dimension,
142 model_id: "mock-embedder".to_string(),
143 config: EmbeddingConfig::default(),
144 }
145 }
146
147 #[must_use]
149 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
150 self.model_id = model_id.into();
151 self
152 }
153
154 #[must_use]
156 pub fn with_config(mut self, config: EmbeddingConfig) -> Self {
157 self.config = config;
158 self
159 }
160
161 fn hash_to_vector(&self, text: &str) -> Vec<f32> {
162 use std::collections::hash_map::DefaultHasher;
163 use std::hash::{Hash, Hasher};
164
165 let mut vector = Vec::with_capacity(self.dimension);
166 let mut hasher = DefaultHasher::new();
167
168 for i in 0..self.dimension {
169 text.hash(&mut hasher);
170 i.hash(&mut hasher);
171 let hash = hasher.finish();
172 let value = (hash as f32 / u64::MAX as f32) * 2.0 - 1.0;
174 vector.push(value);
175 }
176
177 if self.config.normalize {
178 Self::normalize_vector(&mut vector);
179 }
180
181 vector
182 }
183
184 fn normalize_vector(vector: &mut [f32]) {
185 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
186 if norm > 0.0 {
187 for x in vector.iter_mut() {
188 *x /= norm;
189 }
190 }
191 }
192}
193
194impl Embedder for MockEmbedder {
195 fn embed(&self, text: &str) -> Result<Vec<f32>> {
196 if text.is_empty() {
197 return Err(Error::EmptyDocument("empty text for embedding".to_string()));
198 }
199
200 let prefixed = if let Some(prefix) = &self.config.document_prefix {
201 format!("{prefix}{text}")
202 } else {
203 text.to_string()
204 };
205
206 Ok(self.hash_to_vector(&prefixed))
207 }
208
209 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
210 texts.iter().map(|t| self.embed(t)).collect()
211 }
212
213 fn dimension(&self) -> usize {
214 self.dimension
215 }
216
217 fn model_id(&self) -> &str {
218 &self.model_id
219 }
220
221 fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
222 if query.is_empty() {
223 return Err(Error::Query("empty query".to_string()));
224 }
225
226 let prefixed = if let Some(prefix) = &self.config.query_prefix {
227 format!("{prefix}{query}")
228 } else {
229 query.to_string()
230 };
231
232 Ok(self.hash_to_vector(&prefixed))
233 }
234}
235
236#[derive(Debug, Clone)]
238pub struct TfIdfEmbedder {
239 dimension: usize,
240 vocabulary: std::collections::HashMap<String, usize>,
241 idf: Vec<f32>,
242}
243
244impl TfIdfEmbedder {
245 #[must_use]
247 pub fn new(dimension: usize) -> Self {
248 Self { dimension, vocabulary: std::collections::HashMap::new(), idf: Vec::new() }
249 }
250
251 pub fn fit(&mut self, documents: &[&str]) {
253 use std::collections::{HashMap, HashSet};
254
255 let mut doc_freq: HashMap<String, usize> = HashMap::new();
256 let mut all_terms: HashSet<String> = HashSet::new();
257
258 for doc in documents {
259 let terms: HashSet<String> = doc.split_whitespace().map(|s| s.to_lowercase()).collect();
260
261 for term in &terms {
262 *doc_freq.entry(term.clone()).or_insert(0) += 1;
263 all_terms.insert(term.clone());
264 }
265 }
266
267 let mut terms: Vec<_> = all_terms.into_iter().collect();
269 terms.sort_by_key(|t| std::cmp::Reverse(doc_freq.get(t).copied().unwrap_or(0)));
270 terms.truncate(self.dimension);
271
272 self.vocabulary = terms.iter().enumerate().map(|(i, t)| (t.clone(), i)).collect();
273
274 let n = documents.len() as f32;
276 self.idf = terms
277 .iter()
278 .map(|t| {
279 let df = doc_freq.get(t).copied().unwrap_or(1) as f32;
280 (n / df).max(f32::EPSILON).ln() + 1.0
281 })
282 .collect();
283 }
284
285 fn compute_tf(&self, text: &str) -> Vec<f32> {
286 let mut tf = vec![0.0f32; self.dimension];
287 let terms: Vec<String> = text.split_whitespace().map(|s| s.to_lowercase()).collect();
288 let total = terms.len() as f32;
289
290 for term in terms {
291 if let Some(&idx) = self.vocabulary.get(&term) {
292 tf[idx] += 1.0 / total;
293 }
294 }
295
296 tf
297 }
298}
299
300impl Embedder for TfIdfEmbedder {
301 fn embed(&self, text: &str) -> Result<Vec<f32>> {
302 if text.is_empty() {
303 return Err(Error::EmptyDocument("empty text".to_string()));
304 }
305
306 if self.vocabulary.is_empty() {
307 return Err(Error::InvalidConfig("embedder not trained".to_string()));
308 }
309
310 let tf = self.compute_tf(text);
311 let mut tfidf: Vec<f32> = tf.iter().zip(self.idf.iter()).map(|(t, i)| t * i).collect();
312
313 let norm: f32 = tfidf.iter().map(|x| x * x).sum::<f32>().sqrt();
315 if norm > 0.0 {
316 for x in &mut tfidf {
317 *x /= norm;
318 }
319 }
320
321 tfidf.resize(self.dimension, 0.0);
323 Ok(tfidf)
324 }
325
326 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
327 texts.iter().map(|t| self.embed(t)).collect()
328 }
329
330 fn dimension(&self) -> usize {
331 self.dimension
332 }
333
334 fn model_id(&self) -> &str {
335 "tfidf"
336 }
337}
338
339#[must_use]
341fn l2_norm(v: &[f32]) -> f32 {
343 v.iter().map(|x| x * x).sum::<f32>().sqrt()
344}
345
346fn safe_divide(numerator: f32, denominator: f32) -> f32 {
348 if denominator == 0.0 {
349 0.0
350 } else {
351 numerator / denominator
352 }
353}
354
355pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
357 if a.len() != b.len() {
358 return 0.0;
359 }
360 safe_divide(dot_product(a, b), l2_norm(a) * l2_norm(b))
361}
362
363#[must_use]
365pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
366 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
367}
368
369#[must_use]
371pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
372 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
373}
374
375#[cfg(feature = "embeddings")]
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382pub enum EmbeddingModelType {
383 AllMiniLmL6V2,
385 AllMiniLmL12V2,
387 BgeSmallEnV15,
389 BgeBaseEnV15,
391 NomicEmbedTextV1,
393}
394
395#[cfg(feature = "embeddings")]
396impl Default for EmbeddingModelType {
397 fn default() -> Self {
398 Self::AllMiniLmL6V2
399 }
400}
401
402#[cfg(feature = "embeddings")]
403impl EmbeddingModelType {
404 fn to_fastembed_model(self) -> fastembed::EmbeddingModel {
406 match self {
407 Self::AllMiniLmL6V2 => fastembed::EmbeddingModel::AllMiniLML6V2,
408 Self::AllMiniLmL12V2 => fastembed::EmbeddingModel::AllMiniLML12V2,
409 Self::BgeSmallEnV15 => fastembed::EmbeddingModel::BGESmallENV15,
410 Self::BgeBaseEnV15 => fastembed::EmbeddingModel::BGEBaseENV15,
411 Self::NomicEmbedTextV1 => fastembed::EmbeddingModel::NomicEmbedTextV1,
412 }
413 }
414
415 #[must_use]
417 pub const fn dimension(self) -> usize {
418 match self {
419 Self::AllMiniLmL6V2 | Self::AllMiniLmL12V2 | Self::BgeSmallEnV15 => 384,
420 Self::BgeBaseEnV15 | Self::NomicEmbedTextV1 => 768,
421 }
422 }
423
424 #[must_use]
426 pub const fn model_name(self) -> &'static str {
427 match self {
428 Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
429 Self::AllMiniLmL12V2 => "sentence-transformers/all-MiniLM-L12-v2",
430 Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
431 Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
432 Self::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
433 }
434 }
435}
436
437#[cfg(feature = "embeddings")]
451#[derive(Clone)]
452pub struct FastEmbedder {
453 model: std::sync::Arc<std::sync::Mutex<fastembed::TextEmbedding>>,
454 model_type: EmbeddingModelType,
455}
456
457#[cfg(feature = "embeddings")]
458impl std::fmt::Debug for FastEmbedder {
459 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460 f.debug_struct("FastEmbedder")
461 .field("model_type", &self.model_type)
462 .field("dimension", &self.model_type.dimension())
463 .finish_non_exhaustive() }
465}
466
467#[cfg(feature = "embeddings")]
468impl FastEmbedder {
469 pub fn new(model_type: EmbeddingModelType) -> Result<Self> {
476 let options = fastembed::InitOptions::new(model_type.to_fastembed_model())
477 .with_show_download_progress(true);
478
479 let model = fastembed::TextEmbedding::try_new(options).map_err(|e| {
480 Error::InvalidConfig(format!("Failed to initialize embedding model: {e}"))
481 })?;
482
483 Ok(Self { model: std::sync::Arc::new(std::sync::Mutex::new(model)), model_type })
484 }
485
486 pub fn default_model() -> Result<Self> {
491 Self::new(EmbeddingModelType::default())
492 }
493
494 #[must_use]
496 pub fn model_type(&self) -> EmbeddingModelType {
497 self.model_type
498 }
499}
500
501#[cfg(feature = "embeddings")]
502impl Embedder for FastEmbedder {
503 fn embed(&self, text: &str) -> Result<Vec<f32>> {
504 if text.is_empty() {
505 return Err(Error::EmptyDocument("empty text for embedding".to_string()));
506 }
507
508 let mut model =
509 self.model.lock().map_err(|e| Error::Embedding(format!("lock failed: {e}")))?;
510
511 let embeddings = model
512 .embed(vec![text], None)
513 .map_err(|e| Error::Embedding(format!("embedding failed: {e}")))?;
514
515 embeddings
516 .into_iter()
517 .next()
518 .ok_or_else(|| Error::Embedding("no embedding returned".to_string()))
519 }
520
521 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
522 if texts.is_empty() {
523 return Ok(Vec::new());
524 }
525
526 let non_empty: Vec<&str> = texts.iter().copied().filter(|t| !t.is_empty()).collect();
528 if non_empty.is_empty() {
529 return Err(Error::EmptyDocument("all texts are empty".to_string()));
530 }
531
532 let mut model =
533 self.model.lock().map_err(|e| Error::Embedding(format!("lock failed: {e}")))?;
534
535 model
536 .embed(non_empty, None)
537 .map_err(|e| Error::Embedding(format!("batch embedding failed: {e}")))
538 }
539
540 fn dimension(&self) -> usize {
541 self.model_type.dimension()
542 }
543
544 fn model_id(&self) -> &str {
545 self.model_type.model_name()
546 }
547
548 fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
549 self.embed(query)
551 }
552
553 fn embed_document(&self, document: &str) -> Result<Vec<f32>> {
554 self.embed(document)
555 }
556}
557
558#[cfg(test)]
559mod tests;