traitclaw_rag/
embedding.rs1use async_trait::async_trait;
37use traitclaw_core::{Error, Result};
38
39use crate::{Document, Retriever};
40
41#[async_trait]
45pub trait EmbeddingProvider: Send + Sync + 'static {
46 async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f64>>>;
51}
52
53struct VectorEntry {
55 embedding: Vec<f64>,
56 document: Document,
57}
58
59pub struct EmbeddingRetriever<P: EmbeddingProvider> {
64 provider: P,
65 store: Vec<VectorEntry>,
66 similarity_threshold: f64,
67}
68
69impl<P: EmbeddingProvider> EmbeddingRetriever<P> {
70 #[must_use]
72 pub fn new(provider: P) -> Self {
73 Self {
74 provider,
75 store: Vec::new(),
76 similarity_threshold: 0.0,
77 }
78 }
79
80 #[must_use]
98 pub fn with_similarity_threshold(mut self, threshold: f64) -> Self {
99 self.similarity_threshold = threshold;
100 self
101 }
102
103 pub async fn add_documents(&mut self, documents: Vec<Document>) -> Result<()> {
112 if documents.is_empty() {
113 return Ok(());
114 }
115
116 let texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
117 let embeddings = self.provider.embed(&texts).await?;
118
119 if embeddings.len() != documents.len() {
120 return Err(Error::Runtime(format!(
121 "EmbeddingProvider returned {} embeddings for {} documents",
122 embeddings.len(),
123 documents.len()
124 )));
125 }
126
127 for (doc, emb) in documents.into_iter().zip(embeddings) {
128 self.store.push(VectorEntry {
129 embedding: emb,
130 document: doc,
131 });
132 }
133
134 Ok(())
135 }
136
137 #[must_use]
139 pub fn len(&self) -> usize {
140 self.store.len()
141 }
142
143 #[must_use]
145 pub fn is_empty(&self) -> bool {
146 self.store.is_empty()
147 }
148}
149
150#[async_trait]
151impl<P: EmbeddingProvider> Retriever for EmbeddingRetriever<P> {
152 async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Document>> {
154 if self.store.is_empty() {
155 return Ok(Vec::new());
156 }
157
158 let query_embs = self.provider.embed(&[query]).await?;
159 let query_emb = query_embs
160 .into_iter()
161 .next()
162 .ok_or_else(|| Error::Runtime("EmbeddingProvider returned empty for query".into()))?;
163
164 let mut scored: Vec<(f64, &Document)> = self
165 .store
166 .iter()
167 .map(|entry| {
168 let sim = cosine_similarity(&query_emb, &entry.embedding);
169 (sim, &entry.document)
170 })
171 .filter(|(sim, _)| *sim >= self.similarity_threshold)
172 .collect();
173
174 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
176 scored.truncate(limit);
177
178 let results = scored
179 .into_iter()
180 .map(|(sim, doc)| {
181 let mut d = doc.clone();
182 d.score = sim;
183 d
184 })
185 .collect();
186
187 Ok(results)
188 }
189}
190
191#[allow(clippy::cast_precision_loss)]
195fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
196 if a.len() != b.len() || a.is_empty() {
197 return 0.0;
198 }
199
200 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
201 let mag_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
202 let mag_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
203
204 if mag_a == 0.0 || mag_b == 0.0 {
205 return 0.0;
206 }
207
208 dot / (mag_a * mag_b)
209}
210
211#[cfg(test)]
216pub(crate) mod test_helpers {
217 use std::sync::atomic::{AtomicUsize, Ordering};
218 use std::sync::Arc;
219
220 use super::*;
221
222 pub struct CountingEmbedder {
224 pub call_count: Arc<AtomicUsize>,
225 #[allow(dead_code)]
226 pub dim: usize,
227 }
228
229 impl CountingEmbedder {
230 pub fn new(dim: usize) -> Self {
231 Self {
232 call_count: Arc::new(AtomicUsize::new(0)),
233 dim,
234 }
235 }
236 }
237
238 #[async_trait]
239 impl EmbeddingProvider for CountingEmbedder {
240 async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f64>>> {
241 self.call_count.fetch_add(1, Ordering::Relaxed);
242 Ok(texts
244 .iter()
245 .map(|t| {
246 let base = (t.len() % 10) as f64 / 10.0;
247 vec![base, 1.0 - base, 0.5]
248 })
249 .collect())
250 }
251 }
252
253 pub struct FixedEmbedder(pub Vec<Vec<f64>>);
255
256 #[async_trait]
257 impl EmbeddingProvider for FixedEmbedder {
258 async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f64>>> {
259 Ok(texts
261 .iter()
262 .enumerate()
263 .map(|(i, _)| self.0[i % self.0.len()].clone())
264 .collect())
265 }
266 }
267}
268
269#[cfg(test)]
274mod tests {
275 use std::sync::atomic::Ordering;
276 use std::sync::Arc;
277
278 use super::test_helpers::*;
279 use super::*;
280 use crate::Document;
281
282 fn make_docs(n: usize) -> Vec<Document> {
283 (0..n)
284 .map(|i| Document::new(format!("doc{i}"), format!("document content {i}")))
285 .collect()
286 }
287
288 #[tokio::test]
289 async fn test_add_documents_calls_embed_once() {
290 let embedder = CountingEmbedder::new(3);
292 let count = embedder.call_count.clone();
293 let mut retriever = EmbeddingRetriever::new(embedder);
294 retriever.add_documents(make_docs(10)).await.unwrap();
295
296 assert_eq!(
297 count.load(Ordering::Relaxed),
298 1,
299 "embed should be called exactly once"
300 );
301 assert_eq!(retriever.len(), 10);
302 }
303
304 #[tokio::test]
305 async fn test_retrieve_returns_at_most_limit() {
306 let mut retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
308 retriever.add_documents(make_docs(10)).await.unwrap();
309
310 let results = retriever.retrieve("content", 3).await.unwrap();
311 assert!(
312 results.len() <= 3,
313 "expected ≤3 results, got {}",
314 results.len()
315 );
316 }
317
318 #[tokio::test]
319 async fn test_retrieve_sorted_by_similarity_desc() {
320 let mut retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
322 retriever.add_documents(make_docs(5)).await.unwrap();
323
324 let results = retriever.retrieve("query", 5).await.unwrap();
325 for window in results.windows(2) {
326 assert!(
327 window[0].score >= window[1].score,
328 "results not sorted: {} < {}",
329 window[0].score,
330 window[1].score
331 );
332 }
333 }
334
335 #[tokio::test]
336 async fn test_similarity_threshold_filters_results() {
337 let vecs = vec![
339 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.7, 0.7, 0.0], ];
343
344 let mut retriever_low =
345 EmbeddingRetriever::new(FixedEmbedder(vecs.clone())).with_similarity_threshold(0.0);
346 retriever_low.add_documents(make_docs(3)).await.unwrap();
347 let results_low = retriever_low.retrieve("doc", 10).await.unwrap();
348
349 let mut retriever_high =
350 EmbeddingRetriever::new(FixedEmbedder(vecs.clone())).with_similarity_threshold(0.95);
351 retriever_high.add_documents(make_docs(3)).await.unwrap();
352 let results_high = retriever_high.retrieve("doc", 10).await.unwrap();
353
354 assert!(
356 results_high.len() < results_low.len() || results_high.len() <= 1,
357 "high threshold should filter more: low={}, high={}",
358 results_low.len(),
359 results_high.len()
360 );
361 }
362
363 #[tokio::test]
364 async fn test_empty_store_returns_empty() {
365 let retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
366 let results = retriever.retrieve("any query", 10).await.unwrap();
367 assert!(results.is_empty());
368 }
369
370 #[tokio::test]
371 async fn test_add_empty_documents() {
372 let mut retriever = EmbeddingRetriever::new(CountingEmbedder::new(3));
373 retriever.add_documents(vec![]).await.unwrap();
374 assert!(retriever.is_empty());
375 }
376
377 #[test]
378 fn test_cosine_similarity_identical() {
379 let v = vec![1.0, 2.0, 3.0];
380 let sim = cosine_similarity(&v, &v);
381 assert!((sim - 1.0).abs() < 1e-9);
382 }
383
384 #[test]
385 fn test_cosine_similarity_orthogonal() {
386 let a = vec![1.0, 0.0];
387 let b = vec![0.0, 1.0];
388 let sim = cosine_similarity(&a, &b);
389 assert!(sim.abs() < 1e-9);
390 }
391
392 #[test]
393 fn test_cosine_similarity_zero_vector() {
394 let a = vec![0.0, 0.0];
395 let b = vec![1.0, 0.0];
396 assert!(cosine_similarity(&a, &b).abs() < f64::EPSILON);
397 }
398
399 #[test]
400 fn test_embedding_retriever_is_retriever_trait_object() {
401 let r = EmbeddingRetriever::new(CountingEmbedder::new(3));
403 let _: Arc<dyn Retriever> = Arc::new(r);
404 }
405}