1use rand::Rng;
10use std::collections::VecDeque;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Debug, Clone)]
15pub struct ReplayEntry {
16 pub query: Vec<f32>,
18 pub positive_ids: Vec<usize>,
20 pub timestamp: u64,
22}
23
24impl ReplayEntry {
25 pub fn new(query: Vec<f32>, positive_ids: Vec<usize>) -> Self {
27 let timestamp = SystemTime::now()
28 .duration_since(UNIX_EPOCH)
29 .unwrap_or_default()
30 .as_millis() as u64;
31
32 Self {
33 query,
34 positive_ids,
35 timestamp,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct DistributionStats {
43 pub mean: Vec<f32>,
45 pub variance: Vec<f32>,
47 pub count: usize,
49}
50
51impl DistributionStats {
52 pub fn new(dimension: usize) -> Self {
54 Self {
55 mean: vec![0.0; dimension],
56 variance: vec![0.0; dimension],
57 count: 0,
58 }
59 }
60
61 pub fn update(&mut self, sample: &[f32]) {
63 if self.mean.is_empty() && !sample.is_empty() {
64 self.mean = vec![0.0; sample.len()];
65 self.variance = vec![0.0; sample.len()];
66 }
67
68 if self.mean.len() != sample.len() {
69 return; }
71
72 self.count += 1;
73 let count = self.count as f32;
74
75 for i in 0..sample.len() {
76 let delta = sample[i] - self.mean[i];
77 self.mean[i] += delta / count;
78 let delta2 = sample[i] - self.mean[i];
79 self.variance[i] += delta * delta2;
80 }
81 }
82
83 pub fn std_dev(&self) -> Vec<f32> {
85 if self.count <= 1 {
86 return vec![0.0; self.variance.len()];
87 }
88
89 self.variance
90 .iter()
91 .map(|&v| (v / (self.count - 1) as f32).sqrt())
92 .collect()
93 }
94
95 pub fn reset(&mut self) {
97 let dim = self.mean.len();
98 self.mean = vec![0.0; dim];
99 self.variance = vec![0.0; dim];
100 self.count = 0;
101 }
102}
103
104pub struct ReplayBuffer {
106 queries: VecDeque<ReplayEntry>,
108 capacity: usize,
110 total_seen: usize,
112 distribution_stats: DistributionStats,
114}
115
116impl ReplayBuffer {
117 pub fn new(capacity: usize) -> Self {
122 Self {
123 queries: VecDeque::with_capacity(capacity),
124 capacity,
125 total_seen: 0,
126 distribution_stats: DistributionStats::new(0),
127 }
128 }
129
130 pub fn add(&mut self, query: &[f32], positive_ids: &[usize]) {
139 let entry = ReplayEntry::new(query.to_vec(), positive_ids.to_vec());
140
141 self.total_seen += 1;
142
143 self.distribution_stats.update(query);
145
146 if self.queries.len() < self.capacity {
148 self.queries.push_back(entry);
149 return;
150 }
151
152 let mut rng = rand::thread_rng();
154 let random_index = rng.gen_range(0..self.total_seen);
155
156 if random_index < self.capacity {
157 self.queries[random_index] = entry;
158 }
159 }
160
161 pub fn sample(&self, batch_size: usize) -> Vec<&ReplayEntry> {
169 if self.queries.is_empty() {
170 return Vec::new();
171 }
172
173 let actual_batch_size = batch_size.min(self.queries.len());
174 let mut rng = rand::thread_rng();
175 let mut indices: Vec<usize> = (0..self.queries.len()).collect();
176
177 for i in 0..actual_batch_size {
179 let j = rng.gen_range(i..indices.len());
180 indices.swap(i, j);
181 }
182
183 indices[..actual_batch_size]
184 .iter()
185 .map(|&idx| &self.queries[idx])
186 .collect()
187 }
188
189 pub fn detect_distribution_shift(&self, recent_window: usize) -> f32 {
200 if self.queries.len() < recent_window || recent_window == 0 {
201 return 0.0;
202 }
203
204 let mut recent_stats = DistributionStats::new(self.distribution_stats.mean.len());
206
207 let start_idx = self.queries.len().saturating_sub(recent_window);
208 for entry in self.queries.iter().skip(start_idx) {
209 recent_stats.update(&entry.query);
210 }
211
212 let overall_mean = &self.distribution_stats.mean;
214 let recent_mean = &recent_stats.mean;
215
216 if overall_mean.is_empty() || recent_mean.is_empty() {
217 return 0.0;
218 }
219
220 let overall_std = self.distribution_stats.std_dev();
221 let mut shift_sum = 0.0;
222 let mut count = 0;
223
224 for i in 0..overall_mean.len() {
225 if overall_std[i] > 1e-8 {
226 let diff = (recent_mean[i] - overall_mean[i]).abs();
227 shift_sum += diff / overall_std[i];
228 count += 1;
229 }
230 }
231
232 if count > 0 {
233 shift_sum / count as f32
234 } else {
235 0.0
236 }
237 }
238
239 pub fn len(&self) -> usize {
241 self.queries.len()
242 }
243
244 pub fn is_empty(&self) -> bool {
246 self.queries.is_empty()
247 }
248
249 pub fn capacity(&self) -> usize {
251 self.capacity
252 }
253
254 pub fn total_seen(&self) -> usize {
256 self.total_seen
257 }
258
259 pub fn distribution_stats(&self) -> &DistributionStats {
261 &self.distribution_stats
262 }
263
264 pub fn clear(&mut self) {
266 self.queries.clear();
267 self.total_seen = 0;
268 self.distribution_stats.reset();
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_replay_buffer_basic() {
278 let mut buffer = ReplayBuffer::new(10);
279 assert_eq!(buffer.len(), 0);
280 assert!(buffer.is_empty());
281 assert_eq!(buffer.capacity(), 10);
282
283 buffer.add(&[1.0, 2.0, 3.0], &[0, 1]);
284 assert_eq!(buffer.len(), 1);
285 assert!(!buffer.is_empty());
286
287 buffer.add(&[4.0, 5.0, 6.0], &[2, 3]);
288 assert_eq!(buffer.len(), 2);
289 assert_eq!(buffer.total_seen(), 2);
290 }
291
292 #[test]
293 fn test_replay_buffer_capacity() {
294 let mut buffer = ReplayBuffer::new(3);
295
296 for i in 0..3 {
298 buffer.add(&[i as f32], &[i]);
299 }
300 assert_eq!(buffer.len(), 3);
301
302 for i in 3..10 {
304 buffer.add(&[i as f32], &[i]);
305 }
306 assert_eq!(buffer.len(), 3);
307 assert_eq!(buffer.total_seen(), 10);
308 }
309
310 #[test]
311 fn test_sample_empty_buffer() {
312 let buffer = ReplayBuffer::new(10);
313 let samples = buffer.sample(5);
314 assert!(samples.is_empty());
315 }
316
317 #[test]
318 fn test_sample_basic() {
319 let mut buffer = ReplayBuffer::new(10);
320
321 for i in 0..5 {
322 buffer.add(&[i as f32], &[i]);
323 }
324
325 let samples = buffer.sample(3);
326 assert_eq!(samples.len(), 3);
327
328 for sample in samples {
330 assert!(sample.query[0] >= 0.0 && sample.query[0] < 5.0);
331 }
332 }
333
334 #[test]
335 fn test_sample_larger_than_buffer() {
336 let mut buffer = ReplayBuffer::new(10);
337
338 buffer.add(&[1.0], &[0]);
339 buffer.add(&[2.0], &[1]);
340
341 let samples = buffer.sample(5);
342 assert_eq!(samples.len(), 2); }
344
345 #[test]
346 fn test_distribution_stats_update() {
347 let mut stats = DistributionStats::new(2);
348
349 stats.update(&[1.0, 2.0]);
350 assert_eq!(stats.count, 1);
351 assert_eq!(stats.mean, vec![1.0, 2.0]);
352
353 stats.update(&[3.0, 4.0]);
354 assert_eq!(stats.count, 2);
355 assert_eq!(stats.mean, vec![2.0, 3.0]);
356
357 stats.update(&[2.0, 3.0]);
358 assert_eq!(stats.count, 3);
359 assert_eq!(stats.mean, vec![2.0, 3.0]);
360 }
361
362 #[test]
363 fn test_distribution_stats_std_dev() {
364 let mut stats = DistributionStats::new(2);
365
366 stats.update(&[1.0, 1.0]);
367 stats.update(&[3.0, 3.0]);
368 stats.update(&[5.0, 5.0]);
369
370 let std_dev = stats.std_dev();
371 assert!((std_dev[0] - 2.0).abs() < 0.01);
373 assert!((std_dev[1] - 2.0).abs() < 0.01);
374 }
375
376 #[test]
377 fn test_detect_distribution_shift_no_shift() {
378 let mut buffer = ReplayBuffer::new(100);
379
380 for _ in 0..50 {
382 buffer.add(&[1.0, 2.0, 3.0], &[0]);
383 }
384
385 let shift = buffer.detect_distribution_shift(10);
386 assert!(shift < 0.1); }
388
389 #[test]
390 fn test_detect_distribution_shift_with_shift() {
391 let mut buffer = ReplayBuffer::new(100);
392
393 for _ in 0..40 {
395 buffer.add(&[1.0, 2.0, 3.0], &[0]);
396 }
397
398 for _ in 0..10 {
400 buffer.add(&[5.0, 6.0, 7.0], &[1]);
401 }
402
403 let shift = buffer.detect_distribution_shift(10);
404 assert!(shift > 0.5); }
406
407 #[test]
408 fn test_detect_distribution_shift_insufficient_data() {
409 let mut buffer = ReplayBuffer::new(100);
410
411 buffer.add(&[1.0, 2.0], &[0]);
412
413 let shift = buffer.detect_distribution_shift(10);
414 assert_eq!(shift, 0.0); }
416
417 #[test]
418 fn test_clear() {
419 let mut buffer = ReplayBuffer::new(10);
420
421 for i in 0..5 {
422 buffer.add(&[i as f32], &[i]);
423 }
424
425 assert_eq!(buffer.len(), 5);
426 assert_eq!(buffer.total_seen(), 5);
427
428 buffer.clear();
429 assert_eq!(buffer.len(), 0);
430 assert_eq!(buffer.total_seen(), 0);
431 assert!(buffer.is_empty());
432 assert_eq!(buffer.distribution_stats().count, 0);
433 }
434
435 #[test]
436 fn test_replay_entry_creation() {
437 let entry = ReplayEntry::new(vec![1.0, 2.0, 3.0], vec![0, 1, 2]);
438
439 assert_eq!(entry.query, vec![1.0, 2.0, 3.0]);
440 assert_eq!(entry.positive_ids, vec![0, 1, 2]);
441 assert!(entry.timestamp > 0);
442 }
443
444 #[test]
445 fn test_reservoir_sampling_distribution() {
446 let mut buffer = ReplayBuffer::new(10);
447
448 for i in 0..100 {
450 buffer.add(&[i as f32], &[i]);
451 }
452
453 assert_eq!(buffer.len(), 10);
454 assert_eq!(buffer.total_seen(), 100);
455
456 let samples1 = buffer.sample(5);
458 let samples2 = buffer.sample(5);
459
460 assert_eq!(samples1.len(), 5);
461 assert_eq!(samples2.len(), 5);
462
463 let sample_batch = buffer.sample(10);
465 let values: Vec<f32> = sample_batch.iter().map(|e| e.query[0]).collect();
466
467 let unique_values: std::collections::HashSet<_> =
469 values.iter().map(|&v| v as i32).collect();
470 assert!(unique_values.len() > 1);
471 }
472
473 #[test]
474 fn test_dimension_mismatch_handling() {
475 let mut buffer = ReplayBuffer::new(10);
476
477 buffer.add(&[1.0, 2.0], &[0]);
478
479 assert_eq!(buffer.len(), 1);
482 assert_eq!(buffer.distribution_stats().mean.len(), 2);
483 }
484
485 #[test]
486 fn test_sample_uniqueness() {
487 let mut buffer = ReplayBuffer::new(5);
488
489 for i in 0..5 {
490 buffer.add(&[i as f32], &[i]);
491 }
492
493 let samples = buffer.sample(5);
495 let values: Vec<f32> = samples.iter().map(|e| e.query[0]).collect();
496
497 let unique_values: std::collections::HashSet<_> =
499 values.iter().map(|&v| v as i32).collect();
500 assert_eq!(unique_values.len(), 5);
501 }
502}