1use crate::multivector::{
12 codec::ResidualCodec,
13 search::{CandidateScorer, CentroidSelector, ScoreMerger},
14 types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig},
15};
16use crate::{Chunk, ChunkId, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct WarpIndex {
45 config: WarpIndexConfig,
47 codec: Option<ResidualCodec>,
49 sizes: Vec<usize>,
51 offsets: Vec<usize>,
53 chunk_ids: Vec<ChunkId>,
55 token_indices: Vec<u16>,
57 residuals: Vec<u8>,
59 #[serde(skip)]
61 chunks: HashMap<ChunkId, Chunk>,
62 #[serde(skip)]
64 pending: Vec<(ChunkId, MultiVectorEmbedding)>,
65 is_built: bool,
67}
68
69impl WarpIndex {
70 #[must_use]
72 pub fn new(config: WarpIndexConfig) -> Self {
73 Self {
74 config,
75 codec: None,
76 sizes: Vec::new(),
77 offsets: Vec::new(),
78 chunk_ids: Vec::new(),
79 token_indices: Vec::new(),
80 residuals: Vec::new(),
81 chunks: HashMap::new(),
82 pending: Vec::new(),
83 is_built: false,
84 }
85 }
86
87 #[must_use]
89 pub fn config(&self) -> &WarpIndexConfig {
90 &self.config
91 }
92
93 #[must_use]
95 pub fn codec(&self) -> Option<&ResidualCodec> {
96 self.codec.as_ref()
97 }
98
99 #[must_use]
101 pub fn is_trained(&self) -> bool {
102 self.codec.is_some()
103 }
104
105 #[must_use]
107 pub fn is_built(&self) -> bool {
108 self.is_built
109 }
110
111 #[must_use]
113 pub fn num_chunks(&self) -> usize {
114 self.chunks.len()
115 }
116
117 #[must_use]
119 pub fn num_tokens(&self) -> usize {
120 self.chunk_ids.len()
121 }
122
123 #[must_use]
125 pub fn is_empty(&self) -> bool {
126 self.chunks.is_empty()
127 }
128
129 #[must_use]
131 pub fn get_chunk(&self, id: &ChunkId) -> Option<&Chunk> {
132 self.chunks.get(id)
133 }
134
135 #[must_use]
137 pub fn memory_usage(&self) -> usize {
138 let codec_size = self
139 .codec
140 .as_ref()
141 .map(|c| {
142 c.centroids().len() * 4 + c.dim() * ((1 << c.nbits()) - 1) * 4 + c.dim() * (1 << c.nbits()) * 4 })
146 .unwrap_or(0);
147
148 let index_size = self.chunk_ids.len() * size_of::<ChunkId>()
149 + self.token_indices.len() * size_of::<u16>()
150 + self.residuals.len()
151 + self.sizes.len() * size_of::<usize>()
152 + self.offsets.len() * size_of::<usize>();
153
154 codec_size + index_size
155 }
156
157 pub fn train(&mut self, samples: &[MultiVectorEmbedding]) -> Result<()> {
169 let total_tokens: usize = samples.iter().map(|s| s.num_tokens()).sum();
171 let min_samples = self.config.effective_min_training_samples();
172
173 if total_tokens < min_samples {
174 return Err(crate::Error::InvalidInput(format!(
175 "Insufficient training tokens: {total_tokens} < {min_samples} required"
176 )));
177 }
178
179 let mut all_embeddings = Vec::with_capacity(total_tokens * self.config.token_dim);
181 for sample in samples {
182 all_embeddings.extend_from_slice(sample.as_slice());
183 }
184
185 let codec = ResidualCodec::train(
187 &all_embeddings,
188 self.config.token_dim,
189 self.config.num_centroids,
190 self.config.nbits,
191 self.config.kmeans_iterations,
192 )?;
193
194 self.codec = Some(codec);
195 Ok(())
196 }
197
198 pub fn insert(&mut self, chunk: Chunk, embedding: MultiVectorEmbedding) -> Result<()> {
208 if self.codec.is_none() {
209 return Err(crate::Error::InvalidInput(
210 "Codec not trained - call train() first".to_string(),
211 ));
212 }
213
214 if self.is_built {
215 return Err(crate::Error::InvalidInput(
216 "Index already built - cannot insert".to_string(),
217 ));
218 }
219
220 let chunk_id = chunk.id;
221 self.chunks.insert(chunk_id, chunk);
222 self.pending.push((chunk_id, embedding));
223
224 Ok(())
225 }
226
227 pub fn build(&mut self) -> Result<()> {
236 let codec = self.codec.as_ref().ok_or_else(|| {
237 crate::Error::InvalidInput("Codec not trained - call train() first".to_string())
238 })?;
239
240 let mut centroid_assignments: Vec<Vec<(ChunkId, u16, Vec<u8>)>> =
242 vec![Vec::new(); self.config.num_centroids];
243
244 for (chunk_id, embedding) in &self.pending {
245 for (token_idx, token) in embedding.tokens().enumerate() {
246 let (centroid_id, residual) = codec.compress(token);
247 centroid_assignments[centroid_id].push((*chunk_id, token_idx as u16, residual));
248 }
249 }
250
251 let bytes_per_residual = self.config.packed_residual_size();
253
254 self.sizes = centroid_assignments.iter().map(|v| v.len()).collect();
255 self.offsets = self
256 .sizes
257 .iter()
258 .scan(0, |acc, &size| {
259 let offset = *acc;
260 *acc += size;
261 Some(offset)
262 })
263 .collect();
264
265 let total_tokens: usize = self.sizes.iter().sum();
266 self.chunk_ids = Vec::with_capacity(total_tokens);
267 self.token_indices = Vec::with_capacity(total_tokens);
268 self.residuals = Vec::with_capacity(total_tokens * bytes_per_residual);
269
270 for assignments in centroid_assignments {
271 for (chunk_id, token_idx, residual) in assignments {
272 self.chunk_ids.push(chunk_id);
273 self.token_indices.push(token_idx);
274 self.residuals.extend(residual);
275 }
276 }
277
278 self.pending.clear();
279 self.is_built = true;
280
281 Ok(())
282 }
283
284 pub fn clear_index(&mut self) {
289 self.sizes.clear();
290 self.offsets.clear();
291 self.chunk_ids.clear();
292 self.token_indices.clear();
293 self.residuals.clear();
294 self.is_built = false;
295 }
296
297 pub fn search(
312 &self,
313 query: &MultiVectorEmbedding,
314 search_config: &WarpSearchConfig,
315 ) -> Result<Vec<(ChunkId, f32)>> {
316 let codec = self
317 .codec
318 .as_ref()
319 .ok_or_else(|| crate::Error::InvalidInput("Codec not trained".to_string()))?;
320
321 if !self.is_built {
322 return Err(crate::Error::InvalidInput(
323 "Index not built - call build() first".to_string(),
324 ));
325 }
326
327 let selected_centroids = CentroidSelector::select(
329 query,
330 codec.centroids(),
331 self.config.token_dim,
332 search_config,
333 );
334
335 let mut total_centroids = 0;
337 let max_tokens = search_config.t_prime.unwrap_or(usize::MAX);
338 let bounded_centroids: Vec<Vec<(usize, f32)>> = selected_centroids
339 .into_iter()
340 .take(max_tokens)
341 .map(|centroids| {
342 let take =
343 (search_config.bound.saturating_sub(total_centroids)).min(centroids.len());
344 total_centroids += take;
345 centroids.into_iter().take(take).collect()
346 })
347 .collect();
348
349 let bytes_per_residual = self.config.packed_residual_size();
351
352 let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = bounded_centroids
353 .into_iter()
354 .enumerate()
355 .map(|(query_token_idx, centroids)| {
356 let query_token = query.token(query_token_idx);
357
358 centroids
359 .into_iter()
360 .flat_map(|(centroid_id, centroid_score)| {
361 CandidateScorer::score(
362 query_token,
363 centroid_id,
364 centroid_score,
365 codec,
366 &self.sizes,
367 &self.offsets,
368 &self.chunk_ids,
369 &self.token_indices,
370 &self.residuals,
371 bytes_per_residual,
372 )
373 })
374 .collect()
375 })
376 .collect();
377
378 Ok(ScoreMerger::merge(token_scores, search_config.k))
380 }
381
382 #[must_use]
384 pub fn centroid_size(&self, centroid_id: usize) -> usize {
385 self.sizes.get(centroid_id).copied().unwrap_or(0)
386 }
387
388 #[must_use]
390 pub fn centroid_offset(&self, centroid_id: usize) -> usize {
391 self.offsets.get(centroid_id).copied().unwrap_or(0)
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::DocumentId;
399
400 fn create_test_chunk(content: &str) -> Chunk {
401 Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
402 }
403
404 fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
405 let mut embeddings = Vec::with_capacity(num_tokens * dim);
406 let mut rng = seed;
407
408 for _ in 0..(num_tokens * dim) {
409 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
410 let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
411 embeddings.push(val);
412 }
413
414 MultiVectorEmbedding::new(embeddings, num_tokens, dim)
415 }
416
417 #[test]
420 fn test_index_new() {
421 let config = WarpIndexConfig::new(2, 16, 32);
422 let index = WarpIndex::new(config);
423
424 assert!(!index.is_trained());
425 assert!(!index.is_built());
426 assert!(index.is_empty());
427 }
428
429 #[test]
430 fn test_index_config() {
431 let config = WarpIndexConfig::new(4, 32, 64);
432 let index = WarpIndex::new(config);
433
434 assert_eq!(index.config().nbits, 4);
435 assert_eq!(index.config().num_centroids, 32);
436 assert_eq!(index.config().token_dim, 64);
437 }
438
439 #[test]
442 fn test_index_train() {
443 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
444 let mut index = WarpIndex::new(config);
445
446 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
448
449 index.train(&samples).unwrap();
450
451 assert!(index.is_trained());
452 assert!(index.codec().is_some());
453 }
454
455 #[test]
456 fn test_index_train_insufficient_samples() {
457 let config = WarpIndexConfig::new(2, 100, 16); let mut index = WarpIndex::new(config);
459
460 let samples: Vec<_> = (0..5).map(|i| generate_embedding(10, 16, i)).collect();
461
462 let result = index.train(&samples);
463 assert!(result.is_err());
464 }
465
466 #[test]
469 fn test_index_insert() {
470 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
471 let mut index = WarpIndex::new(config);
472
473 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
475 index.train(&samples).unwrap();
476
477 let chunk = create_test_chunk("test content");
479 let embedding = generate_embedding(5, 16, 999);
480 index.insert(chunk, embedding).unwrap();
481
482 assert_eq!(index.num_chunks(), 1);
483 }
484
485 #[test]
486 fn test_index_insert_without_training() {
487 let config = WarpIndexConfig::new(2, 8, 16);
488 let mut index = WarpIndex::new(config);
489
490 let chunk = create_test_chunk("test");
491 let embedding = generate_embedding(5, 16, 0);
492
493 let result = index.insert(chunk, embedding);
494 assert!(result.is_err());
495 }
496
497 #[test]
500 fn test_index_build() {
501 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
502 let mut index = WarpIndex::new(config);
503
504 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
506 index.train(&samples).unwrap();
507
508 for i in 0..10 {
510 let chunk = create_test_chunk(&format!("document {}", i));
511 let embedding = generate_embedding(5, 16, 1000 + i);
512 index.insert(chunk, embedding).unwrap();
513 }
514
515 index.build().unwrap();
517
518 assert!(index.is_built());
519 assert_eq!(index.num_chunks(), 10);
520 assert_eq!(index.num_tokens(), 50); }
522
523 #[test]
524 fn test_index_cannot_insert_after_build() {
525 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
526 let mut index = WarpIndex::new(config);
527
528 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
529 index.train(&samples).unwrap();
530
531 let chunk = create_test_chunk("test");
532 let embedding = generate_embedding(5, 16, 0);
533 index.insert(chunk, embedding).unwrap();
534
535 index.build().unwrap();
536
537 let chunk2 = create_test_chunk("test2");
539 let embedding2 = generate_embedding(5, 16, 1);
540 let result = index.insert(chunk2, embedding2);
541
542 assert!(result.is_err());
543 }
544
545 #[test]
548 fn test_index_search() {
549 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
550 let mut index = WarpIndex::new(config);
551
552 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
554 index.train(&samples).unwrap();
555
556 for i in 0..20 {
558 let chunk = create_test_chunk(&format!("document {}", i));
559 let embedding = generate_embedding(5, 16, 1000 + i);
560 index.insert(chunk, embedding).unwrap();
561 }
562
563 index.build().unwrap();
565
566 let query = generate_embedding(3, 16, 9999);
568 let search_config = WarpSearchConfig::with_k(5);
569 let results = index.search(&query, &search_config).unwrap();
570
571 assert!(results.len() <= 5);
572 assert!(!results.is_empty());
573
574 for i in 1..results.len() {
576 assert!(results[i - 1].1 >= results[i].1);
577 }
578 }
579
580 #[test]
581 fn test_index_search_without_build() {
582 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
583 let mut index = WarpIndex::new(config);
584
585 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
586 index.train(&samples).unwrap();
587
588 let query = generate_embedding(3, 16, 0);
589 let search_config = WarpSearchConfig::with_k(5);
590 let result = index.search(&query, &search_config);
591
592 assert!(result.is_err());
593 }
594
595 #[test]
598 fn test_index_memory_usage() {
599 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
600 let mut index = WarpIndex::new(config);
601
602 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
603 index.train(&samples).unwrap();
604
605 for i in 0..10 {
606 let chunk = create_test_chunk(&format!("doc {}", i));
607 let embedding = generate_embedding(5, 16, 1000 + i);
608 index.insert(chunk, embedding).unwrap();
609 }
610
611 index.build().unwrap();
612
613 let memory = index.memory_usage();
614 assert!(memory > 0);
615 }
616
617 #[test]
618 fn test_index_centroid_stats() {
619 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
620 let mut index = WarpIndex::new(config);
621
622 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
623 index.train(&samples).unwrap();
624
625 for i in 0..10 {
626 let chunk = create_test_chunk(&format!("doc {}", i));
627 let embedding = generate_embedding(5, 16, 1000 + i);
628 index.insert(chunk, embedding).unwrap();
629 }
630
631 index.build().unwrap();
632
633 let total: usize = (0..8).map(|c| index.centroid_size(c)).sum();
635 assert_eq!(total, index.num_tokens());
636 }
637
638 #[test]
641 fn test_index_clear_and_rebuild() {
642 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
643 let mut index = WarpIndex::new(config);
644
645 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
646 index.train(&samples).unwrap();
647
648 let chunk = create_test_chunk("test");
649 let embedding = generate_embedding(5, 16, 0);
650 index.insert(chunk, embedding).unwrap();
651 index.build().unwrap();
652
653 assert!(index.is_built());
654
655 index.clear_index();
656
657 assert!(!index.is_built());
658 assert_eq!(index.num_tokens(), 0);
659 assert_eq!(index.num_chunks(), 1);
661 }
662
663 #[test]
666 fn test_index_get_chunk() {
667 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
668 let mut index = WarpIndex::new(config);
669
670 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
671 index.train(&samples).unwrap();
672
673 let chunk = create_test_chunk("test content");
674 let chunk_id = chunk.id;
675 let embedding = generate_embedding(5, 16, 0);
676 index.insert(chunk, embedding).unwrap();
677
678 let retrieved = index.get_chunk(&chunk_id);
679 assert!(retrieved.is_some());
680 assert_eq!(retrieved.unwrap().content, "test content");
681 }
682
683 use proptest::prelude::*;
686
687 proptest! {
688 #[test]
689 fn prop_search_returns_at_most_k(k in 1usize..20) {
690 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
691 let mut index = WarpIndex::new(config);
692
693 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
694 index.train(&samples).unwrap();
695
696 for i in 0..30 {
697 let chunk = create_test_chunk(&format!("doc {}", i));
698 let embedding = generate_embedding(5, 16, 1000 + i as u64);
699 index.insert(chunk, embedding).unwrap();
700 }
701
702 index.build().unwrap();
703
704 let query = generate_embedding(3, 16, 9999);
705 let search_config = WarpSearchConfig::with_k(k);
706 let results = index.search(&query, &search_config).unwrap();
707
708 prop_assert!(results.len() <= k);
709 }
710
711 #[test]
712 fn prop_search_results_sorted_descending(seed in 0u64..1000) {
713 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
714 let mut index = WarpIndex::new(config);
715
716 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
717 index.train(&samples).unwrap();
718
719 for i in 0..20 {
720 let chunk = create_test_chunk(&format!("doc {}", i));
721 let embedding = generate_embedding(5, 16, seed + i as u64);
722 index.insert(chunk, embedding).unwrap();
723 }
724
725 index.build().unwrap();
726
727 let query = generate_embedding(3, 16, seed + 1000);
728 let search_config = WarpSearchConfig::with_k(10);
729 let results = index.search(&query, &search_config).unwrap();
730
731 for i in 1..results.len() {
732 prop_assert!(results[i - 1].1 >= results[i].1);
733 }
734 }
735 }
736}