Skip to main content

ruvector_gnn/
replay.rs

1//! Experience Replay Buffer for GNN Training
2//!
3//! This module implements an experience replay buffer to mitigate catastrophic forgetting
4//! during continual learning. The buffer stores past training samples and supports:
5//! - Reservoir sampling for uniform distribution over time
6//! - Batch sampling for training
7//! - Distribution shift detection
8
9use rand::Rng;
10use std::collections::VecDeque;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// A single entry in the replay buffer
14#[derive(Debug, Clone)]
15pub struct ReplayEntry {
16    /// Query vector used for training
17    pub query: Vec<f32>,
18    /// IDs of positive nodes for this query
19    pub positive_ids: Vec<usize>,
20    /// Timestamp when this entry was added (milliseconds since epoch)
21    pub timestamp: u64,
22}
23
24impl ReplayEntry {
25    /// Create a new replay entry with current timestamp
26    pub fn new(query: Vec<f32>, positive_ids: Vec<usize>) -> Self {
27        let timestamp = SystemTime::now()
28            .duration_since(UNIX_EPOCH)
29            .unwrap_or_default()
30            .as_millis() as u64;
31
32        Self {
33            query,
34            positive_ids,
35            timestamp,
36        }
37    }
38}
39
40/// Statistics for tracking distribution characteristics
41#[derive(Debug, Clone)]
42pub struct DistributionStats {
43    /// Running mean of query vectors
44    pub mean: Vec<f32>,
45    /// Running variance of query vectors
46    pub variance: Vec<f32>,
47    /// Number of samples used to compute statistics
48    pub count: usize,
49}
50
51impl DistributionStats {
52    /// Create new distribution statistics
53    pub fn new(dimension: usize) -> Self {
54        Self {
55            mean: vec![0.0; dimension],
56            variance: vec![0.0; dimension],
57            count: 0,
58        }
59    }
60
61    /// Update statistics with a new sample using Welford's online algorithm
62    pub fn update(&mut self, sample: &[f32]) {
63        if self.mean.is_empty() && !sample.is_empty() {
64            self.mean = vec![0.0; sample.len()];
65            self.variance = vec![0.0; sample.len()];
66        }
67
68        if self.mean.len() != sample.len() {
69            return; // Dimension mismatch, skip update
70        }
71
72        self.count += 1;
73        let count = self.count as f32;
74
75        for i in 0..sample.len() {
76            let delta = sample[i] - self.mean[i];
77            self.mean[i] += delta / count;
78            let delta2 = sample[i] - self.mean[i];
79            self.variance[i] += delta * delta2;
80        }
81    }
82
83    /// Compute standard deviation from variance
84    pub fn std_dev(&self) -> Vec<f32> {
85        if self.count <= 1 {
86            return vec![0.0; self.variance.len()];
87        }
88
89        self.variance
90            .iter()
91            .map(|&v| (v / (self.count - 1) as f32).sqrt())
92            .collect()
93    }
94
95    /// Reset statistics
96    pub fn reset(&mut self) {
97        let dim = self.mean.len();
98        self.mean = vec![0.0; dim];
99        self.variance = vec![0.0; dim];
100        self.count = 0;
101    }
102}
103
104/// Experience Replay Buffer for storing and sampling past training examples
105pub struct ReplayBuffer {
106    /// Circular buffer of replay entries
107    queries: VecDeque<ReplayEntry>,
108    /// Maximum capacity of the buffer
109    capacity: usize,
110    /// Total number of samples seen (including evicted ones)
111    total_seen: usize,
112    /// Statistics of the overall distribution
113    distribution_stats: DistributionStats,
114}
115
116impl ReplayBuffer {
117    /// Create a new replay buffer with specified capacity
118    ///
119    /// # Arguments
120    /// * `capacity` - Maximum number of entries to store
121    pub fn new(capacity: usize) -> Self {
122        Self {
123            queries: VecDeque::with_capacity(capacity),
124            capacity,
125            total_seen: 0,
126            distribution_stats: DistributionStats::new(0),
127        }
128    }
129
130    /// Add a new entry to the buffer using reservoir sampling
131    ///
132    /// Reservoir sampling ensures uniform distribution over all samples seen,
133    /// even as old samples are evicted due to capacity constraints.
134    ///
135    /// # Arguments
136    /// * `query` - Query vector
137    /// * `positive_ids` - IDs of positive nodes for this query
138    pub fn add(&mut self, query: &[f32], positive_ids: &[usize]) {
139        let entry = ReplayEntry::new(query.to_vec(), positive_ids.to_vec());
140
141        self.total_seen += 1;
142
143        // Update distribution statistics
144        self.distribution_stats.update(query);
145
146        // If buffer is not full, just add the entry
147        if self.queries.len() < self.capacity {
148            self.queries.push_back(entry);
149            return;
150        }
151
152        // Reservoir sampling: replace a random entry with probability capacity/total_seen
153        let mut rng = rand::thread_rng();
154        let random_index = rng.gen_range(0..self.total_seen);
155
156        if random_index < self.capacity {
157            self.queries[random_index] = entry;
158        }
159    }
160
161    /// Sample a batch of entries uniformly at random
162    ///
163    /// # Arguments
164    /// * `batch_size` - Number of entries to sample
165    ///
166    /// # Returns
167    /// Vector of references to sampled entries (may be smaller than batch_size if buffer is small)
168    pub fn sample(&self, batch_size: usize) -> Vec<&ReplayEntry> {
169        if self.queries.is_empty() {
170            return Vec::new();
171        }
172
173        let actual_batch_size = batch_size.min(self.queries.len());
174        let mut rng = rand::thread_rng();
175        let mut indices: Vec<usize> = (0..self.queries.len()).collect();
176
177        // Fisher-Yates shuffle for first batch_size elements
178        for i in 0..actual_batch_size {
179            let j = rng.gen_range(i..indices.len());
180            indices.swap(i, j);
181        }
182
183        indices[..actual_batch_size]
184            .iter()
185            .map(|&idx| &self.queries[idx])
186            .collect()
187    }
188
189    /// Detect distribution shift between recent samples and overall distribution
190    ///
191    /// Uses Kullback-Leibler divergence approximation based on mean and variance changes.
192    ///
193    /// # Arguments
194    /// * `recent_window` - Number of most recent samples to compare
195    ///
196    /// # Returns
197    /// Shift score (higher values indicate more significant distribution shift)
198    /// Returns 0.0 if insufficient data
199    pub fn detect_distribution_shift(&self, recent_window: usize) -> f32 {
200        if self.queries.len() < recent_window || recent_window == 0 {
201            return 0.0;
202        }
203
204        // Compute statistics for recent window
205        let mut recent_stats = DistributionStats::new(self.distribution_stats.mean.len());
206
207        let start_idx = self.queries.len().saturating_sub(recent_window);
208        for entry in self.queries.iter().skip(start_idx) {
209            recent_stats.update(&entry.query);
210        }
211
212        // Compute shift using normalized mean difference
213        let overall_mean = &self.distribution_stats.mean;
214        let recent_mean = &recent_stats.mean;
215
216        if overall_mean.is_empty() || recent_mean.is_empty() {
217            return 0.0;
218        }
219
220        let overall_std = self.distribution_stats.std_dev();
221        let mut shift_sum = 0.0;
222        let mut count = 0;
223
224        for i in 0..overall_mean.len() {
225            if overall_std[i] > 1e-8 {
226                let diff = (recent_mean[i] - overall_mean[i]).abs();
227                shift_sum += diff / overall_std[i];
228                count += 1;
229            }
230        }
231
232        if count > 0 {
233            shift_sum / count as f32
234        } else {
235            0.0
236        }
237    }
238
239    /// Get the number of entries currently in the buffer
240    pub fn len(&self) -> usize {
241        self.queries.len()
242    }
243
244    /// Check if the buffer is empty
245    pub fn is_empty(&self) -> bool {
246        self.queries.is_empty()
247    }
248
249    /// Get the total capacity of the buffer
250    pub fn capacity(&self) -> usize {
251        self.capacity
252    }
253
254    /// Get the total number of samples seen (including evicted ones)
255    pub fn total_seen(&self) -> usize {
256        self.total_seen
257    }
258
259    /// Get a reference to the distribution statistics
260    pub fn distribution_stats(&self) -> &DistributionStats {
261        &self.distribution_stats
262    }
263
264    /// Clear all entries from the buffer
265    pub fn clear(&mut self) {
266        self.queries.clear();
267        self.total_seen = 0;
268        self.distribution_stats.reset();
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_replay_buffer_basic() {
278        let mut buffer = ReplayBuffer::new(10);
279        assert_eq!(buffer.len(), 0);
280        assert!(buffer.is_empty());
281        assert_eq!(buffer.capacity(), 10);
282
283        buffer.add(&[1.0, 2.0, 3.0], &[0, 1]);
284        assert_eq!(buffer.len(), 1);
285        assert!(!buffer.is_empty());
286
287        buffer.add(&[4.0, 5.0, 6.0], &[2, 3]);
288        assert_eq!(buffer.len(), 2);
289        assert_eq!(buffer.total_seen(), 2);
290    }
291
292    #[test]
293    fn test_replay_buffer_capacity() {
294        let mut buffer = ReplayBuffer::new(3);
295
296        // Add entries up to capacity
297        for i in 0..3 {
298            buffer.add(&[i as f32], &[i]);
299        }
300        assert_eq!(buffer.len(), 3);
301
302        // Adding more should maintain capacity through reservoir sampling
303        for i in 3..10 {
304            buffer.add(&[i as f32], &[i]);
305        }
306        assert_eq!(buffer.len(), 3);
307        assert_eq!(buffer.total_seen(), 10);
308    }
309
310    #[test]
311    fn test_sample_empty_buffer() {
312        let buffer = ReplayBuffer::new(10);
313        let samples = buffer.sample(5);
314        assert!(samples.is_empty());
315    }
316
317    #[test]
318    fn test_sample_basic() {
319        let mut buffer = ReplayBuffer::new(10);
320
321        for i in 0..5 {
322            buffer.add(&[i as f32], &[i]);
323        }
324
325        let samples = buffer.sample(3);
326        assert_eq!(samples.len(), 3);
327
328        // Check that samples are from the buffer
329        for sample in samples {
330            assert!(sample.query[0] >= 0.0 && sample.query[0] < 5.0);
331        }
332    }
333
334    #[test]
335    fn test_sample_larger_than_buffer() {
336        let mut buffer = ReplayBuffer::new(10);
337
338        buffer.add(&[1.0], &[0]);
339        buffer.add(&[2.0], &[1]);
340
341        let samples = buffer.sample(5);
342        assert_eq!(samples.len(), 2); // Can only return what's available
343    }
344
345    #[test]
346    fn test_distribution_stats_update() {
347        let mut stats = DistributionStats::new(2);
348
349        stats.update(&[1.0, 2.0]);
350        assert_eq!(stats.count, 1);
351        assert_eq!(stats.mean, vec![1.0, 2.0]);
352
353        stats.update(&[3.0, 4.0]);
354        assert_eq!(stats.count, 2);
355        assert_eq!(stats.mean, vec![2.0, 3.0]);
356
357        stats.update(&[2.0, 3.0]);
358        assert_eq!(stats.count, 3);
359        assert_eq!(stats.mean, vec![2.0, 3.0]);
360    }
361
362    #[test]
363    fn test_distribution_stats_std_dev() {
364        let mut stats = DistributionStats::new(2);
365
366        stats.update(&[1.0, 1.0]);
367        stats.update(&[3.0, 3.0]);
368        stats.update(&[5.0, 5.0]);
369
370        let std_dev = stats.std_dev();
371        // Expected std dev for [1, 3, 5] is 2.0
372        assert!((std_dev[0] - 2.0).abs() < 0.01);
373        assert!((std_dev[1] - 2.0).abs() < 0.01);
374    }
375
376    #[test]
377    fn test_detect_distribution_shift_no_shift() {
378        let mut buffer = ReplayBuffer::new(100);
379
380        // Add samples from the same distribution
381        for _ in 0..50 {
382            buffer.add(&[1.0, 2.0, 3.0], &[0]);
383        }
384
385        let shift = buffer.detect_distribution_shift(10);
386        assert!(shift < 0.1); // Should be very low
387    }
388
389    #[test]
390    fn test_detect_distribution_shift_with_shift() {
391        let mut buffer = ReplayBuffer::new(100);
392
393        // Add samples from one distribution
394        for _ in 0..40 {
395            buffer.add(&[1.0, 2.0, 3.0], &[0]);
396        }
397
398        // Add samples from a different distribution
399        for _ in 0..10 {
400            buffer.add(&[5.0, 6.0, 7.0], &[1]);
401        }
402
403        let shift = buffer.detect_distribution_shift(10);
404        assert!(shift > 0.5); // Should detect significant shift
405    }
406
407    #[test]
408    fn test_detect_distribution_shift_insufficient_data() {
409        let mut buffer = ReplayBuffer::new(100);
410
411        buffer.add(&[1.0, 2.0], &[0]);
412
413        let shift = buffer.detect_distribution_shift(10);
414        assert_eq!(shift, 0.0); // Not enough data
415    }
416
417    #[test]
418    fn test_clear() {
419        let mut buffer = ReplayBuffer::new(10);
420
421        for i in 0..5 {
422            buffer.add(&[i as f32], &[i]);
423        }
424
425        assert_eq!(buffer.len(), 5);
426        assert_eq!(buffer.total_seen(), 5);
427
428        buffer.clear();
429        assert_eq!(buffer.len(), 0);
430        assert_eq!(buffer.total_seen(), 0);
431        assert!(buffer.is_empty());
432        assert_eq!(buffer.distribution_stats().count, 0);
433    }
434
435    #[test]
436    fn test_replay_entry_creation() {
437        let entry = ReplayEntry::new(vec![1.0, 2.0, 3.0], vec![0, 1, 2]);
438
439        assert_eq!(entry.query, vec![1.0, 2.0, 3.0]);
440        assert_eq!(entry.positive_ids, vec![0, 1, 2]);
441        assert!(entry.timestamp > 0);
442    }
443
444    #[test]
445    fn test_reservoir_sampling_distribution() {
446        let mut buffer = ReplayBuffer::new(10);
447
448        // Add 100 entries (much more than capacity)
449        for i in 0..100 {
450            buffer.add(&[i as f32], &[i]);
451        }
452
453        assert_eq!(buffer.len(), 10);
454        assert_eq!(buffer.total_seen(), 100);
455
456        // Sample multiple times and verify we get different samples
457        let samples1 = buffer.sample(5);
458        let samples2 = buffer.sample(5);
459
460        assert_eq!(samples1.len(), 5);
461        assert_eq!(samples2.len(), 5);
462
463        // Check that samples come from the full range (not just recent entries)
464        let sample_batch = buffer.sample(10);
465        let values: Vec<f32> = sample_batch.iter().map(|e| e.query[0]).collect();
466
467        // With reservoir sampling, we should have some diversity in values
468        let unique_values: std::collections::HashSet<_> =
469            values.iter().map(|&v| v as i32).collect();
470        assert!(unique_values.len() > 1);
471    }
472
473    #[test]
474    fn test_dimension_mismatch_handling() {
475        let mut buffer = ReplayBuffer::new(10);
476
477        buffer.add(&[1.0, 2.0], &[0]);
478
479        // This should not panic, just be handled gracefully
480        // The implementation will initialize stats on first add
481        assert_eq!(buffer.len(), 1);
482        assert_eq!(buffer.distribution_stats().mean.len(), 2);
483    }
484
485    #[test]
486    fn test_sample_uniqueness() {
487        let mut buffer = ReplayBuffer::new(5);
488
489        for i in 0..5 {
490            buffer.add(&[i as f32], &[i]);
491        }
492
493        // Sample all entries
494        let samples = buffer.sample(5);
495        let values: Vec<f32> = samples.iter().map(|e| e.query[0]).collect();
496
497        // All samples should be unique (no duplicates in a single batch)
498        let unique_values: std::collections::HashSet<_> =
499            values.iter().map(|&v| v as i32).collect();
500        assert_eq!(unique_values.len(), 5);
501    }
502}