trueno_rag/multivector/
embedder.rs1use crate::multivector::MultiVectorEmbedding;
7use crate::Result;
8
9pub trait MultiVectorEmbedder: Send + Sync {
27 fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding>;
37
38 fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
43 texts.iter().map(|t| self.embed_tokens(t)).collect()
44 }
45
46 fn token_dimension(&self) -> usize;
48
49 fn max_tokens(&self) -> usize;
51
52 fn model_id(&self) -> &str;
54}
55
56#[derive(Debug, Clone)]
76pub struct MockMultiVectorEmbedder {
77 dim: usize,
78 max_tokens: usize,
79 seed: u64,
80}
81
82impl MockMultiVectorEmbedder {
83 #[must_use]
90 pub fn new(dim: usize, max_tokens: usize) -> Self {
91 Self { dim, max_tokens, seed: 42 }
92 }
93
94 #[must_use]
96 pub fn with_seed(dim: usize, max_tokens: usize, seed: u64) -> Self {
97 Self { dim, max_tokens, seed }
98 }
99
100 fn generate_unit_vector(&self, seed: u64) -> Vec<f32> {
102 let mut vec = Vec::with_capacity(self.dim);
103 let mut rng = seed;
104
105 for _ in 0..self.dim {
106 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
107 let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
108 vec.push(val);
109 }
110
111 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
113 if norm > 0.0 {
114 for v in &mut vec {
115 *v /= norm;
116 }
117 }
118
119 vec
120 }
121
122 fn hash_token(&self, token: &str, index: usize) -> u64 {
124 let mut hash = self.seed;
125 for byte in token.bytes() {
126 hash = hash.wrapping_mul(31).wrapping_add(u64::from(byte));
127 }
128 hash = hash.wrapping_mul(31).wrapping_add(index as u64);
129 hash
130 }
131}
132
133impl MultiVectorEmbedder for MockMultiVectorEmbedder {
134 fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding> {
135 let tokens: Vec<&str> = text.split_whitespace().collect();
136 let num_tokens = tokens.len().min(self.max_tokens);
137
138 if num_tokens == 0 {
139 return Ok(MultiVectorEmbedding::new(Vec::new(), 0, self.dim));
140 }
141
142 let mut embeddings = Vec::with_capacity(num_tokens * self.dim);
143
144 for (i, token) in tokens.iter().take(num_tokens).enumerate() {
145 let token_seed = self.hash_token(token, i);
146 embeddings.extend(self.generate_unit_vector(token_seed));
147 }
148
149 Ok(MultiVectorEmbedding::new(embeddings, num_tokens, self.dim))
150 }
151
152 fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
153 texts.iter().map(|t| self.embed_tokens(t)).collect()
154 }
155
156 fn token_dimension(&self) -> usize {
157 self.dim
158 }
159
160 fn max_tokens(&self) -> usize {
161 self.max_tokens
162 }
163
164 fn model_id(&self) -> &str {
165 "mock-multivector"
166 }
167}
168
169impl<E: MultiVectorEmbedder + ?Sized> MultiVectorEmbedder for Box<E> {
171 fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding> {
172 (**self).embed_tokens(text)
173 }
174
175 fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
176 (**self).embed_tokens_batch(texts)
177 }
178
179 fn token_dimension(&self) -> usize {
180 (**self).token_dimension()
181 }
182
183 fn max_tokens(&self) -> usize {
184 (**self).max_tokens()
185 }
186
187 fn model_id(&self) -> &str {
188 (**self).model_id()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
199 fn test_mock_embedder_new() {
200 let embedder = MockMultiVectorEmbedder::new(128, 512);
201
202 assert_eq!(embedder.token_dimension(), 128);
203 assert_eq!(embedder.max_tokens(), 512);
204 assert_eq!(embedder.model_id(), "mock-multivector");
205 }
206
207 #[test]
208 fn test_mock_embedder_with_seed() {
209 let embedder1 = MockMultiVectorEmbedder::with_seed(128, 512, 123);
210 let embedder2 = MockMultiVectorEmbedder::with_seed(128, 512, 456);
211
212 let emb1 = embedder1.embed_tokens("test").unwrap();
213 let emb2 = embedder2.embed_tokens("test").unwrap();
214
215 assert_ne!(emb1.as_slice(), emb2.as_slice());
217 }
218
219 #[test]
220 fn test_mock_embedder_deterministic() {
221 let embedder = MockMultiVectorEmbedder::new(64, 256);
222
223 let emb1 = embedder.embed_tokens("hello world").unwrap();
224 let emb2 = embedder.embed_tokens("hello world").unwrap();
225
226 assert_eq!(emb1.num_tokens(), emb2.num_tokens());
227 assert_eq!(emb1.as_slice(), emb2.as_slice());
228 }
229
230 #[test]
231 fn test_mock_embedder_token_count() {
232 let embedder = MockMultiVectorEmbedder::new(64, 256);
233
234 let emb = embedder.embed_tokens("one two three four five").unwrap();
235
236 assert_eq!(emb.num_tokens(), 5);
237 assert_eq!(emb.dim(), 64);
238 }
239
240 #[test]
241 fn test_mock_embedder_max_tokens() {
242 let embedder = MockMultiVectorEmbedder::new(64, 3);
243
244 let emb = embedder.embed_tokens("one two three four five six").unwrap();
245
246 assert_eq!(emb.num_tokens(), 3); }
248
249 #[test]
250 fn test_mock_embedder_empty_text() {
251 let embedder = MockMultiVectorEmbedder::new(64, 256);
252
253 let emb = embedder.embed_tokens("").unwrap();
254
255 assert_eq!(emb.num_tokens(), 0);
256 assert!(emb.is_empty());
257 }
258
259 #[test]
260 fn test_mock_embedder_whitespace_only() {
261 let embedder = MockMultiVectorEmbedder::new(64, 256);
262
263 let emb = embedder.embed_tokens(" \t\n ").unwrap();
264
265 assert_eq!(emb.num_tokens(), 0);
266 }
267
268 #[test]
269 fn test_mock_embedder_unit_vectors() {
270 let embedder = MockMultiVectorEmbedder::new(64, 256);
271
272 let emb = embedder.embed_tokens("test token").unwrap();
273
274 for token_emb in emb.tokens() {
276 let norm: f32 = token_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
277 assert!((norm - 1.0).abs() < 0.001, "Token not unit length: norm = {}", norm);
278 }
279 }
280
281 #[test]
282 fn test_mock_embedder_different_tokens() {
283 let embedder = MockMultiVectorEmbedder::new(64, 256);
284
285 let emb = embedder.embed_tokens("hello world").unwrap();
286
287 let token0 = emb.token(0);
289 let token1 = emb.token(1);
290
291 assert_ne!(token0, token1);
292 }
293
294 #[test]
297 fn test_mock_embedder_batch() {
298 let embedder = MockMultiVectorEmbedder::new(64, 256);
299
300 let texts = ["hello", "world", "test"];
301 let embeddings = embedder.embed_tokens_batch(&texts).unwrap();
302
303 assert_eq!(embeddings.len(), 3);
304 assert_eq!(embeddings[0].num_tokens(), 1);
305 assert_eq!(embeddings[1].num_tokens(), 1);
306 assert_eq!(embeddings[2].num_tokens(), 1);
307 }
308
309 #[test]
310 fn test_mock_embedder_batch_consistency() {
311 let embedder = MockMultiVectorEmbedder::new(64, 256);
312
313 let texts = ["hello", "world"];
314 let batch_result = embedder.embed_tokens_batch(&texts).unwrap();
315
316 let single1 = embedder.embed_tokens("hello").unwrap();
317 let single2 = embedder.embed_tokens("world").unwrap();
318
319 assert_eq!(batch_result[0].as_slice(), single1.as_slice());
320 assert_eq!(batch_result[1].as_slice(), single2.as_slice());
321 }
322
323 #[test]
326 fn test_boxed_embedder() {
327 let embedder: Box<dyn MultiVectorEmbedder> =
328 Box::new(MockMultiVectorEmbedder::new(64, 256));
329
330 let emb = embedder.embed_tokens("test").unwrap();
331
332 assert_eq!(emb.num_tokens(), 1);
333 assert_eq!(embedder.token_dimension(), 64);
334 }
335
336 use proptest::prelude::*;
339
340 proptest! {
341 #[test]
342 fn prop_embed_produces_correct_dimensions(
343 dim in 16usize..256,
344 text in "[a-z ]{1,100}"
345 ) {
346 let embedder = MockMultiVectorEmbedder::new(dim, 512);
347 let emb = embedder.embed_tokens(&text).unwrap();
348
349 prop_assert_eq!(emb.dim(), dim);
350 if emb.num_tokens() > 0 {
351 prop_assert_eq!(emb.token(0).len(), dim);
352 }
353 }
354
355 #[test]
356 fn prop_embed_respects_max_tokens(
357 max_tokens in 1usize..10,
358 words in 1usize..20
359 ) {
360 let text: String = (0..words).map(|i| format!("word{}", i)).collect::<Vec<_>>().join(" ");
361 let embedder = MockMultiVectorEmbedder::new(64, max_tokens);
362
363 let emb = embedder.embed_tokens(&text).unwrap();
364
365 prop_assert!(emb.num_tokens() <= max_tokens);
366 }
367
368 #[test]
369 fn prop_embed_is_deterministic(
370 seed in 0u64..10000,
371 text in "[a-z ]{1,50}"
372 ) {
373 let embedder = MockMultiVectorEmbedder::with_seed(64, 256, seed);
374
375 let emb1 = embedder.embed_tokens(&text).unwrap();
376 let emb2 = embedder.embed_tokens(&text).unwrap();
377
378 prop_assert_eq!(emb1.as_slice(), emb2.as_slice());
379 }
380
381 #[test]
382 fn prop_tokens_are_approximately_unit_length(
383 dim in 32usize..128,
384 text in "[a-z]{3,10}"
385 ) {
386 let embedder = MockMultiVectorEmbedder::new(dim, 256);
387 let emb = embedder.embed_tokens(&text).unwrap();
388
389 for token_emb in emb.tokens() {
390 let norm: f32 = token_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
391 prop_assert!((norm - 1.0).abs() < 0.01);
392 }
393 }
394 }
395}