1use serde::{Deserialize, Serialize};
55use std::collections::HashMap;
56
57#[derive(Debug, Clone, Copy)]
63pub struct ErrorSample {
64 pub proxy: f32,
66 pub true_score: f32,
68 pub error: f32,
70}
71
72impl ErrorSample {
73 pub fn new(proxy: f32, true_score: f32) -> Self {
75 Self {
76 proxy,
77 true_score,
78 error: proxy - true_score,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ErrorEnvelope {
90 pub list_idx: u32,
92
93 pub quantiles: HashMap<u32, f32>, pub mean_error: f32,
100
101 pub std_error: f32,
103
104 pub max_error: f32,
106
107 pub min_error: f32,
109
110 pub sample_count: u32,
112}
113
114impl ErrorEnvelope {
115 pub fn error_at_quantile(&self, quantile: f32) -> f32 {
119 let key = (quantile * 10000.0).round() as u32;
120
121 if let Some(&error) = self.quantiles.get(&key) {
123 return error;
124 }
125
126 let mut below_key = 0u32;
128 let mut above_key = 10000u32;
129 let mut below_val = self.min_error;
130 let mut above_val = self.max_error;
131
132 for (&k, &v) in &self.quantiles {
133 if k < key && k > below_key {
134 below_key = k;
135 below_val = v;
136 }
137 if k > key && k < above_key {
138 above_key = k;
139 above_val = v;
140 }
141 }
142
143 if above_key > below_key {
145 let t = (key - below_key) as f32 / (above_key - below_key) as f32;
146 below_val + t * (above_val - below_val)
147 } else {
148 self.max_error
149 }
150 }
151
152 pub fn safe_true_threshold(&self, proxy: f32, confidence: f32) -> f32 {
159 let error_bound = self.error_at_quantile(confidence);
160 proxy - error_bound
161 }
162
163 pub fn safe_proxy_threshold(&self, true_threshold: f32, confidence: f32) -> f32 {
168 let error_bound = self.error_at_quantile(confidence);
169 true_threshold + error_bound
170 }
171
172 pub fn definitely_beats(&self, proxy: f32, true_threshold: f32) -> bool {
174 proxy - self.max_error > true_threshold
176 }
177
178 pub fn might_beat(&self, proxy: f32, true_threshold: f32, confidence: f32) -> bool {
180 let error_bound = self.error_at_quantile(confidence);
181 proxy - error_bound > true_threshold
182 }
183}
184
185impl Default for ErrorEnvelope {
186 fn default() -> Self {
187 Self {
188 list_idx: 0,
189 quantiles: HashMap::new(),
190 mean_error: 0.0,
191 std_error: 0.0,
192 max_error: 0.0,
193 min_error: 0.0,
194 sample_count: 0,
195 }
196 }
197}
198
199pub struct ErrorCalibrator {
205 samples: Vec<Vec<ErrorSample>>,
207 n_lists: usize,
209 quantiles: Vec<f32>,
211}
212
213impl ErrorCalibrator {
214 pub fn new(n_lists: usize) -> Self {
216 Self {
217 samples: vec![Vec::new(); n_lists],
218 n_lists,
219 quantiles: vec![0.50, 0.75, 0.90, 0.95, 0.99, 0.999],
220 }
221 }
222
223 pub fn with_quantiles(n_lists: usize, quantiles: Vec<f32>) -> Self {
225 Self {
226 samples: vec![Vec::new(); n_lists],
227 n_lists,
228 quantiles,
229 }
230 }
231
232 pub fn record_error(&mut self, list_idx: usize, proxy: f32, true_score: f32) {
234 if list_idx < self.n_lists {
235 self.samples[list_idx].push(ErrorSample::new(proxy, true_score));
236 }
237 }
238
239 pub fn record_errors(&mut self, list_idx: usize, samples: &[(f32, f32)]) {
241 if list_idx < self.n_lists {
242 for &(proxy, true_score) in samples {
243 self.samples[list_idx].push(ErrorSample::new(proxy, true_score));
244 }
245 }
246 }
247
248 pub fn finalize(&self) -> ErrorEnvelopeSet {
250 let envelopes: Vec<ErrorEnvelope> = (0..self.n_lists)
251 .map(|i| self.compute_envelope(i))
252 .collect();
253
254 let global = self.compute_global_envelope();
256
257 ErrorEnvelopeSet { envelopes, global }
258 }
259
260 fn compute_envelope(&self, list_idx: usize) -> ErrorEnvelope {
262 let samples = &self.samples[list_idx];
263
264 if samples.is_empty() {
265 return ErrorEnvelope {
266 list_idx: list_idx as u32,
267 ..Default::default()
268 };
269 }
270
271 let mut errors: Vec<f32> = samples.iter().map(|s| s.error).collect();
273 errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
274
275 let n = errors.len();
276
277 let sum: f32 = errors.iter().sum();
279 let mean = sum / n as f32;
280 let variance: f32 = errors.iter().map(|&e| (e - mean).powi(2)).sum::<f32>() / n as f32;
281 let std = variance.sqrt();
282
283 let mut quantiles = HashMap::new();
285 for &q in &self.quantiles {
286 let idx = ((n as f32 * q) as usize).min(n - 1);
287 let key = (q * 10000.0).round() as u32;
288 quantiles.insert(key, errors[idx]);
289 }
290
291 ErrorEnvelope {
292 list_idx: list_idx as u32,
293 quantiles,
294 mean_error: mean,
295 std_error: std,
296 max_error: errors[n - 1],
297 min_error: errors[0],
298 sample_count: n as u32,
299 }
300 }
301
302 fn compute_global_envelope(&self) -> ErrorEnvelope {
304 let mut all_errors: Vec<f32> = self
305 .samples
306 .iter()
307 .flat_map(|s| s.iter().map(|e| e.error))
308 .collect();
309
310 if all_errors.is_empty() {
311 return ErrorEnvelope::default();
312 }
313
314 all_errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
315 let n = all_errors.len();
316
317 let sum: f32 = all_errors.iter().sum();
318 let mean = sum / n as f32;
319 let variance: f32 = all_errors.iter().map(|&e| (e - mean).powi(2)).sum::<f32>() / n as f32;
320 let std = variance.sqrt();
321
322 let mut quantiles = HashMap::new();
323 for &q in &self.quantiles {
324 let idx = ((n as f32 * q) as usize).min(n - 1);
325 let key = (q * 10000.0).round() as u32;
326 quantiles.insert(key, all_errors[idx]);
327 }
328
329 ErrorEnvelope {
330 list_idx: u32::MAX, quantiles,
332 mean_error: mean,
333 std_error: std,
334 max_error: all_errors[n - 1],
335 min_error: all_errors[0],
336 sample_count: n as u32,
337 }
338 }
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct ErrorEnvelopeSet {
348 pub envelopes: Vec<ErrorEnvelope>,
350 pub global: ErrorEnvelope,
352}
353
354impl ErrorEnvelopeSet {
355 pub fn get(&self, list_idx: usize) -> &ErrorEnvelope {
357 if list_idx < self.envelopes.len() && self.envelopes[list_idx].sample_count > 0 {
358 &self.envelopes[list_idx]
359 } else {
360 &self.global
361 }
362 }
363
364 pub fn safe_true_threshold(&self, list_idx: usize, proxy: f32, confidence: f32) -> f32 {
366 self.get(list_idx).safe_true_threshold(proxy, confidence)
367 }
368
369 pub fn can_terminate(
371 &self,
372 kth_proxy: f32,
373 remaining_list_bounds: &[(usize, f32)],
374 confidence: f32,
375 ) -> bool {
376 let kth_true_lower = self.global.safe_true_threshold(kth_proxy, confidence);
378
379 remaining_list_bounds.iter().all(|(list_idx, bound)| {
381 let envelope = self.get(*list_idx);
383 let true_upper = *bound + envelope.max_error.abs();
386 true_upper < kth_true_lower
387 })
388 }
389
390 pub fn to_bytes(&self) -> Vec<u8> {
392 bincode::serialize(self).unwrap_or_default()
393 }
394
395 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
397 bincode::deserialize(bytes).ok()
398 }
399}
400
401pub struct CalibrationRunner {
407 n_lists: usize,
409 quantize_fn: Option<Box<dyn Fn(&[f32]) -> Vec<u8> + Send + Sync>>,
411 proxy_distance_fn: Option<Box<dyn Fn(&[f32], &[u8]) -> f32 + Send + Sync>>,
413 true_distance_fn: Option<Box<dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync>>,
415}
416
417impl CalibrationRunner {
418 pub fn new(n_lists: usize) -> Self {
420 Self {
421 n_lists,
422 quantize_fn: None,
423 proxy_distance_fn: None,
424 true_distance_fn: None,
425 }
426 }
427
428 pub fn calibrate(
433 &self,
434 queries: &[Vec<f32>],
435 lists: &[Vec<Vec<f32>>],
436 quantized_lists: &[Vec<Vec<u8>>],
437 ) -> ErrorEnvelopeSet {
438 let mut calibrator = ErrorCalibrator::new(self.n_lists);
439
440 for query in queries {
441 for (list_idx, (vectors, codes)) in lists.iter().zip(quantized_lists.iter()).enumerate()
442 {
443 for (vec, code) in vectors.iter().zip(codes.iter()) {
444 let true_score = dot_product(query, vec);
446 let proxy_score = if let Some(ref f) = self.proxy_distance_fn {
447 f(query, code)
448 } else {
449 true_score };
451
452 calibrator.record_error(list_idx, proxy_score, true_score);
453 }
454 }
455 }
456
457 calibrator.finalize()
458 }
459
460 pub fn calibrate_synthetic(
464 n_lists: usize,
465 mean_error: f32,
466 std_error: f32,
467 samples_per_list: usize,
468 ) -> ErrorEnvelopeSet {
469 let mut calibrator = ErrorCalibrator::new(n_lists);
470
471 let mut rng_state: u64 = 12345;
473 let mut rand = || {
474 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
475 (rng_state >> 33) as f32 / (1u64 << 31) as f32
476 };
477
478 for list_idx in 0..n_lists {
479 for _ in 0..samples_per_list {
480 let u1 = rand();
482 let u2 = rand();
483 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
484 let error = mean_error + std_error * z;
485
486 let true_score = 0.5 + rand() * 0.5; let proxy_score = true_score + error;
488
489 calibrator.record_error(list_idx, proxy_score, true_score);
490 }
491 }
492
493 calibrator.finalize()
494 }
495}
496
497fn dot_product(a: &[f32], b: &[f32]) -> f32 {
499 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_error_sample() {
508 let sample = ErrorSample::new(0.92, 0.90);
509 assert!((sample.error - 0.02).abs() < 1e-6);
510 }
511
512 #[test]
513 fn test_calibrator() {
514 let mut calibrator = ErrorCalibrator::new(3);
515
516 calibrator.record_error(0, 0.90, 0.88);
518 calibrator.record_error(0, 0.85, 0.82);
519 calibrator.record_error(0, 0.92, 0.91);
520 calibrator.record_error(0, 0.88, 0.85);
521 calibrator.record_error(0, 0.95, 0.90);
522
523 let envelopes = calibrator.finalize();
524
525 assert!(envelopes.envelopes[0].sample_count == 5);
526 assert!(envelopes.envelopes[0].mean_error > 0.0);
527 assert!(envelopes.envelopes[0].max_error > envelopes.envelopes[0].mean_error);
528 }
529
530 #[test]
531 fn test_envelope_threshold() {
532 let mut quantiles = HashMap::new();
533 quantiles.insert(9500, 0.05); quantiles.insert(9900, 0.08); let envelope = ErrorEnvelope {
537 list_idx: 0,
538 quantiles,
539 mean_error: 0.03,
540 std_error: 0.02,
541 max_error: 0.10,
542 min_error: 0.00,
543 sample_count: 100,
544 };
545
546 let threshold = envelope.safe_true_threshold(0.90, 0.95);
549 assert!((threshold - 0.85).abs() < 0.01);
550
551 let threshold99 = envelope.safe_true_threshold(0.90, 0.99);
553 assert!((threshold99 - 0.82).abs() < 0.01);
554 }
555
556 #[test]
557 fn test_can_terminate() {
558 let envelopes = CalibrationRunner::calibrate_synthetic(5, 0.03, 0.01, 100);
559
560 let kth_proxy = 0.95;
562 let remaining = vec![(1, 0.70), (2, 0.65)]; let can_term = envelopes.can_terminate(kth_proxy, &remaining, 0.99);
565 assert!(
566 can_term,
567 "Should be able to terminate with high kth and low bounds"
568 );
569
570 let remaining_high = vec![(1, 0.94), (2, 0.93)];
572 let cannot_term = envelopes.can_terminate(kth_proxy, &remaining_high, 0.99);
573 assert!(!cannot_term, "Should not terminate with close bounds");
574 }
575
576 #[test]
577 fn test_synthetic_calibration() {
578 let envelopes = CalibrationRunner::calibrate_synthetic(10, 0.02, 0.01, 500);
579
580 assert_eq!(envelopes.envelopes.len(), 10);
581 assert!(envelopes.global.sample_count > 0);
582
583 assert!((envelopes.global.mean_error - 0.02).abs() < 0.01);
585 }
586}