ruvector_dag/healing/
drift_detector.rs1use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct DriftMetric {
7 pub name: String,
8 pub current_value: f64,
9 pub baseline_value: f64,
10 pub drift_magnitude: f64,
11 pub trend: DriftTrend,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum DriftTrend {
16 Improving,
17 Stable,
18 Declining,
19}
20
21pub struct LearningDriftDetector {
22 baselines: HashMap<String, f64>,
23 current_values: HashMap<String, Vec<f64>>,
24 drift_threshold: f64,
25 window_size: usize,
26}
27
28impl LearningDriftDetector {
29 pub fn new(drift_threshold: f64, window_size: usize) -> Self {
30 Self {
31 baselines: HashMap::new(),
32 current_values: HashMap::new(),
33 drift_threshold,
34 window_size,
35 }
36 }
37
38 pub fn set_baseline(&mut self, metric: &str, value: f64) {
39 self.baselines.insert(metric.to_string(), value);
40 }
41
42 pub fn record(&mut self, metric: &str, value: f64) {
43 let values = self
44 .current_values
45 .entry(metric.to_string())
46 .or_insert_with(Vec::new);
47
48 values.push(value);
49
50 if values.len() > self.window_size {
52 values.remove(0);
53 }
54 }
55
56 pub fn check_drift(&self, metric: &str) -> Option<DriftMetric> {
57 let baseline = self.baselines.get(metric)?;
58 let values = self.current_values.get(metric)?;
59
60 if values.is_empty() {
61 return None;
62 }
63
64 let current = values.iter().sum::<f64>() / values.len() as f64;
65 let drift_magnitude = (current - baseline).abs() / baseline.abs().max(1e-10);
66
67 let trend = if current > *baseline * 1.05 {
68 DriftTrend::Improving
69 } else if current < *baseline * 0.95 {
70 DriftTrend::Declining
71 } else {
72 DriftTrend::Stable
73 };
74
75 Some(DriftMetric {
76 name: metric.to_string(),
77 current_value: current,
78 baseline_value: *baseline,
79 drift_magnitude,
80 trend,
81 })
82 }
83
84 pub fn check_all_drifts(&self) -> Vec<DriftMetric> {
85 self.baselines
86 .keys()
87 .filter_map(|metric| self.check_drift(metric))
88 .filter(|d| d.drift_magnitude > self.drift_threshold)
89 .collect()
90 }
91
92 pub fn drift_threshold(&self) -> f64 {
93 self.drift_threshold
94 }
95
96 pub fn window_size(&self) -> usize {
97 self.window_size
98 }
99
100 pub fn metrics(&self) -> Vec<String> {
101 self.baselines.keys().cloned().collect()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_baseline_setting() {
111 let mut detector = LearningDriftDetector::new(0.1, 10);
112 detector.set_baseline("accuracy", 0.95);
113
114 assert_eq!(detector.baselines.get("accuracy"), Some(&0.95));
115 }
116
117 #[test]
118 fn test_stable_metric() {
119 let mut detector = LearningDriftDetector::new(0.1, 10);
120 detector.set_baseline("accuracy", 0.95);
121
122 for _ in 0..10 {
123 detector.record("accuracy", 0.95);
124 }
125
126 let drift = detector.check_drift("accuracy").unwrap();
127 assert_eq!(drift.trend, DriftTrend::Stable);
128 assert!(drift.drift_magnitude < 0.01);
129 }
130
131 #[test]
132 fn test_improving_trend() {
133 let mut detector = LearningDriftDetector::new(0.1, 10);
134 detector.set_baseline("accuracy", 0.80);
135
136 for i in 0..10 {
137 detector.record("accuracy", 0.85 + (i as f64) * 0.01);
138 }
139
140 let drift = detector.check_drift("accuracy").unwrap();
141 assert_eq!(drift.trend, DriftTrend::Improving);
142 }
143
144 #[test]
145 fn test_declining_trend() {
146 let mut detector = LearningDriftDetector::new(0.1, 10);
147 detector.set_baseline("accuracy", 0.95);
148
149 for _ in 0..10 {
150 detector.record("accuracy", 0.85);
151 }
152
153 let drift = detector.check_drift("accuracy").unwrap();
154 assert_eq!(drift.trend, DriftTrend::Declining);
155 }
156
157 #[test]
158 fn test_drift_threshold() {
159 let mut detector = LearningDriftDetector::new(0.1, 10);
160 detector.set_baseline("metric1", 1.0);
161 detector.set_baseline("metric2", 1.0);
162
163 for _ in 0..10 {
165 detector.record("metric1", 1.05);
166 }
167
168 for _ in 0..10 {
170 detector.record("metric2", 1.5);
171 }
172
173 let drifts = detector.check_all_drifts();
174 assert_eq!(drifts.len(), 1);
175 assert_eq!(drifts[0].name, "metric2");
176 }
177}