1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct MultiVectorEmbedding {
34 embeddings: Vec<f32>,
36 num_tokens: usize,
38 dim: usize,
40}
41
42impl MultiVectorEmbedding {
43 #[must_use]
49 pub fn new(embeddings: Vec<f32>, num_tokens: usize, dim: usize) -> Self {
50 assert_eq!(
51 embeddings.len(),
52 num_tokens * dim,
53 "Embedding size mismatch: expected {} ({}×{}), got {}",
54 num_tokens * dim,
55 num_tokens,
56 dim,
57 embeddings.len()
58 );
59 Self { embeddings, num_tokens, dim }
60 }
61
62 #[must_use]
64 pub fn from_tokens(tokens: &[Vec<f32>]) -> Self {
65 if tokens.is_empty() {
66 return Self { embeddings: Vec::new(), num_tokens: 0, dim: 0 };
67 }
68
69 let dim = tokens[0].len();
70 let num_tokens = tokens.len();
71 let mut embeddings = Vec::with_capacity(num_tokens * dim);
72
73 for token in tokens {
74 assert_eq!(token.len(), dim, "All tokens must have the same dimension");
75 embeddings.extend_from_slice(token);
76 }
77
78 Self { embeddings, num_tokens, dim }
79 }
80
81 #[must_use]
83 pub fn num_tokens(&self) -> usize {
84 self.num_tokens
85 }
86
87 #[must_use]
89 pub fn dim(&self) -> usize {
90 self.dim
91 }
92
93 #[must_use]
99 pub fn token(&self, i: usize) -> &[f32] {
100 assert!(i < self.num_tokens, "Token index out of bounds");
101 let start = i * self.dim;
102 &self.embeddings[start..start + self.dim]
103 }
104
105 pub fn tokens(&self) -> impl Iterator<Item = &[f32]> {
107 self.embeddings.chunks_exact(self.dim)
108 }
109
110 #[must_use]
112 pub fn as_slice(&self) -> &[f32] {
113 &self.embeddings
114 }
115
116 pub fn as_mut_slice(&mut self) -> &mut [f32] {
118 &mut self.embeddings
119 }
120
121 #[must_use]
123 pub fn size_bytes(&self) -> usize {
124 self.embeddings.len() * size_of::<f32>()
125 }
126
127 #[must_use]
129 pub fn is_empty(&self) -> bool {
130 self.num_tokens == 0
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct WarpIndexConfig {
149 pub nbits: u8,
154
155 pub num_centroids: usize,
160
161 pub token_dim: usize,
163
164 pub min_training_samples: Option<usize>,
169
170 pub kmeans_iterations: usize,
172}
173
174impl Default for WarpIndexConfig {
175 fn default() -> Self {
176 Self {
177 nbits: 2,
178 num_centroids: 1024,
179 token_dim: 128,
180 min_training_samples: None,
181 kmeans_iterations: 20,
182 }
183 }
184}
185
186impl WarpIndexConfig {
187 #[must_use]
189 pub fn new(nbits: u8, num_centroids: usize, token_dim: usize) -> Self {
190 Self { nbits, num_centroids, token_dim, ..Default::default() }
191 }
192
193 #[must_use]
195 pub fn with_min_training_samples(mut self, samples: usize) -> Self {
196 self.min_training_samples = Some(samples);
197 self
198 }
199
200 #[must_use]
202 pub fn with_kmeans_iterations(mut self, iterations: usize) -> Self {
203 self.kmeans_iterations = iterations;
204 self
205 }
206
207 #[must_use]
209 pub fn effective_min_training_samples(&self) -> usize {
210 self.min_training_samples.unwrap_or(10 * self.num_centroids)
211 }
212
213 #[must_use]
215 pub fn packed_residual_size(&self) -> usize {
216 (self.token_dim * self.nbits as usize + 7) / 8
217 }
218
219 pub fn validate(&self) -> Result<(), &'static str> {
221 if self.nbits != 2 && self.nbits != 4 {
222 return Err("nbits must be 2 or 4");
223 }
224 if self.num_centroids == 0 {
225 return Err("num_centroids must be > 0");
226 }
227 if self.token_dim == 0 {
228 return Err("token_dim must be > 0");
229 }
230 if self.kmeans_iterations == 0 {
231 return Err("kmeans_iterations must be > 0");
232 }
233 Ok(())
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct WarpSearchConfig {
243 pub k: usize,
245
246 pub nprobe: u32,
251
252 pub bound: usize,
256
257 pub t_prime: Option<usize>,
262
263 pub centroid_score_threshold: f32,
268}
269
270impl Default for WarpSearchConfig {
271 fn default() -> Self {
272 Self { k: 10, nprobe: 4, bound: 128, t_prime: None, centroid_score_threshold: 0.4 }
273 }
274}
275
276impl WarpSearchConfig {
277 #[must_use]
279 pub fn with_k(k: usize) -> Self {
280 Self { k, ..Default::default() }
281 }
282
283 #[must_use]
285 pub fn nprobe(mut self, nprobe: u32) -> Self {
286 self.nprobe = nprobe;
287 self
288 }
289
290 #[must_use]
292 pub fn bound(mut self, bound: usize) -> Self {
293 self.bound = bound;
294 self
295 }
296
297 #[must_use]
299 pub fn t_prime(mut self, t_prime: usize) -> Self {
300 self.t_prime = Some(t_prime);
301 self
302 }
303
304 #[must_use]
306 pub fn centroid_score_threshold(mut self, threshold: f32) -> Self {
307 self.centroid_score_threshold = threshold;
308 self
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
319 fn test_multivector_new() {
320 let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
321 let mv = MultiVectorEmbedding::new(embeddings, 2, 3);
322
323 assert_eq!(mv.num_tokens(), 2);
324 assert_eq!(mv.dim(), 3);
325 assert_eq!(mv.token(0), &[1.0, 2.0, 3.0]);
326 assert_eq!(mv.token(1), &[4.0, 5.0, 6.0]);
327 }
328
329 #[test]
330 #[should_panic(expected = "Embedding size mismatch")]
331 fn test_multivector_size_mismatch() {
332 let embeddings = vec![1.0, 2.0, 3.0];
333 let _ = MultiVectorEmbedding::new(embeddings, 2, 3); }
335
336 #[test]
337 fn test_multivector_from_tokens() {
338 let tokens = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
339 let mv = MultiVectorEmbedding::from_tokens(&tokens);
340
341 assert_eq!(mv.num_tokens(), 3);
342 assert_eq!(mv.dim(), 2);
343 }
344
345 #[test]
346 fn test_multivector_from_tokens_empty() {
347 let tokens: Vec<Vec<f32>> = vec![];
348 let mv = MultiVectorEmbedding::from_tokens(&tokens);
349
350 assert_eq!(mv.num_tokens(), 0);
351 assert!(mv.is_empty());
352 }
353
354 #[test]
355 fn test_multivector_tokens_iterator() {
356 let embeddings = vec![1.0, 2.0, 3.0, 4.0];
357 let mv = MultiVectorEmbedding::new(embeddings, 2, 2);
358
359 let tokens: Vec<&[f32]> = mv.tokens().collect();
360 assert_eq!(tokens.len(), 2);
361 assert_eq!(tokens[0], &[1.0, 2.0]);
362 assert_eq!(tokens[1], &[3.0, 4.0]);
363 }
364
365 #[test]
366 fn test_multivector_size_bytes() {
367 let embeddings = vec![0.0; 100];
368 let mv = MultiVectorEmbedding::new(embeddings, 10, 10);
369
370 assert_eq!(mv.size_bytes(), 100 * 4); }
372
373 #[test]
374 fn test_multivector_as_slice() {
375 let embeddings = vec![1.0, 2.0, 3.0];
376 let mv = MultiVectorEmbedding::new(embeddings.clone(), 1, 3);
377
378 assert_eq!(mv.as_slice(), &[1.0, 2.0, 3.0]);
379 }
380
381 #[test]
382 fn test_multivector_serialization() {
383 let mv = MultiVectorEmbedding::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
384 let json = serde_json::to_string(&mv).unwrap();
385 let deserialized: MultiVectorEmbedding = serde_json::from_str(&json).unwrap();
386
387 assert_eq!(mv.num_tokens(), deserialized.num_tokens());
388 assert_eq!(mv.dim(), deserialized.dim());
389 assert_eq!(mv.as_slice(), deserialized.as_slice());
390 }
391
392 #[test]
395 fn test_index_config_default() {
396 let config = WarpIndexConfig::default();
397
398 assert_eq!(config.nbits, 2);
399 assert_eq!(config.num_centroids, 1024);
400 assert_eq!(config.token_dim, 128);
401 assert_eq!(config.kmeans_iterations, 20);
402 }
403
404 #[test]
405 fn test_index_config_new() {
406 let config = WarpIndexConfig::new(4, 256, 64);
407
408 assert_eq!(config.nbits, 4);
409 assert_eq!(config.num_centroids, 256);
410 assert_eq!(config.token_dim, 64);
411 }
412
413 #[test]
414 fn test_index_config_builders() {
415 let config = WarpIndexConfig::new(2, 512, 128)
416 .with_min_training_samples(5000)
417 .with_kmeans_iterations(30);
418
419 assert_eq!(config.min_training_samples, Some(5000));
420 assert_eq!(config.kmeans_iterations, 30);
421 }
422
423 #[test]
424 fn test_index_config_effective_min_samples() {
425 let config = WarpIndexConfig::new(2, 100, 128);
426 assert_eq!(config.effective_min_training_samples(), 1000); let config = config.with_min_training_samples(500);
429 assert_eq!(config.effective_min_training_samples(), 500);
430 }
431
432 #[test]
433 fn test_index_config_packed_size() {
434 let config = WarpIndexConfig::new(2, 1024, 128);
436 assert_eq!(config.packed_residual_size(), 32);
437
438 let config = WarpIndexConfig::new(4, 1024, 128);
440 assert_eq!(config.packed_residual_size(), 64);
441 }
442
443 #[test]
444 fn test_index_config_validate() {
445 let config = WarpIndexConfig::default();
446 assert!(config.validate().is_ok());
447
448 let bad_nbits = WarpIndexConfig { nbits: 3, ..Default::default() };
449 assert!(bad_nbits.validate().is_err());
450
451 let bad_centroids = WarpIndexConfig { num_centroids: 0, ..Default::default() };
452 assert!(bad_centroids.validate().is_err());
453 }
454
455 #[test]
456 fn test_index_config_serialization() {
457 let config = WarpIndexConfig::new(4, 512, 64);
458 let json = serde_json::to_string(&config).unwrap();
459 let deserialized: WarpIndexConfig = serde_json::from_str(&json).unwrap();
460
461 assert_eq!(config.nbits, deserialized.nbits);
462 assert_eq!(config.num_centroids, deserialized.num_centroids);
463 assert_eq!(config.token_dim, deserialized.token_dim);
464 }
465
466 #[test]
469 fn test_search_config_default() {
470 let config = WarpSearchConfig::default();
471
472 assert_eq!(config.k, 10);
473 assert_eq!(config.nprobe, 4);
474 assert_eq!(config.bound, 128);
475 assert!(config.t_prime.is_none());
476 assert!((config.centroid_score_threshold - 0.4).abs() < 0.001);
477 }
478
479 #[test]
480 fn test_search_config_with_k() {
481 let config = WarpSearchConfig::with_k(20);
482 assert_eq!(config.k, 20);
483 }
484
485 #[test]
486 fn test_search_config_builders() {
487 let config = WarpSearchConfig::with_k(5)
488 .nprobe(8)
489 .bound(256)
490 .t_prime(10)
491 .centroid_score_threshold(0.5);
492
493 assert_eq!(config.k, 5);
494 assert_eq!(config.nprobe, 8);
495 assert_eq!(config.bound, 256);
496 assert_eq!(config.t_prime, Some(10));
497 assert!((config.centroid_score_threshold - 0.5).abs() < 0.001);
498 }
499
500 #[test]
501 fn test_search_config_serialization() {
502 let config = WarpSearchConfig::with_k(15).nprobe(6);
503 let json = serde_json::to_string(&config).unwrap();
504 let deserialized: WarpSearchConfig = serde_json::from_str(&json).unwrap();
505
506 assert_eq!(config.k, deserialized.k);
507 assert_eq!(config.nprobe, deserialized.nprobe);
508 }
509
510 use proptest::prelude::*;
513
514 proptest! {
515 #[test]
516 fn prop_multivector_tokens_count_matches(
517 num_tokens in 1usize..20,
518 dim in 1usize..64
519 ) {
520 let embeddings = vec![0.0f32; num_tokens * dim];
521 let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
522
523 prop_assert_eq!(mv.num_tokens(), num_tokens);
524 prop_assert_eq!(mv.dim(), dim);
525 prop_assert_eq!(mv.tokens().count(), num_tokens);
526 }
527
528 #[test]
529 fn prop_multivector_token_slices_correct_size(
530 num_tokens in 1usize..10,
531 dim in 1usize..32
532 ) {
533 let embeddings = vec![0.0f32; num_tokens * dim];
534 let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
535
536 for i in 0..num_tokens {
537 prop_assert_eq!(mv.token(i).len(), dim);
538 }
539 }
540
541 #[test]
542 fn prop_index_config_packed_size_formula(
543 nbits in prop::sample::select(vec![2u8, 4]),
544 dim in 1usize..256
545 ) {
546 let config = WarpIndexConfig::new(nbits, 1024, dim);
547 let expected = (dim * nbits as usize + 7) / 8;
548 prop_assert_eq!(config.packed_residual_size(), expected);
549 }
550 }
551}