Skip to main content

ruvector_dag/healing/
drift_detector.rs

1//! Learning Drift Detection
2
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct DriftMetric {
7    pub name: String,
8    pub current_value: f64,
9    pub baseline_value: f64,
10    pub drift_magnitude: f64,
11    pub trend: DriftTrend,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum DriftTrend {
16    Improving,
17    Stable,
18    Declining,
19}
20
21pub struct LearningDriftDetector {
22    baselines: HashMap<String, f64>,
23    current_values: HashMap<String, Vec<f64>>,
24    drift_threshold: f64,
25    window_size: usize,
26}
27
28impl LearningDriftDetector {
29    pub fn new(drift_threshold: f64, window_size: usize) -> Self {
30        Self {
31            baselines: HashMap::new(),
32            current_values: HashMap::new(),
33            drift_threshold,
34            window_size,
35        }
36    }
37
38    pub fn set_baseline(&mut self, metric: &str, value: f64) {
39        self.baselines.insert(metric.to_string(), value);
40    }
41
42    pub fn record(&mut self, metric: &str, value: f64) {
43        let values = self
44            .current_values
45            .entry(metric.to_string())
46            .or_insert_with(Vec::new);
47
48        values.push(value);
49
50        // Keep only window_size values
51        if values.len() > self.window_size {
52            values.remove(0);
53        }
54    }
55
56    pub fn check_drift(&self, metric: &str) -> Option<DriftMetric> {
57        let baseline = self.baselines.get(metric)?;
58        let values = self.current_values.get(metric)?;
59
60        if values.is_empty() {
61            return None;
62        }
63
64        let current = values.iter().sum::<f64>() / values.len() as f64;
65        let drift_magnitude = (current - baseline).abs() / baseline.abs().max(1e-10);
66
67        let trend = if current > *baseline * 1.05 {
68            DriftTrend::Improving
69        } else if current < *baseline * 0.95 {
70            DriftTrend::Declining
71        } else {
72            DriftTrend::Stable
73        };
74
75        Some(DriftMetric {
76            name: metric.to_string(),
77            current_value: current,
78            baseline_value: *baseline,
79            drift_magnitude,
80            trend,
81        })
82    }
83
84    pub fn check_all_drifts(&self) -> Vec<DriftMetric> {
85        self.baselines
86            .keys()
87            .filter_map(|metric| self.check_drift(metric))
88            .filter(|d| d.drift_magnitude > self.drift_threshold)
89            .collect()
90    }
91
92    pub fn drift_threshold(&self) -> f64 {
93        self.drift_threshold
94    }
95
96    pub fn window_size(&self) -> usize {
97        self.window_size
98    }
99
100    pub fn metrics(&self) -> Vec<String> {
101        self.baselines.keys().cloned().collect()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_baseline_setting() {
111        let mut detector = LearningDriftDetector::new(0.1, 10);
112        detector.set_baseline("accuracy", 0.95);
113
114        assert_eq!(detector.baselines.get("accuracy"), Some(&0.95));
115    }
116
117    #[test]
118    fn test_stable_metric() {
119        let mut detector = LearningDriftDetector::new(0.1, 10);
120        detector.set_baseline("accuracy", 0.95);
121
122        for _ in 0..10 {
123            detector.record("accuracy", 0.95);
124        }
125
126        let drift = detector.check_drift("accuracy").unwrap();
127        assert_eq!(drift.trend, DriftTrend::Stable);
128        assert!(drift.drift_magnitude < 0.01);
129    }
130
131    #[test]
132    fn test_improving_trend() {
133        let mut detector = LearningDriftDetector::new(0.1, 10);
134        detector.set_baseline("accuracy", 0.80);
135
136        for i in 0..10 {
137            detector.record("accuracy", 0.85 + (i as f64) * 0.01);
138        }
139
140        let drift = detector.check_drift("accuracy").unwrap();
141        assert_eq!(drift.trend, DriftTrend::Improving);
142    }
143
144    #[test]
145    fn test_declining_trend() {
146        let mut detector = LearningDriftDetector::new(0.1, 10);
147        detector.set_baseline("accuracy", 0.95);
148
149        for _ in 0..10 {
150            detector.record("accuracy", 0.85);
151        }
152
153        let drift = detector.check_drift("accuracy").unwrap();
154        assert_eq!(drift.trend, DriftTrend::Declining);
155    }
156
157    #[test]
158    fn test_drift_threshold() {
159        let mut detector = LearningDriftDetector::new(0.1, 10);
160        detector.set_baseline("metric1", 1.0);
161        detector.set_baseline("metric2", 1.0);
162
163        // metric1: no drift
164        for _ in 0..10 {
165            detector.record("metric1", 1.05);
166        }
167
168        // metric2: significant drift
169        for _ in 0..10 {
170            detector.record("metric2", 1.5);
171        }
172
173        let drifts = detector.check_all_drifts();
174        assert_eq!(drifts.len(), 1);
175        assert_eq!(drifts[0].name, "metric2");
176    }
177}