Skip to main content

trustformers_training/
online_learning.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, VecDeque};
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct OnlineLearningConfig {
9    pub buffer_size: usize,
10    pub batch_size: usize,
11    pub update_frequency: Duration,
12    pub forgetting_factor: f32,
13    pub adaptation_rate: f32,
14    pub drift_detection_threshold: f32,
15    pub window_size: usize,
16    pub min_samples_for_update: usize,
17    pub enable_concept_drift_detection: bool,
18    pub enable_adaptive_learning_rate: bool,
19}
20
21impl Default for OnlineLearningConfig {
22    fn default() -> Self {
23        Self {
24            buffer_size: 10000,
25            batch_size: 32,
26            update_frequency: Duration::from_secs(60),
27            forgetting_factor: 0.99,
28            adaptation_rate: 0.01,
29            drift_detection_threshold: 0.1,
30            window_size: 1000,
31            min_samples_for_update: 10,
32            enable_concept_drift_detection: true,
33            enable_adaptive_learning_rate: true,
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct OnlineDataPoint {
40    pub features: Vec<f32>,
41    pub label: Vec<f32>,
42    pub timestamp: Instant,
43    pub importance_weight: f32,
44}
45
46#[derive(Debug, Clone)]
47pub struct ConceptDrift {
48    pub detected: bool,
49    pub drift_score: f32,
50    pub detection_time: Instant,
51    pub drift_type: DriftType,
52}
53
54#[derive(Debug, Clone)]
55pub enum DriftType {
56    Gradual,
57    Sudden,
58    Incremental,
59    Recurring,
60}
61
62pub struct PerformanceWindow {
63    scores: VecDeque<f32>,
64    timestamps: VecDeque<Instant>,
65    window_size: usize,
66}
67
68impl PerformanceWindow {
69    pub fn new(window_size: usize) -> Self {
70        Self {
71            scores: VecDeque::with_capacity(window_size),
72            timestamps: VecDeque::with_capacity(window_size),
73            window_size,
74        }
75    }
76
77    pub fn add_score(&mut self, score: f32) {
78        if self.scores.len() >= self.window_size {
79            self.scores.pop_front();
80            self.timestamps.pop_front();
81        }
82        self.scores.push_back(score);
83        self.timestamps.push_back(Instant::now());
84    }
85
86    pub fn mean(&self) -> f32 {
87        if self.scores.is_empty() {
88            0.0
89        } else {
90            self.scores.iter().sum::<f32>() / self.scores.len() as f32
91        }
92    }
93
94    pub fn variance(&self) -> f32 {
95        if self.scores.len() < 2 {
96            0.0
97        } else {
98            let mean = self.mean();
99            let variance: f32 = self.scores.iter().map(|score| (score - mean).powi(2)).sum::<f32>()
100                / (self.scores.len() - 1) as f32;
101            variance
102        }
103    }
104
105    pub fn is_full(&self) -> bool {
106        self.scores.len() >= self.window_size
107    }
108}
109
110pub struct OnlineLearningManager {
111    config: OnlineLearningConfig,
112    data_buffer: Arc<RwLock<VecDeque<OnlineDataPoint>>>,
113    performance_window: Arc<RwLock<PerformanceWindow>>,
114    model_state: Arc<RwLock<HashMap<String, Vec<f32>>>>,
115    last_update: Arc<RwLock<Instant>>,
116    learning_rate: Arc<RwLock<f32>>,
117    drift_detector: Arc<RwLock<ConceptDrift>>,
118    statistics: Arc<RwLock<OnlineStatistics>>,
119}
120
121#[derive(Debug, Default, Clone)]
122pub struct OnlineStatistics {
123    pub total_samples_processed: usize,
124    pub total_updates: usize,
125    pub concept_drifts_detected: usize,
126    pub average_latency: Duration,
127    pub throughput: f32,
128    pub last_performance_score: f32,
129}
130
131impl OnlineLearningManager {
132    pub fn new(config: OnlineLearningConfig) -> Self {
133        Self {
134            performance_window: Arc::new(RwLock::new(PerformanceWindow::new(config.window_size))),
135            data_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(config.buffer_size))),
136            model_state: Arc::new(RwLock::new(HashMap::new())),
137            last_update: Arc::new(RwLock::new(Instant::now())),
138            learning_rate: Arc::new(RwLock::new(config.adaptation_rate)),
139            drift_detector: Arc::new(RwLock::new(ConceptDrift {
140                detected: false,
141                drift_score: 0.0,
142                detection_time: Instant::now(),
143                drift_type: DriftType::Gradual,
144            })),
145            statistics: Arc::new(RwLock::new(OnlineStatistics::default())),
146            config,
147        }
148    }
149
150    pub fn add_data_point(&self, data_point: OnlineDataPoint) -> Result<()> {
151        let mut buffer = self
152            .data_buffer
153            .write()
154            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on data buffer"))?;
155
156        if buffer.len() >= self.config.buffer_size {
157            buffer.pop_front();
158        }
159        buffer.push_back(data_point);
160
161        // Update statistics
162        {
163            let mut stats = self
164                .statistics
165                .write()
166                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
167            stats.total_samples_processed += 1;
168        }
169
170        // Check if we should trigger an update
171        let should_update = {
172            let last_update = self
173                .last_update
174                .read()
175                .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on last_update"))?;
176            last_update.elapsed() >= self.config.update_frequency
177                && buffer.len() >= self.config.min_samples_for_update
178        };
179
180        if should_update {
181            self.trigger_model_update()?;
182        }
183
184        Ok(())
185    }
186
187    pub fn trigger_model_update(&self) -> Result<()> {
188        let start_time = Instant::now();
189
190        // Get batch of data
191        let batch = self.get_training_batch()?;
192
193        if batch.is_empty() {
194            return Ok(());
195        }
196
197        // Perform model update (simplified - in practice would call actual model training)
198        self.update_model_weights(&batch)?;
199
200        // Update last update time
201        {
202            let mut last_update = self
203                .last_update
204                .write()
205                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on last_update"))?;
206            *last_update = Instant::now();
207        }
208
209        // Update statistics
210        {
211            let mut stats = self
212                .statistics
213                .write()
214                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
215            stats.total_updates += 1;
216            stats.average_latency = (stats.average_latency + start_time.elapsed()) / 2;
217        }
218
219        Ok(())
220    }
221
222    fn get_training_batch(&self) -> Result<Vec<OnlineDataPoint>> {
223        let buffer = self
224            .data_buffer
225            .read()
226            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on data buffer"))?;
227
228        let batch_size = std::cmp::min(self.config.batch_size, buffer.len());
229        let batch: Vec<_> = buffer.iter().rev().take(batch_size).cloned().collect();
230
231        Ok(batch)
232    }
233
234    fn update_model_weights(&self, batch: &[OnlineDataPoint]) -> Result<()> {
235        let learning_rate = {
236            let lr = self
237                .learning_rate
238                .read()
239                .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on learning_rate"))?;
240            *lr
241        };
242
243        // Simplified weight update - in practice would call actual model training
244        let mut model_state = self
245            .model_state
246            .write()
247            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on model_state"))?;
248
249        // Calculate gradients and update weights (simplified)
250        for data_point in batch {
251            let weight_key = "layer_weights".to_string();
252            let weights = model_state.entry(weight_key).or_insert_with(|| vec![0.0; 128]);
253
254            // Simplified gradient computation and weight update
255            for (i, feature) in data_point.features.iter().enumerate() {
256                if i < weights.len() {
257                    weights[i] += learning_rate * feature * data_point.importance_weight;
258                }
259            }
260        }
261
262        Ok(())
263    }
264
265    pub fn detect_concept_drift(&self, current_performance: f32) -> Result<bool> {
266        if !self.config.enable_concept_drift_detection {
267            return Ok(false);
268        }
269
270        // Add current performance to window
271        {
272            let mut window = self.performance_window.write().map_err(|_| {
273                anyhow::anyhow!("Failed to acquire write lock on performance window")
274            })?;
275            window.add_score(current_performance);
276
277            if !window.is_full() {
278                return Ok(false);
279            }
280        }
281
282        let window = self
283            .performance_window
284            .read()
285            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on performance window"))?;
286
287        // Simple drift detection using performance degradation
288        let recent_mean = {
289            let recent_scores: Vec<_> =
290                window.scores.iter().rev().take(window.window_size / 4).cloned().collect();
291            if recent_scores.is_empty() {
292                return Ok(false);
293            }
294            recent_scores.iter().sum::<f32>() / recent_scores.len() as f32
295        };
296
297        let historical_mean = window.mean();
298        let performance_drop = historical_mean - recent_mean;
299
300        let drift_detected = performance_drop > self.config.drift_detection_threshold;
301
302        if drift_detected {
303            let mut drift_detector = self
304                .drift_detector
305                .write()
306                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on drift detector"))?;
307
308            drift_detector.detected = true;
309            drift_detector.drift_score = performance_drop;
310            drift_detector.detection_time = Instant::now();
311            drift_detector.drift_type =
312                if performance_drop > self.config.drift_detection_threshold * 2.0 {
313                    DriftType::Sudden
314                } else {
315                    DriftType::Gradual
316                };
317
318            // Update statistics
319            let mut stats = self
320                .statistics
321                .write()
322                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
323            stats.concept_drifts_detected += 1;
324
325            // Adapt learning rate if enabled
326            if self.config.enable_adaptive_learning_rate {
327                self.adapt_learning_rate(performance_drop)?;
328            }
329        }
330
331        Ok(drift_detected)
332    }
333
334    fn adapt_learning_rate(&self, drift_score: f32) -> Result<()> {
335        let mut learning_rate = self
336            .learning_rate
337            .write()
338            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on learning_rate"))?;
339
340        // Increase learning rate proportionally to drift severity
341        let adaptation_factor = 1.0 + drift_score;
342        *learning_rate = (*learning_rate * adaptation_factor).min(0.1); // Cap at 0.1
343
344        Ok(())
345    }
346
347    pub fn reset_after_drift(&self) -> Result<()> {
348        // Clear data buffer to start fresh
349        {
350            let mut buffer = self
351                .data_buffer
352                .write()
353                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on data buffer"))?;
354            buffer.clear();
355        }
356
357        // Reset drift detector
358        {
359            let mut drift_detector = self
360                .drift_detector
361                .write()
362                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on drift detector"))?;
363            drift_detector.detected = false;
364            drift_detector.drift_score = 0.0;
365        }
366
367        // Reset learning rate to initial value
368        {
369            let mut learning_rate = self
370                .learning_rate
371                .write()
372                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on learning_rate"))?;
373            *learning_rate = self.config.adaptation_rate;
374        }
375
376        Ok(())
377    }
378
379    pub fn get_statistics(&self) -> Result<OnlineStatistics> {
380        let stats = self
381            .statistics
382            .read()
383            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on statistics"))?;
384        Ok((*stats).clone())
385    }
386
387    pub fn get_current_drift_state(&self) -> Result<ConceptDrift> {
388        let drift = self
389            .drift_detector
390            .read()
391            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on drift detector"))?;
392        Ok(drift.clone())
393    }
394
395    pub fn get_buffer_size(&self) -> Result<usize> {
396        let buffer = self
397            .data_buffer
398            .read()
399            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on data buffer"))?;
400        Ok(buffer.len())
401    }
402}
403
404#[derive(Debug)]
405pub enum OnlineLearningError {
406    BufferFull,
407    ModelUpdateFailed,
408    DriftDetectionFailed,
409    ConfigurationError,
410}
411
412impl std::fmt::Display for OnlineLearningError {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        match self {
415            OnlineLearningError::BufferFull => write!(f, "Data buffer is full"),
416            OnlineLearningError::ModelUpdateFailed => write!(f, "Model update failed"),
417            OnlineLearningError::DriftDetectionFailed => write!(f, "Drift detection failed"),
418            OnlineLearningError::ConfigurationError => write!(f, "Configuration error"),
419        }
420    }
421}
422
423impl std::error::Error for OnlineLearningError {}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_online_learning_manager_creation() {
431        let config = OnlineLearningConfig::default();
432        let manager = OnlineLearningManager::new(config);
433        assert_eq!(
434            manager.get_buffer_size().expect("operation failed in test"),
435            0
436        );
437    }
438
439    #[test]
440    fn test_add_data_point() {
441        let config = OnlineLearningConfig::default();
442        let manager = OnlineLearningManager::new(config);
443
444        let data_point = OnlineDataPoint {
445            features: vec![1.0, 2.0, 3.0],
446            label: vec![1.0],
447            timestamp: Instant::now(),
448            importance_weight: 1.0,
449        };
450
451        manager.add_data_point(data_point).expect("add operation failed");
452        assert_eq!(
453            manager.get_buffer_size().expect("operation failed in test"),
454            1
455        );
456    }
457
458    #[test]
459    fn test_performance_window() {
460        let mut window = PerformanceWindow::new(3);
461
462        window.add_score(0.8);
463        window.add_score(0.7);
464        window.add_score(0.9);
465
466        assert_eq!(window.mean(), (0.8 + 0.7 + 0.9) / 3.0);
467        assert!(window.is_full());
468
469        window.add_score(0.6);
470        assert_eq!(window.scores.len(), 3);
471        assert_eq!(window.scores[0], 0.7); // 0.8 should be removed
472    }
473
474    #[test]
475    fn test_concept_drift_detection() {
476        let config = OnlineLearningConfig {
477            window_size: 4,
478            drift_detection_threshold: 0.1,
479            ..Default::default()
480        };
481        let manager = OnlineLearningManager::new(config);
482
483        // Add some high performance scores
484        assert!(!manager.detect_concept_drift(0.9).expect("operation failed in test"));
485        assert!(!manager.detect_concept_drift(0.85).expect("operation failed in test"));
486        assert!(!manager.detect_concept_drift(0.88).expect("operation failed in test"));
487        assert!(!manager.detect_concept_drift(0.87).expect("operation failed in test"));
488
489        // Add a significantly lower score that should trigger drift detection
490        assert!(manager.detect_concept_drift(0.6).expect("operation failed in test"));
491
492        let drift_state = manager.get_current_drift_state().expect("operation failed in test");
493        assert!(drift_state.detected);
494    }
495
496    #[test]
497    fn test_buffer_capacity() {
498        let config = OnlineLearningConfig {
499            buffer_size: 2,
500            ..Default::default()
501        };
502        let manager = OnlineLearningManager::new(config);
503
504        for i in 0..5 {
505            let data_point = OnlineDataPoint {
506                features: vec![i as f32],
507                label: vec![i as f32],
508                timestamp: Instant::now(),
509                importance_weight: 1.0,
510            };
511            manager.add_data_point(data_point).expect("add operation failed");
512        }
513
514        // Buffer should not exceed capacity
515        assert_eq!(
516            manager.get_buffer_size().expect("operation failed in test"),
517            2
518        );
519    }
520
521    #[test]
522    fn test_statistics_tracking() {
523        let config = OnlineLearningConfig::default();
524        let manager = OnlineLearningManager::new(config);
525
526        let data_point = OnlineDataPoint {
527            features: vec![1.0, 2.0],
528            label: vec![1.0],
529            timestamp: Instant::now(),
530            importance_weight: 1.0,
531        };
532
533        manager.add_data_point(data_point).expect("add operation failed");
534
535        let stats = manager.get_statistics().expect("operation failed in test");
536        assert_eq!(stats.total_samples_processed, 1);
537    }
538}