1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, VecDeque};
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct OnlineLearningConfig {
9 pub buffer_size: usize,
10 pub batch_size: usize,
11 pub update_frequency: Duration,
12 pub forgetting_factor: f32,
13 pub adaptation_rate: f32,
14 pub drift_detection_threshold: f32,
15 pub window_size: usize,
16 pub min_samples_for_update: usize,
17 pub enable_concept_drift_detection: bool,
18 pub enable_adaptive_learning_rate: bool,
19}
20
21impl Default for OnlineLearningConfig {
22 fn default() -> Self {
23 Self {
24 buffer_size: 10000,
25 batch_size: 32,
26 update_frequency: Duration::from_secs(60),
27 forgetting_factor: 0.99,
28 adaptation_rate: 0.01,
29 drift_detection_threshold: 0.1,
30 window_size: 1000,
31 min_samples_for_update: 10,
32 enable_concept_drift_detection: true,
33 enable_adaptive_learning_rate: true,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct OnlineDataPoint {
40 pub features: Vec<f32>,
41 pub label: Vec<f32>,
42 pub timestamp: Instant,
43 pub importance_weight: f32,
44}
45
46#[derive(Debug, Clone)]
47pub struct ConceptDrift {
48 pub detected: bool,
49 pub drift_score: f32,
50 pub detection_time: Instant,
51 pub drift_type: DriftType,
52}
53
54#[derive(Debug, Clone)]
55pub enum DriftType {
56 Gradual,
57 Sudden,
58 Incremental,
59 Recurring,
60}
61
62pub struct PerformanceWindow {
63 scores: VecDeque<f32>,
64 timestamps: VecDeque<Instant>,
65 window_size: usize,
66}
67
68impl PerformanceWindow {
69 pub fn new(window_size: usize) -> Self {
70 Self {
71 scores: VecDeque::with_capacity(window_size),
72 timestamps: VecDeque::with_capacity(window_size),
73 window_size,
74 }
75 }
76
77 pub fn add_score(&mut self, score: f32) {
78 if self.scores.len() >= self.window_size {
79 self.scores.pop_front();
80 self.timestamps.pop_front();
81 }
82 self.scores.push_back(score);
83 self.timestamps.push_back(Instant::now());
84 }
85
86 pub fn mean(&self) -> f32 {
87 if self.scores.is_empty() {
88 0.0
89 } else {
90 self.scores.iter().sum::<f32>() / self.scores.len() as f32
91 }
92 }
93
94 pub fn variance(&self) -> f32 {
95 if self.scores.len() < 2 {
96 0.0
97 } else {
98 let mean = self.mean();
99 let variance: f32 = self.scores.iter().map(|score| (score - mean).powi(2)).sum::<f32>()
100 / (self.scores.len() - 1) as f32;
101 variance
102 }
103 }
104
105 pub fn is_full(&self) -> bool {
106 self.scores.len() >= self.window_size
107 }
108}
109
110pub struct OnlineLearningManager {
111 config: OnlineLearningConfig,
112 data_buffer: Arc<RwLock<VecDeque<OnlineDataPoint>>>,
113 performance_window: Arc<RwLock<PerformanceWindow>>,
114 model_state: Arc<RwLock<HashMap<String, Vec<f32>>>>,
115 last_update: Arc<RwLock<Instant>>,
116 learning_rate: Arc<RwLock<f32>>,
117 drift_detector: Arc<RwLock<ConceptDrift>>,
118 statistics: Arc<RwLock<OnlineStatistics>>,
119}
120
121#[derive(Debug, Default, Clone)]
122pub struct OnlineStatistics {
123 pub total_samples_processed: usize,
124 pub total_updates: usize,
125 pub concept_drifts_detected: usize,
126 pub average_latency: Duration,
127 pub throughput: f32,
128 pub last_performance_score: f32,
129}
130
131impl OnlineLearningManager {
132 pub fn new(config: OnlineLearningConfig) -> Self {
133 Self {
134 performance_window: Arc::new(RwLock::new(PerformanceWindow::new(config.window_size))),
135 data_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(config.buffer_size))),
136 model_state: Arc::new(RwLock::new(HashMap::new())),
137 last_update: Arc::new(RwLock::new(Instant::now())),
138 learning_rate: Arc::new(RwLock::new(config.adaptation_rate)),
139 drift_detector: Arc::new(RwLock::new(ConceptDrift {
140 detected: false,
141 drift_score: 0.0,
142 detection_time: Instant::now(),
143 drift_type: DriftType::Gradual,
144 })),
145 statistics: Arc::new(RwLock::new(OnlineStatistics::default())),
146 config,
147 }
148 }
149
150 pub fn add_data_point(&self, data_point: OnlineDataPoint) -> Result<()> {
151 let mut buffer = self
152 .data_buffer
153 .write()
154 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on data buffer"))?;
155
156 if buffer.len() >= self.config.buffer_size {
157 buffer.pop_front();
158 }
159 buffer.push_back(data_point);
160
161 {
163 let mut stats = self
164 .statistics
165 .write()
166 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
167 stats.total_samples_processed += 1;
168 }
169
170 let should_update = {
172 let last_update = self
173 .last_update
174 .read()
175 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on last_update"))?;
176 last_update.elapsed() >= self.config.update_frequency
177 && buffer.len() >= self.config.min_samples_for_update
178 };
179
180 if should_update {
181 self.trigger_model_update()?;
182 }
183
184 Ok(())
185 }
186
187 pub fn trigger_model_update(&self) -> Result<()> {
188 let start_time = Instant::now();
189
190 let batch = self.get_training_batch()?;
192
193 if batch.is_empty() {
194 return Ok(());
195 }
196
197 self.update_model_weights(&batch)?;
199
200 {
202 let mut last_update = self
203 .last_update
204 .write()
205 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on last_update"))?;
206 *last_update = Instant::now();
207 }
208
209 {
211 let mut stats = self
212 .statistics
213 .write()
214 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
215 stats.total_updates += 1;
216 stats.average_latency = (stats.average_latency + start_time.elapsed()) / 2;
217 }
218
219 Ok(())
220 }
221
222 fn get_training_batch(&self) -> Result<Vec<OnlineDataPoint>> {
223 let buffer = self
224 .data_buffer
225 .read()
226 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on data buffer"))?;
227
228 let batch_size = std::cmp::min(self.config.batch_size, buffer.len());
229 let batch: Vec<_> = buffer.iter().rev().take(batch_size).cloned().collect();
230
231 Ok(batch)
232 }
233
234 fn update_model_weights(&self, batch: &[OnlineDataPoint]) -> Result<()> {
235 let learning_rate = {
236 let lr = self
237 .learning_rate
238 .read()
239 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on learning_rate"))?;
240 *lr
241 };
242
243 let mut model_state = self
245 .model_state
246 .write()
247 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on model_state"))?;
248
249 for data_point in batch {
251 let weight_key = "layer_weights".to_string();
252 let weights = model_state.entry(weight_key).or_insert_with(|| vec![0.0; 128]);
253
254 for (i, feature) in data_point.features.iter().enumerate() {
256 if i < weights.len() {
257 weights[i] += learning_rate * feature * data_point.importance_weight;
258 }
259 }
260 }
261
262 Ok(())
263 }
264
265 pub fn detect_concept_drift(&self, current_performance: f32) -> Result<bool> {
266 if !self.config.enable_concept_drift_detection {
267 return Ok(false);
268 }
269
270 {
272 let mut window = self.performance_window.write().map_err(|_| {
273 anyhow::anyhow!("Failed to acquire write lock on performance window")
274 })?;
275 window.add_score(current_performance);
276
277 if !window.is_full() {
278 return Ok(false);
279 }
280 }
281
282 let window = self
283 .performance_window
284 .read()
285 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on performance window"))?;
286
287 let recent_mean = {
289 let recent_scores: Vec<_> =
290 window.scores.iter().rev().take(window.window_size / 4).cloned().collect();
291 if recent_scores.is_empty() {
292 return Ok(false);
293 }
294 recent_scores.iter().sum::<f32>() / recent_scores.len() as f32
295 };
296
297 let historical_mean = window.mean();
298 let performance_drop = historical_mean - recent_mean;
299
300 let drift_detected = performance_drop > self.config.drift_detection_threshold;
301
302 if drift_detected {
303 let mut drift_detector = self
304 .drift_detector
305 .write()
306 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on drift detector"))?;
307
308 drift_detector.detected = true;
309 drift_detector.drift_score = performance_drop;
310 drift_detector.detection_time = Instant::now();
311 drift_detector.drift_type =
312 if performance_drop > self.config.drift_detection_threshold * 2.0 {
313 DriftType::Sudden
314 } else {
315 DriftType::Gradual
316 };
317
318 let mut stats = self
320 .statistics
321 .write()
322 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
323 stats.concept_drifts_detected += 1;
324
325 if self.config.enable_adaptive_learning_rate {
327 self.adapt_learning_rate(performance_drop)?;
328 }
329 }
330
331 Ok(drift_detected)
332 }
333
334 fn adapt_learning_rate(&self, drift_score: f32) -> Result<()> {
335 let mut learning_rate = self
336 .learning_rate
337 .write()
338 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on learning_rate"))?;
339
340 let adaptation_factor = 1.0 + drift_score;
342 *learning_rate = (*learning_rate * adaptation_factor).min(0.1); Ok(())
345 }
346
347 pub fn reset_after_drift(&self) -> Result<()> {
348 {
350 let mut buffer = self
351 .data_buffer
352 .write()
353 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on data buffer"))?;
354 buffer.clear();
355 }
356
357 {
359 let mut drift_detector = self
360 .drift_detector
361 .write()
362 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on drift detector"))?;
363 drift_detector.detected = false;
364 drift_detector.drift_score = 0.0;
365 }
366
367 {
369 let mut learning_rate = self
370 .learning_rate
371 .write()
372 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on learning_rate"))?;
373 *learning_rate = self.config.adaptation_rate;
374 }
375
376 Ok(())
377 }
378
379 pub fn get_statistics(&self) -> Result<OnlineStatistics> {
380 let stats = self
381 .statistics
382 .read()
383 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on statistics"))?;
384 Ok((*stats).clone())
385 }
386
387 pub fn get_current_drift_state(&self) -> Result<ConceptDrift> {
388 let drift = self
389 .drift_detector
390 .read()
391 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on drift detector"))?;
392 Ok(drift.clone())
393 }
394
395 pub fn get_buffer_size(&self) -> Result<usize> {
396 let buffer = self
397 .data_buffer
398 .read()
399 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on data buffer"))?;
400 Ok(buffer.len())
401 }
402}
403
404#[derive(Debug)]
405pub enum OnlineLearningError {
406 BufferFull,
407 ModelUpdateFailed,
408 DriftDetectionFailed,
409 ConfigurationError,
410}
411
412impl std::fmt::Display for OnlineLearningError {
413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414 match self {
415 OnlineLearningError::BufferFull => write!(f, "Data buffer is full"),
416 OnlineLearningError::ModelUpdateFailed => write!(f, "Model update failed"),
417 OnlineLearningError::DriftDetectionFailed => write!(f, "Drift detection failed"),
418 OnlineLearningError::ConfigurationError => write!(f, "Configuration error"),
419 }
420 }
421}
422
423impl std::error::Error for OnlineLearningError {}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_online_learning_manager_creation() {
431 let config = OnlineLearningConfig::default();
432 let manager = OnlineLearningManager::new(config);
433 assert_eq!(
434 manager.get_buffer_size().expect("operation failed in test"),
435 0
436 );
437 }
438
439 #[test]
440 fn test_add_data_point() {
441 let config = OnlineLearningConfig::default();
442 let manager = OnlineLearningManager::new(config);
443
444 let data_point = OnlineDataPoint {
445 features: vec![1.0, 2.0, 3.0],
446 label: vec![1.0],
447 timestamp: Instant::now(),
448 importance_weight: 1.0,
449 };
450
451 manager.add_data_point(data_point).expect("add operation failed");
452 assert_eq!(
453 manager.get_buffer_size().expect("operation failed in test"),
454 1
455 );
456 }
457
458 #[test]
459 fn test_performance_window() {
460 let mut window = PerformanceWindow::new(3);
461
462 window.add_score(0.8);
463 window.add_score(0.7);
464 window.add_score(0.9);
465
466 assert_eq!(window.mean(), (0.8 + 0.7 + 0.9) / 3.0);
467 assert!(window.is_full());
468
469 window.add_score(0.6);
470 assert_eq!(window.scores.len(), 3);
471 assert_eq!(window.scores[0], 0.7); }
473
474 #[test]
475 fn test_concept_drift_detection() {
476 let config = OnlineLearningConfig {
477 window_size: 4,
478 drift_detection_threshold: 0.1,
479 ..Default::default()
480 };
481 let manager = OnlineLearningManager::new(config);
482
483 assert!(!manager.detect_concept_drift(0.9).expect("operation failed in test"));
485 assert!(!manager.detect_concept_drift(0.85).expect("operation failed in test"));
486 assert!(!manager.detect_concept_drift(0.88).expect("operation failed in test"));
487 assert!(!manager.detect_concept_drift(0.87).expect("operation failed in test"));
488
489 assert!(manager.detect_concept_drift(0.6).expect("operation failed in test"));
491
492 let drift_state = manager.get_current_drift_state().expect("operation failed in test");
493 assert!(drift_state.detected);
494 }
495
496 #[test]
497 fn test_buffer_capacity() {
498 let config = OnlineLearningConfig {
499 buffer_size: 2,
500 ..Default::default()
501 };
502 let manager = OnlineLearningManager::new(config);
503
504 for i in 0..5 {
505 let data_point = OnlineDataPoint {
506 features: vec![i as f32],
507 label: vec![i as f32],
508 timestamp: Instant::now(),
509 importance_weight: 1.0,
510 };
511 manager.add_data_point(data_point).expect("add operation failed");
512 }
513
514 assert_eq!(
516 manager.get_buffer_size().expect("operation failed in test"),
517 2
518 );
519 }
520
521 #[test]
522 fn test_statistics_tracking() {
523 let config = OnlineLearningConfig::default();
524 let manager = OnlineLearningManager::new(config);
525
526 let data_point = OnlineDataPoint {
527 features: vec![1.0, 2.0],
528 label: vec![1.0],
529 timestamp: Instant::now(),
530 importance_weight: 1.0,
531 };
532
533 manager.add_data_point(data_point).expect("add operation failed");
534
535 let stats = manager.get_statistics().expect("operation failed in test");
536 assert_eq!(stats.total_samples_processed, 1);
537 }
538}