Skip to main content

seekr_code/embedder/
batch.rs

1//! Batch embedding wrapper.
2//!
3//! Groups input texts into batches and calls the underlying Embedder
4//! for batch-optimized inference, improving throughput for index building.
5
6use crate::embedder::traits::Embedder;
7use crate::error::EmbedderError;
8
9/// Batch embedding processor with progress reporting.
10pub struct BatchEmbedder<E: Embedder> {
11    embedder: E,
12    batch_size: usize,
13}
14
15impl<E: Embedder> BatchEmbedder<E> {
16    /// Create a new BatchEmbedder wrapping the given embedder.
17    pub fn new(embedder: E, batch_size: usize) -> Self {
18        Self {
19            embedder,
20            batch_size: batch_size.max(1),
21        }
22    }
23
24    /// Get the embedding dimension.
25    pub fn dimension(&self) -> usize {
26        self.embedder.dimension()
27    }
28
29    /// Embed all texts in batches, calling the progress callback after each batch.
30    ///
31    /// `progress_fn` receives (completed_count, total_count).
32    pub fn embed_all_with_progress<F>(
33        &self,
34        texts: &[String],
35        mut progress_fn: F,
36    ) -> Result<Vec<Vec<f32>>, EmbedderError>
37    where
38        F: FnMut(usize, usize),
39    {
40        let total = texts.len();
41        let mut all_embeddings = Vec::with_capacity(total);
42        let mut completed = 0;
43
44        for chunk in texts.chunks(self.batch_size) {
45            let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
46            let batch_result = self.embedder.embed_batch(&refs)?;
47            all_embeddings.extend(batch_result);
48            completed += chunk.len();
49            progress_fn(completed, total);
50        }
51
52        Ok(all_embeddings)
53    }
54
55    /// Embed all texts in batches without progress reporting.
56    pub fn embed_all(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
57        self.embed_all_with_progress(texts, |_, _| {})
58    }
59
60    /// Get a reference to the inner embedder.
61    pub fn inner(&self) -> &E {
62        &self.embedder
63    }
64}
65
66/// A dummy embedder for testing that produces random-like but deterministic vectors.
67pub struct DummyEmbedder {
68    dim: usize,
69}
70
71impl DummyEmbedder {
72    /// Create a new dummy embedder with the given dimension.
73    pub fn new(dim: usize) -> Self {
74        Self { dim }
75    }
76}
77
78impl Embedder for DummyEmbedder {
79    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
80        // Generate a deterministic pseudo-random embedding based on text content
81        let mut embedding = vec![0.0f32; self.dim];
82        let mut hash: u64 = 5381;
83
84        for byte in text.bytes() {
85            hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
86        }
87
88        for (i, val) in embedding.iter_mut().enumerate() {
89            hash = hash.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
90            *val = ((hash >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
91            // Mix in position
92            let _ = i; // suppress unused warning, position affects hash via iteration
93        }
94
95        // L2 normalize
96        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
97        if norm > 0.0 {
98            for x in &mut embedding {
99                *x /= norm;
100            }
101        }
102
103        Ok(embedding)
104    }
105
106    fn dimension(&self) -> usize {
107        self.dim
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_dummy_embedder() {
117        let embedder = DummyEmbedder::new(384);
118        let embedding = embedder.embed("hello world").unwrap();
119        assert_eq!(embedding.len(), 384);
120
121        // Check L2 norm ≈ 1.0
122        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
123        assert!((norm - 1.0).abs() < 0.01, "Embedding should be L2 normalized");
124    }
125
126    #[test]
127    fn test_dummy_embedder_deterministic() {
128        let embedder = DummyEmbedder::new(384);
129        let e1 = embedder.embed("test").unwrap();
130        let e2 = embedder.embed("test").unwrap();
131        assert_eq!(e1, e2, "Same input should produce same embedding");
132    }
133
134    #[test]
135    fn test_dummy_embedder_different_inputs() {
136        let embedder = DummyEmbedder::new(384);
137        let e1 = embedder.embed("hello").unwrap();
138        let e2 = embedder.embed("world").unwrap();
139        assert_ne!(e1, e2, "Different inputs should produce different embeddings");
140    }
141
142    #[test]
143    fn test_batch_embedder() {
144        let embedder = DummyEmbedder::new(128);
145        let batch = BatchEmbedder::new(embedder, 2);
146
147        let texts: Vec<String> = vec![
148            "hello".to_string(),
149            "world".to_string(),
150            "foo".to_string(),
151            "bar".to_string(),
152            "baz".to_string(),
153        ];
154
155        let mut progress_calls = Vec::new();
156        let results = batch
157            .embed_all_with_progress(&texts, |completed, total| {
158                progress_calls.push((completed, total));
159            })
160            .unwrap();
161
162        assert_eq!(results.len(), 5);
163        assert_eq!(results[0].len(), 128);
164
165        // Should have 3 progress calls (batches of 2, 2, 1)
166        assert_eq!(progress_calls.len(), 3);
167        assert_eq!(progress_calls[0], (2, 5));
168        assert_eq!(progress_calls[1], (4, 5));
169        assert_eq!(progress_calls[2], (5, 5));
170    }
171}