seekr_code/embedder/
batch.rs1use crate::embedder::traits::Embedder;
7use crate::error::EmbedderError;
8
9pub struct BatchEmbedder<E: Embedder> {
11 embedder: E,
12 batch_size: usize,
13}
14
15impl<E: Embedder> BatchEmbedder<E> {
16 pub fn new(embedder: E, batch_size: usize) -> Self {
18 Self {
19 embedder,
20 batch_size: batch_size.max(1),
21 }
22 }
23
24 pub fn dimension(&self) -> usize {
26 self.embedder.dimension()
27 }
28
29 pub fn embed_all_with_progress<F>(
33 &self,
34 texts: &[String],
35 mut progress_fn: F,
36 ) -> Result<Vec<Vec<f32>>, EmbedderError>
37 where
38 F: FnMut(usize, usize),
39 {
40 let total = texts.len();
41 let mut all_embeddings = Vec::with_capacity(total);
42 let mut completed = 0;
43
44 for chunk in texts.chunks(self.batch_size) {
45 let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
46 let batch_result = self.embedder.embed_batch(&refs)?;
47 all_embeddings.extend(batch_result);
48 completed += chunk.len();
49 progress_fn(completed, total);
50 }
51
52 Ok(all_embeddings)
53 }
54
55 pub fn embed_all(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
57 self.embed_all_with_progress(texts, |_, _| {})
58 }
59
60 pub fn inner(&self) -> &E {
62 &self.embedder
63 }
64}
65
66pub struct DummyEmbedder {
68 dim: usize,
69}
70
71impl DummyEmbedder {
72 pub fn new(dim: usize) -> Self {
74 Self { dim }
75 }
76}
77
78impl Embedder for DummyEmbedder {
79 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
80 let mut embedding = vec![0.0f32; self.dim];
82 let mut hash: u64 = 5381;
83
84 for byte in text.bytes() {
85 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
86 }
87
88 for (i, val) in embedding.iter_mut().enumerate() {
89 hash = hash.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
90 *val = ((hash >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
91 let _ = i; }
94
95 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
97 if norm > 0.0 {
98 for x in &mut embedding {
99 *x /= norm;
100 }
101 }
102
103 Ok(embedding)
104 }
105
106 fn dimension(&self) -> usize {
107 self.dim
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_dummy_embedder() {
117 let embedder = DummyEmbedder::new(384);
118 let embedding = embedder.embed("hello world").unwrap();
119 assert_eq!(embedding.len(), 384);
120
121 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
123 assert!((norm - 1.0).abs() < 0.01, "Embedding should be L2 normalized");
124 }
125
126 #[test]
127 fn test_dummy_embedder_deterministic() {
128 let embedder = DummyEmbedder::new(384);
129 let e1 = embedder.embed("test").unwrap();
130 let e2 = embedder.embed("test").unwrap();
131 assert_eq!(e1, e2, "Same input should produce same embedding");
132 }
133
134 #[test]
135 fn test_dummy_embedder_different_inputs() {
136 let embedder = DummyEmbedder::new(384);
137 let e1 = embedder.embed("hello").unwrap();
138 let e2 = embedder.embed("world").unwrap();
139 assert_ne!(e1, e2, "Different inputs should produce different embeddings");
140 }
141
142 #[test]
143 fn test_batch_embedder() {
144 let embedder = DummyEmbedder::new(128);
145 let batch = BatchEmbedder::new(embedder, 2);
146
147 let texts: Vec<String> = vec![
148 "hello".to_string(),
149 "world".to_string(),
150 "foo".to_string(),
151 "bar".to_string(),
152 "baz".to_string(),
153 ];
154
155 let mut progress_calls = Vec::new();
156 let results = batch
157 .embed_all_with_progress(&texts, |completed, total| {
158 progress_calls.push((completed, total));
159 })
160 .unwrap();
161
162 assert_eq!(results.len(), 5);
163 assert_eq!(results[0].len(), 128);
164
165 assert_eq!(progress_calls.len(), 3);
167 assert_eq!(progress_calls[0], (2, 5));
168 assert_eq!(progress_calls[1], (4, 5));
169 assert_eq!(progress_calls[2], (5, 5));
170 }
171}