Skip to main content

trueno_rag/multivector/
embedder.rs

1//! Multi-vector embedder trait and implementations
2//!
3//! This module defines the trait for models that produce token-level embeddings
4//! (like ColBERT) and provides a mock implementation for testing.
5
6use crate::multivector::MultiVectorEmbedding;
7use crate::Result;
8
9/// Trait for models that produce token-level embeddings.
10///
11/// Unlike single-vector embedders (which produce one embedding per text),
12/// multi-vector embedders produce one embedding per token, enabling
13/// fine-grained late interaction scoring.
14///
15/// # Example
16///
17/// ```ignore
18/// use trueno_rag::multivector::{MultiVectorEmbedder, MockMultiVectorEmbedder};
19///
20/// let embedder = MockMultiVectorEmbedder::new(128, 512);
21/// let embedding = embedder.embed_tokens("hello world").unwrap();
22///
23/// assert_eq!(embedding.num_tokens(), 2);
24/// assert_eq!(embedding.dim(), 128);
25/// ```
26pub trait MultiVectorEmbedder: Send + Sync {
27    /// Embed text into token-level vectors.
28    ///
29    /// # Arguments
30    ///
31    /// * `text` - Input text to embed
32    ///
33    /// # Returns
34    ///
35    /// A `MultiVectorEmbedding` containing one vector per token.
36    fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding>;
37
38    /// Batch embed multiple texts.
39    ///
40    /// The default implementation calls `embed_tokens` sequentially.
41    /// Implementations may override for more efficient batching.
42    fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
43        texts.iter().map(|t| self.embed_tokens(t)).collect()
44    }
45
46    /// Get the token embedding dimension.
47    fn token_dimension(&self) -> usize;
48
49    /// Get the maximum tokens per document.
50    fn max_tokens(&self) -> usize;
51
52    /// Get the model identifier.
53    fn model_id(&self) -> &str;
54}
55
56/// Mock multi-vector embedder for testing.
57///
58/// Generates deterministic pseudo-random embeddings based on token content.
59/// Useful for testing the retrieval pipeline without requiring a real model.
60///
61/// # Example
62///
63/// ```
64/// use trueno_rag::multivector::MockMultiVectorEmbedder;
65/// use trueno_rag::multivector::MultiVectorEmbedder;
66///
67/// let embedder = MockMultiVectorEmbedder::new(128, 512);
68///
69/// let emb1 = embedder.embed_tokens("hello world").unwrap();
70/// let emb2 = embedder.embed_tokens("hello world").unwrap();
71///
72/// // Same input produces same output
73/// assert_eq!(emb1.as_slice(), emb2.as_slice());
74/// ```
75#[derive(Debug, Clone)]
76pub struct MockMultiVectorEmbedder {
77    dim: usize,
78    max_tokens: usize,
79    seed: u64,
80}
81
82impl MockMultiVectorEmbedder {
83    /// Create a new mock embedder.
84    ///
85    /// # Arguments
86    ///
87    /// * `dim` - Token embedding dimension (e.g., 128 for ColBERT)
88    /// * `max_tokens` - Maximum tokens per document
89    #[must_use]
90    pub fn new(dim: usize, max_tokens: usize) -> Self {
91        Self { dim, max_tokens, seed: 42 }
92    }
93
94    /// Create with a custom seed for different random sequences.
95    #[must_use]
96    pub fn with_seed(dim: usize, max_tokens: usize, seed: u64) -> Self {
97        Self { dim, max_tokens, seed }
98    }
99
100    /// Generate a deterministic unit vector from a seed.
101    fn generate_unit_vector(&self, seed: u64) -> Vec<f32> {
102        let mut vec = Vec::with_capacity(self.dim);
103        let mut rng = seed;
104
105        for _ in 0..self.dim {
106            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
107            let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
108            vec.push(val);
109        }
110
111        // Normalize to unit length
112        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
113        if norm > 0.0 {
114            for v in &mut vec {
115                *v /= norm;
116            }
117        }
118
119        vec
120    }
121
122    /// Hash a token to a seed value.
123    fn hash_token(&self, token: &str, index: usize) -> u64 {
124        let mut hash = self.seed;
125        for byte in token.bytes() {
126            hash = hash.wrapping_mul(31).wrapping_add(u64::from(byte));
127        }
128        hash = hash.wrapping_mul(31).wrapping_add(index as u64);
129        hash
130    }
131}
132
133impl MultiVectorEmbedder for MockMultiVectorEmbedder {
134    fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding> {
135        let tokens: Vec<&str> = text.split_whitespace().collect();
136        let num_tokens = tokens.len().min(self.max_tokens);
137
138        if num_tokens == 0 {
139            return Ok(MultiVectorEmbedding::new(Vec::new(), 0, self.dim));
140        }
141
142        let mut embeddings = Vec::with_capacity(num_tokens * self.dim);
143
144        for (i, token) in tokens.iter().take(num_tokens).enumerate() {
145            let token_seed = self.hash_token(token, i);
146            embeddings.extend(self.generate_unit_vector(token_seed));
147        }
148
149        Ok(MultiVectorEmbedding::new(embeddings, num_tokens, self.dim))
150    }
151
152    fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
153        texts.iter().map(|t| self.embed_tokens(t)).collect()
154    }
155
156    fn token_dimension(&self) -> usize {
157        self.dim
158    }
159
160    fn max_tokens(&self) -> usize {
161        self.max_tokens
162    }
163
164    fn model_id(&self) -> &str {
165        "mock-multivector"
166    }
167}
168
169/// Trait implementation for boxed embedders.
170impl<E: MultiVectorEmbedder + ?Sized> MultiVectorEmbedder for Box<E> {
171    fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding> {
172        (**self).embed_tokens(text)
173    }
174
175    fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
176        (**self).embed_tokens_batch(texts)
177    }
178
179    fn token_dimension(&self) -> usize {
180        (**self).token_dimension()
181    }
182
183    fn max_tokens(&self) -> usize {
184        (**self).max_tokens()
185    }
186
187    fn model_id(&self) -> &str {
188        (**self).model_id()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    // ============ MockMultiVectorEmbedder Tests ============
197
198    #[test]
199    fn test_mock_embedder_new() {
200        let embedder = MockMultiVectorEmbedder::new(128, 512);
201
202        assert_eq!(embedder.token_dimension(), 128);
203        assert_eq!(embedder.max_tokens(), 512);
204        assert_eq!(embedder.model_id(), "mock-multivector");
205    }
206
207    #[test]
208    fn test_mock_embedder_with_seed() {
209        let embedder1 = MockMultiVectorEmbedder::with_seed(128, 512, 123);
210        let embedder2 = MockMultiVectorEmbedder::with_seed(128, 512, 456);
211
212        let emb1 = embedder1.embed_tokens("test").unwrap();
213        let emb2 = embedder2.embed_tokens("test").unwrap();
214
215        // Different seeds should produce different embeddings
216        assert_ne!(emb1.as_slice(), emb2.as_slice());
217    }
218
219    #[test]
220    fn test_mock_embedder_deterministic() {
221        let embedder = MockMultiVectorEmbedder::new(64, 256);
222
223        let emb1 = embedder.embed_tokens("hello world").unwrap();
224        let emb2 = embedder.embed_tokens("hello world").unwrap();
225
226        assert_eq!(emb1.num_tokens(), emb2.num_tokens());
227        assert_eq!(emb1.as_slice(), emb2.as_slice());
228    }
229
230    #[test]
231    fn test_mock_embedder_token_count() {
232        let embedder = MockMultiVectorEmbedder::new(64, 256);
233
234        let emb = embedder.embed_tokens("one two three four five").unwrap();
235
236        assert_eq!(emb.num_tokens(), 5);
237        assert_eq!(emb.dim(), 64);
238    }
239
240    #[test]
241    fn test_mock_embedder_max_tokens() {
242        let embedder = MockMultiVectorEmbedder::new(64, 3);
243
244        let emb = embedder.embed_tokens("one two three four five six").unwrap();
245
246        assert_eq!(emb.num_tokens(), 3); // Capped at max_tokens
247    }
248
249    #[test]
250    fn test_mock_embedder_empty_text() {
251        let embedder = MockMultiVectorEmbedder::new(64, 256);
252
253        let emb = embedder.embed_tokens("").unwrap();
254
255        assert_eq!(emb.num_tokens(), 0);
256        assert!(emb.is_empty());
257    }
258
259    #[test]
260    fn test_mock_embedder_whitespace_only() {
261        let embedder = MockMultiVectorEmbedder::new(64, 256);
262
263        let emb = embedder.embed_tokens("   \t\n   ").unwrap();
264
265        assert_eq!(emb.num_tokens(), 0);
266    }
267
268    #[test]
269    fn test_mock_embedder_unit_vectors() {
270        let embedder = MockMultiVectorEmbedder::new(64, 256);
271
272        let emb = embedder.embed_tokens("test token").unwrap();
273
274        // Each token should be approximately unit length
275        for token_emb in emb.tokens() {
276            let norm: f32 = token_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
277            assert!((norm - 1.0).abs() < 0.001, "Token not unit length: norm = {}", norm);
278        }
279    }
280
281    #[test]
282    fn test_mock_embedder_different_tokens() {
283        let embedder = MockMultiVectorEmbedder::new(64, 256);
284
285        let emb = embedder.embed_tokens("hello world").unwrap();
286
287        // Different tokens should have different embeddings
288        let token0 = emb.token(0);
289        let token1 = emb.token(1);
290
291        assert_ne!(token0, token1);
292    }
293
294    // ============ Batch Embedding Tests ============
295
296    #[test]
297    fn test_mock_embedder_batch() {
298        let embedder = MockMultiVectorEmbedder::new(64, 256);
299
300        let texts = ["hello", "world", "test"];
301        let embeddings = embedder.embed_tokens_batch(&texts).unwrap();
302
303        assert_eq!(embeddings.len(), 3);
304        assert_eq!(embeddings[0].num_tokens(), 1);
305        assert_eq!(embeddings[1].num_tokens(), 1);
306        assert_eq!(embeddings[2].num_tokens(), 1);
307    }
308
309    #[test]
310    fn test_mock_embedder_batch_consistency() {
311        let embedder = MockMultiVectorEmbedder::new(64, 256);
312
313        let texts = ["hello", "world"];
314        let batch_result = embedder.embed_tokens_batch(&texts).unwrap();
315
316        let single1 = embedder.embed_tokens("hello").unwrap();
317        let single2 = embedder.embed_tokens("world").unwrap();
318
319        assert_eq!(batch_result[0].as_slice(), single1.as_slice());
320        assert_eq!(batch_result[1].as_slice(), single2.as_slice());
321    }
322
323    // ============ Box<dyn MultiVectorEmbedder> Tests ============
324
325    #[test]
326    fn test_boxed_embedder() {
327        let embedder: Box<dyn MultiVectorEmbedder> =
328            Box::new(MockMultiVectorEmbedder::new(64, 256));
329
330        let emb = embedder.embed_tokens("test").unwrap();
331
332        assert_eq!(emb.num_tokens(), 1);
333        assert_eq!(embedder.token_dimension(), 64);
334    }
335
336    // ============ Property-Based Tests ============
337
338    use proptest::prelude::*;
339
340    proptest! {
341        #[test]
342        fn prop_embed_produces_correct_dimensions(
343            dim in 16usize..256,
344            text in "[a-z ]{1,100}"
345        ) {
346            let embedder = MockMultiVectorEmbedder::new(dim, 512);
347            let emb = embedder.embed_tokens(&text).unwrap();
348
349            prop_assert_eq!(emb.dim(), dim);
350            if emb.num_tokens() > 0 {
351                prop_assert_eq!(emb.token(0).len(), dim);
352            }
353        }
354
355        #[test]
356        fn prop_embed_respects_max_tokens(
357            max_tokens in 1usize..10,
358            words in 1usize..20
359        ) {
360            let text: String = (0..words).map(|i| format!("word{}", i)).collect::<Vec<_>>().join(" ");
361            let embedder = MockMultiVectorEmbedder::new(64, max_tokens);
362
363            let emb = embedder.embed_tokens(&text).unwrap();
364
365            prop_assert!(emb.num_tokens() <= max_tokens);
366        }
367
368        #[test]
369        fn prop_embed_is_deterministic(
370            seed in 0u64..10000,
371            text in "[a-z ]{1,50}"
372        ) {
373            let embedder = MockMultiVectorEmbedder::with_seed(64, 256, seed);
374
375            let emb1 = embedder.embed_tokens(&text).unwrap();
376            let emb2 = embedder.embed_tokens(&text).unwrap();
377
378            prop_assert_eq!(emb1.as_slice(), emb2.as_slice());
379        }
380
381        #[test]
382        fn prop_tokens_are_approximately_unit_length(
383            dim in 32usize..128,
384            text in "[a-z]{3,10}"
385        ) {
386            let embedder = MockMultiVectorEmbedder::new(dim, 256);
387            let emb = embedder.embed_tokens(&text).unwrap();
388
389            for token_emb in emb.tokens() {
390                let norm: f32 = token_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
391                prop_assert!((norm - 1.0).abs() < 0.01);
392            }
393        }
394    }
395}