1use std::collections::BinaryHeap;
51use std::sync::atomic::{AtomicU64, Ordering};
52use std::time::{Duration, Instant};
53
54use crate::cost_model::CostTracker;
55
56#[derive(Debug, Clone)]
62pub struct RerankConfig {
63 pub max_io_ops: u32,
65
66 pub max_io_bytes: u64,
68
69 pub max_latency: Duration,
71
72 pub coalesce_threshold: u64,
75
76 pub min_rerank_candidates: usize,
78
79 pub enable_cache: bool,
81
82 pub cache_size: usize,
84
85 pub io_queue_depth: u32,
87
88 pub prefetch_distance: usize,
90}
91
92impl Default for RerankConfig {
93 fn default() -> Self {
94 Self {
95 max_io_ops: 100,
96 max_io_bytes: 16 * 1024 * 1024, max_latency: Duration::from_millis(50),
98 coalesce_threshold: 4096, min_rerank_candidates: 10,
100 enable_cache: true,
101 cache_size: 10000,
102 io_queue_depth: 64,
103 prefetch_distance: 4,
104 }
105 }
106}
107
108impl RerankConfig {
109 pub fn io_budget(mut self, max_ops: u32) -> Self {
111 self.max_io_ops = max_ops;
112 self
113 }
114
115 pub fn coalesce_threshold(mut self, bytes: u64) -> Self {
117 self.coalesce_threshold = bytes;
118 self
119 }
120
121 pub fn max_latency(mut self, latency: Duration) -> Self {
123 self.max_latency = latency;
124 self
125 }
126}
127
128#[derive(Debug, Clone)]
134pub struct IoRange {
135 pub offset: u64,
137 pub length: u64,
139 pub candidate_indices: Vec<usize>,
141}
142
143impl IoRange {
144 pub fn single(offset: u64, length: u64, candidate_idx: usize) -> Self {
146 Self {
147 offset,
148 length,
149 candidate_indices: vec![candidate_idx],
150 }
151 }
152
153 pub fn try_merge(&mut self, other: &IoRange, threshold: u64) -> bool {
155 let self_end = self.offset + self.length;
156 let other_end = other.offset + other.length;
157
158 if other.offset <= self_end + threshold && self.offset <= other_end + threshold {
160 let new_start = self.offset.min(other.offset);
162 let new_end = self_end.max(other_end);
163
164 self.offset = new_start;
165 self.length = new_end - new_start;
166 self.candidate_indices
167 .extend_from_slice(&other.candidate_indices);
168 true
169 } else {
170 false
171 }
172 }
173
174 pub fn end(&self) -> u64 {
176 self.offset + self.length
177 }
178}
179
180#[derive(Debug, Clone)]
186pub struct RerankCandidate {
187 pub id: u32,
189 pub proxy_score: f32,
191 pub disk_offset: u64,
193 pub vector_size: u32,
195}
196
197impl RerankCandidate {
198 pub fn new(id: u32, proxy_score: f32, disk_offset: u64, vector_size: u32) -> Self {
200 Self {
201 id,
202 proxy_score,
203 disk_offset,
204 vector_size,
205 }
206 }
207}
208
209#[derive(Debug, Clone)]
215pub struct RerankResult {
216 pub id: u32,
218 pub true_score: f32,
220 pub from_cache: bool,
222}
223
224impl PartialEq for RerankResult {
225 fn eq(&self, other: &Self) -> bool {
226 self.id == other.id
227 }
228}
229
230impl Eq for RerankResult {}
231
232impl PartialOrd for RerankResult {
233 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
234 other.true_score.partial_cmp(&self.true_score)
236 }
237}
238
239impl Ord for RerankResult {
240 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
241 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
242 }
243}
244
245#[derive(Debug, Clone, Default)]
251pub struct RerankStats {
252 pub candidates_requested: usize,
254 pub candidates_reranked: usize,
256 pub io_ops: u32,
258 pub io_bytes: u64,
260 pub coalesced_ranges: usize,
262 pub cache_hits: usize,
264 pub cache_misses: usize,
266 pub budget_exhausted: bool,
268 pub stop_reason: String,
270 pub duration: Duration,
272}
273
274impl RerankStats {
275 pub fn io_amplification(&self) -> f32 {
277 if self.candidates_reranked == 0 {
278 0.0
279 } else {
280 self.io_bytes as f32 / (self.candidates_reranked as f32 * 4.0 * 768.0) }
282 }
283
284 pub fn cache_hit_ratio(&self) -> f32 {
286 let total = self.cache_hits + self.cache_misses;
287 if total == 0 {
288 0.0
289 } else {
290 self.cache_hits as f32 / total as f32
291 }
292 }
293}
294
295pub struct IoCoalescer {
301 threshold: u64,
302}
303
304impl IoCoalescer {
305 pub fn new(threshold: u64) -> Self {
307 Self { threshold }
308 }
309
310 pub fn coalesce(&self, candidates: &[RerankCandidate]) -> Vec<IoRange> {
317 if candidates.is_empty() {
318 return Vec::new();
319 }
320
321 let mut indexed: Vec<(usize, &RerankCandidate)> = candidates.iter().enumerate().collect();
323 indexed.sort_by_key(|(_, c)| c.disk_offset);
324
325 let mut ranges: Vec<IoRange> = Vec::with_capacity(candidates.len());
326
327 let (first_idx, first) = indexed[0];
329 let mut current = IoRange::single(first.disk_offset, first.vector_size as u64, first_idx);
330
331 for (idx, candidate) in indexed.iter().skip(1) {
333 let new_range =
334 IoRange::single(candidate.disk_offset, candidate.vector_size as u64, *idx);
335
336 if !current.try_merge(&new_range, self.threshold) {
337 ranges.push(current);
338 current = new_range;
339 }
340 }
341
342 ranges.push(current);
343 ranges
344 }
345
346 pub fn coalesce_stats(&self, candidates: &[RerankCandidate]) -> CoalesceStats {
348 let ranges = self.coalesce(candidates);
349
350 let total_raw_bytes: u64 = candidates.iter().map(|c| c.vector_size as u64).sum();
351
352 let total_coalesced_bytes: u64 = ranges.iter().map(|r| r.length).sum();
353
354 CoalesceStats {
355 n_candidates: candidates.len(),
356 n_ranges: ranges.len(),
357 raw_bytes: total_raw_bytes,
358 coalesced_bytes: total_coalesced_bytes,
359 reduction_ratio: total_coalesced_bytes as f32 / total_raw_bytes.max(1) as f32,
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
366pub struct CoalesceStats {
367 pub n_candidates: usize,
368 pub n_ranges: usize,
369 pub raw_bytes: u64,
370 pub coalesced_bytes: u64,
371 pub reduction_ratio: f32,
372}
373
374pub struct VectorCache {
380 cache: parking_lot::RwLock<std::collections::HashMap<u32, (Vec<f32>, u64)>>,
382 max_size: usize,
384 access_counter: AtomicU64,
386}
387
388impl VectorCache {
389 pub fn new(max_size: usize) -> Self {
391 Self {
392 cache: parking_lot::RwLock::new(std::collections::HashMap::with_capacity(max_size)),
393 max_size,
394 access_counter: AtomicU64::new(0),
395 }
396 }
397
398 pub fn get(&self, id: u32) -> Option<Vec<f32>> {
400 let mut cache = self.cache.write();
401 if let Some((vec, access)) = cache.get_mut(&id) {
402 *access = self.access_counter.fetch_add(1, Ordering::Relaxed);
403 Some(vec.clone())
404 } else {
405 None
406 }
407 }
408
409 pub fn insert(&self, id: u32, vector: Vec<f32>) {
411 let mut cache = self.cache.write();
412
413 if cache.len() >= self.max_size {
415 let lru_id = cache
417 .iter()
418 .min_by_key(|(_, (_, access))| *access)
419 .map(|(id, _)| *id);
420
421 if let Some(lru_id) = lru_id {
422 cache.remove(&lru_id);
423 }
424 }
425
426 let access = self.access_counter.fetch_add(1, Ordering::Relaxed);
427 cache.insert(id, (vector, access));
428 }
429
430 pub fn contains(&self, id: u32) -> bool {
432 self.cache.read().contains_key(&id)
433 }
434
435 pub fn len(&self) -> usize {
437 self.cache.read().len()
438 }
439
440 pub fn clear(&self) {
442 self.cache.write().clear();
443 }
444}
445
446pub type DistanceFn = dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync;
452
453pub type StorageReader = dyn Fn(u64, u64) -> std::io::Result<Vec<u8>> + Send + Sync;
455
456pub struct RerankExecutor {
458 config: RerankConfig,
459 coalescer: IoCoalescer,
460 cache: Option<VectorCache>,
461 distance_fn: Box<DistanceFn>,
463 reader: Box<StorageReader>,
465 dim: usize,
467}
468
469impl RerankExecutor {
470 pub fn new<D, R>(config: RerankConfig, distance_fn: D, reader: R, dim: usize) -> Self
472 where
473 D: Fn(&[f32], &[f32]) -> f32 + Send + Sync + 'static,
474 R: Fn(u64, u64) -> std::io::Result<Vec<u8>> + Send + Sync + 'static,
475 {
476 let cache = if config.enable_cache {
477 Some(VectorCache::new(config.cache_size))
478 } else {
479 None
480 };
481
482 Self {
483 coalescer: IoCoalescer::new(config.coalesce_threshold),
484 cache,
485 config,
486 distance_fn: Box::new(distance_fn),
487 reader: Box::new(reader),
488 dim,
489 }
490 }
491
492 pub fn rerank(
494 &self,
495 candidates: &[RerankCandidate],
496 query: &[f32],
497 k: usize,
498 ) -> (Vec<RerankResult>, RerankStats) {
499 self.rerank_with_tracker(candidates, query, k, None)
500 }
501
502 pub fn rerank_with_tracker(
504 &self,
505 candidates: &[RerankCandidate],
506 query: &[f32],
507 k: usize,
508 cost_tracker: Option<&CostTracker>,
509 ) -> (Vec<RerankResult>, RerankStats) {
510 let start = Instant::now();
511 let mut stats = RerankStats {
512 candidates_requested: candidates.len(),
513 ..Default::default()
514 };
515
516 let (cached_ids, uncached): (Vec<_>, Vec<_>) =
518 candidates.iter().enumerate().partition(|(_, c)| {
519 self.cache
520 .as_ref()
521 .map(|cache| cache.contains(c.id))
522 .unwrap_or(false)
523 });
524
525 let mut results: BinaryHeap<RerankResult> = BinaryHeap::new();
527
528 for (_idx, candidate) in cached_ids {
529 if let Some(ref cache) = self.cache {
530 if let Some(vector) = cache.get(candidate.id) {
531 let score = (self.distance_fn)(query, &vector);
532 results.push(RerankResult {
533 id: candidate.id,
534 true_score: score,
535 from_cache: true,
536 });
537 stats.cache_hits += 1;
538 stats.candidates_reranked += 1;
539 }
540 }
541 }
542
543 let uncached_candidates: Vec<_> = uncached.iter().map(|(_, c)| (*c).clone()).collect();
545 let ranges = self.coalescer.coalesce(&uncached_candidates);
546 stats.coalesced_ranges = ranges.len();
547
548 let mut io_ops = 0u32;
550 let mut io_bytes = 0u64;
551
552 for range in &ranges {
554 if io_ops >= self.config.max_io_ops {
556 stats.budget_exhausted = true;
557 stats.stop_reason = "io_ops_exceeded".to_string();
558 break;
559 }
560
561 if io_bytes + range.length > self.config.max_io_bytes {
562 stats.budget_exhausted = true;
563 stats.stop_reason = "io_bytes_exceeded".to_string();
564 break;
565 }
566
567 if start.elapsed() > self.config.max_latency {
568 stats.budget_exhausted = true;
569 stats.stop_reason = "latency_exceeded".to_string();
570 break;
571 }
572
573 if let Some(tracker) = cost_tracker {
575 if !tracker.add_ssd_sequential_bytes(range.length) {
576 stats.budget_exhausted = true;
577 stats.stop_reason = "cost_budget_exhausted".to_string();
578 break;
579 }
580 }
581
582 let data = match (self.reader)(range.offset, range.length) {
584 Ok(data) => data,
585 Err(_) => continue, };
587
588 io_ops += 1;
589 io_bytes += range.length;
590
591 for &candidate_idx in &range.candidate_indices {
593 let candidate = &uncached_candidates[candidate_idx];
594
595 let offset_in_range = candidate.disk_offset - range.offset;
597 let start_byte = offset_in_range as usize;
598 let end_byte = start_byte + candidate.vector_size as usize;
599
600 if end_byte > data.len() {
601 continue; }
603
604 let vector_bytes = &data[start_byte..end_byte];
606 let vector: Vec<f32> = vector_bytes
607 .chunks(4)
608 .map(|chunk| {
609 let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
610 f32::from_le_bytes(arr)
611 })
612 .collect();
613
614 let score = (self.distance_fn)(query, &vector);
616
617 results.push(RerankResult {
619 id: candidate.id,
620 true_score: score,
621 from_cache: false,
622 });
623
624 if let Some(ref cache) = self.cache {
626 cache.insert(candidate.id, vector);
627 }
628
629 stats.cache_misses += 1;
630 stats.candidates_reranked += 1;
631
632 if results.len() >= k * 2
634 && stats.candidates_reranked >= self.config.min_rerank_candidates
635 {
636 }
638 }
639 }
640
641 stats.io_ops = io_ops;
642 stats.io_bytes = io_bytes;
643 stats.duration = start.elapsed();
644
645 if stats.stop_reason.is_empty() {
646 stats.stop_reason = "complete".to_string();
647 }
648
649 let mut top_k: Vec<RerankResult> = Vec::with_capacity(k);
651 while top_k.len() < k && !results.is_empty() {
652 if let Some(result) = results.pop() {
653 top_k.push(result);
654 }
655 }
656
657 top_k.sort_by(|a, b| b.true_score.partial_cmp(&a.true_score).unwrap());
659
660 (top_k, stats)
661 }
662
663 pub fn config(&self) -> &RerankConfig {
665 &self.config
666 }
667
668 pub fn cache_stats(&self) -> Option<usize> {
670 self.cache.as_ref().map(|c| c.len())
671 }
672}
673
674pub struct MockStorage {
680 data: Vec<u8>,
681}
682
683impl MockStorage {
684 pub fn new(n_vectors: usize, dim: usize) -> Self {
686 let mut data = Vec::with_capacity(n_vectors * dim * 4);
687
688 for i in 0..n_vectors {
689 for j in 0..dim {
690 let val = (i + j) as f32 / (n_vectors + dim) as f32;
691 data.extend_from_slice(&val.to_le_bytes());
692 }
693 }
694
695 Self { data }
696 }
697
698 pub fn reader(&self) -> impl Fn(u64, u64) -> std::io::Result<Vec<u8>> + '_ {
700 move |offset, length| {
701 let start = offset as usize;
702 let end = (start + length as usize).min(self.data.len());
703 Ok(self.data[start..end].to_vec())
704 }
705 }
706
707 pub fn offset(&self, id: u32, dim: usize) -> u64 {
709 (id as usize * dim * 4) as u64
710 }
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716
717 #[test]
718 fn test_io_coalescing() {
719 let coalescer = IoCoalescer::new(1024);
720
721 let candidates = vec![
722 RerankCandidate::new(0, 0.9, 0, 3072), RerankCandidate::new(1, 0.8, 3072, 3072), RerankCandidate::new(2, 0.7, 10000, 3072), RerankCandidate::new(3, 0.6, 10500, 3072), ];
727
728 let ranges = coalescer.coalesce(&candidates);
729
730 assert_eq!(ranges.len(), 2);
732 assert_eq!(ranges[0].offset, 0);
733 assert!(ranges[0].length >= 6144);
734 }
735
736 #[test]
737 fn test_vector_cache() {
738 let cache = VectorCache::new(3);
739
740 cache.insert(1, vec![1.0, 2.0, 3.0]);
741 cache.insert(2, vec![4.0, 5.0, 6.0]);
742 cache.insert(3, vec![7.0, 8.0, 9.0]);
743
744 assert!(cache.contains(1));
745 assert!(cache.contains(2));
746 assert!(cache.contains(3));
747
748 cache.get(1);
750 cache.get(2);
751
752 cache.insert(4, vec![10.0, 11.0, 12.0]);
754
755 assert!(cache.contains(1));
756 assert!(cache.contains(2));
757 assert!(!cache.contains(3)); assert!(cache.contains(4));
759 }
760
761 #[test]
762 fn test_rerank_executor() {
763 let dim = 4;
764 let storage = MockStorage::new(100, dim);
765
766 let config = RerankConfig::default();
767
768 let distance_fn =
769 |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() };
770
771 let data_clone = storage.data.clone();
772 let reader = move |offset: u64, length: u64| -> std::io::Result<Vec<u8>> {
773 let start = offset as usize;
774 let end = (start + length as usize).min(data_clone.len());
775 Ok(data_clone[start..end].to_vec())
776 };
777
778 let executor = RerankExecutor::new(config, distance_fn, reader, dim);
779
780 let candidates: Vec<RerankCandidate> = (0..10)
781 .map(|i| {
782 RerankCandidate::new(
783 i,
784 0.9 - i as f32 * 0.01,
785 storage.offset(i, dim),
786 (dim * 4) as u32,
787 )
788 })
789 .collect();
790
791 let query = vec![1.0, 1.0, 1.0, 1.0];
792 let (results, stats) = executor.rerank(&candidates, &query, 5);
793
794 assert!(results.len() <= 5);
795 assert!(stats.candidates_reranked > 0);
796 assert!(stats.io_ops > 0);
797 }
798
799 #[test]
800 fn test_coalesce_stats() {
801 let coalescer = IoCoalescer::new(100);
802
803 let candidates: Vec<RerankCandidate> = (0..10)
805 .map(|i| RerankCandidate::new(i, 0.9, i as u64 * 50, 50))
806 .collect();
807
808 let stats = coalescer.coalesce_stats(&candidates);
809
810 assert_eq!(stats.n_candidates, 10);
811 assert!(stats.n_ranges < 10); assert!(stats.reduction_ratio >= 1.0); }
814}