seekr_code/embedder/
batch.rs1use crate::embedder::traits::Embedder;
7use crate::error::EmbedderError;
8
9pub struct BatchEmbedder<E: Embedder> {
11 embedder: E,
12 batch_size: usize,
13}
14
15impl<E: Embedder> BatchEmbedder<E> {
16 pub fn new(embedder: E, batch_size: usize) -> Self {
18 Self {
19 embedder,
20 batch_size: batch_size.max(1),
21 }
22 }
23
24 pub fn dimension(&self) -> usize {
26 self.embedder.dimension()
27 }
28
29 pub fn embed_all_with_progress<F>(
33 &self,
34 texts: &[String],
35 mut progress_fn: F,
36 ) -> Result<Vec<Vec<f32>>, EmbedderError>
37 where
38 F: FnMut(usize, usize),
39 {
40 let total = texts.len();
41 let mut all_embeddings = Vec::with_capacity(total);
42 let mut completed = 0;
43
44 for chunk in texts.chunks(self.batch_size) {
45 let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
46 let batch_result = self.embedder.embed_batch(&refs)?;
47 all_embeddings.extend(batch_result);
48 completed += chunk.len();
49 progress_fn(completed, total);
50 }
51
52 Ok(all_embeddings)
53 }
54
55 pub fn embed_all(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
57 self.embed_all_with_progress(texts, |_, _| {})
58 }
59
60 pub fn inner(&self) -> &E {
62 &self.embedder
63 }
64}
65
66pub struct DummyEmbedder {
68 dim: usize,
69}
70
71impl DummyEmbedder {
72 pub fn new(dim: usize) -> Self {
74 Self { dim }
75 }
76}
77
78impl Embedder for DummyEmbedder {
79 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
80 let mut embedding = vec![0.0f32; self.dim];
82 let mut hash: u64 = 5381;
83
84 for byte in text.bytes() {
85 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
86 }
87
88 for (i, val) in embedding.iter_mut().enumerate() {
89 hash = hash
90 .wrapping_mul(6364136223846793005)
91 .wrapping_add(1442695040888963407);
92 *val = ((hash >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
93 let _ = i; }
96
97 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
99 if norm > 0.0 {
100 for x in &mut embedding {
101 *x /= norm;
102 }
103 }
104
105 Ok(embedding)
106 }
107
108 fn dimension(&self) -> usize {
109 self.dim
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_dummy_embedder() {
119 let embedder = DummyEmbedder::new(384);
120 let embedding = embedder.embed("hello world").unwrap();
121 assert_eq!(embedding.len(), 384);
122
123 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
125 assert!(
126 (norm - 1.0).abs() < 0.01,
127 "Embedding should be L2 normalized"
128 );
129 }
130
131 #[test]
132 fn test_dummy_embedder_deterministic() {
133 let embedder = DummyEmbedder::new(384);
134 let e1 = embedder.embed("test").unwrap();
135 let e2 = embedder.embed("test").unwrap();
136 assert_eq!(e1, e2, "Same input should produce same embedding");
137 }
138
139 #[test]
140 fn test_dummy_embedder_different_inputs() {
141 let embedder = DummyEmbedder::new(384);
142 let e1 = embedder.embed("hello").unwrap();
143 let e2 = embedder.embed("world").unwrap();
144 assert_ne!(
145 e1, e2,
146 "Different inputs should produce different embeddings"
147 );
148 }
149
150 #[test]
151 fn test_batch_embedder() {
152 let embedder = DummyEmbedder::new(128);
153 let batch = BatchEmbedder::new(embedder, 2);
154
155 let texts: Vec<String> = vec![
156 "hello".to_string(),
157 "world".to_string(),
158 "foo".to_string(),
159 "bar".to_string(),
160 "baz".to_string(),
161 ];
162
163 let mut progress_calls = Vec::new();
164 let results = batch
165 .embed_all_with_progress(&texts, |completed, total| {
166 progress_calls.push((completed, total));
167 })
168 .unwrap();
169
170 assert_eq!(results.len(), 5);
171 assert_eq!(results[0].len(), 128);
172
173 assert_eq!(progress_calls.len(), 3);
175 assert_eq!(progress_calls[0], (2, 5));
176 assert_eq!(progress_calls[1], (4, 5));
177 assert_eq!(progress_calls[2], (5, 5));
178 }
179}