trustformers_training/continual/
task_boundary.rs1use anyhow::Result;
2use serde::{Deserialize, Serialize};
4use std::collections::VecDeque;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct BoundaryDetectionConfig {
9 pub window_size: usize,
11 pub threshold: f32,
13 pub detection_method: DetectionMethod,
15 pub min_samples: usize,
17 pub smoothing_factor: f32,
19}
20
21impl Default for BoundaryDetectionConfig {
22 fn default() -> Self {
23 Self {
24 window_size: 100,
25 threshold: 0.05,
26 detection_method: DetectionMethod::LossIncrease,
27 min_samples: 50,
28 smoothing_factor: 0.1,
29 }
30 }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum DetectionMethod {
36 LossIncrease,
38 GradientMagnitude,
40 ActivationPattern,
42 ConfidenceChange,
44 Combined,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct TaskTransition {
51 pub from_task: String,
53 pub to_task: String,
55 pub timestamp: chrono::DateTime<chrono::Utc>,
57 pub boundary_score: f32,
59}
60
61#[derive(Debug)]
63pub struct TaskBoundaryDetector {
64 config: BoundaryDetectionConfig,
65 loss_history: VecDeque<f32>,
66 gradient_history: VecDeque<f32>,
67 confidence_history: VecDeque<f32>,
68 running_loss_avg: f32,
69 running_gradient_avg: f32,
70 running_confidence_avg: f32,
71 sample_count: usize,
72 last_boundary_sample: usize,
73}
74
75impl TaskBoundaryDetector {
76 pub fn new(config: BoundaryDetectionConfig) -> Self {
77 Self {
78 config,
79 loss_history: VecDeque::new(),
80 gradient_history: VecDeque::new(),
81 confidence_history: VecDeque::new(),
82 running_loss_avg: 0.0,
83 running_gradient_avg: 0.0,
84 running_confidence_avg: 0.0,
85 sample_count: 0,
86 last_boundary_sample: 0,
87 }
88 }
89
90 pub fn update(&mut self, loss: f32, gradient_norm: f32, confidence: f32) {
92 self.sample_count += 1;
93
94 let alpha = self.config.smoothing_factor;
96 self.running_loss_avg = alpha * loss + (1.0 - alpha) * self.running_loss_avg;
97 self.running_gradient_avg =
98 alpha * gradient_norm + (1.0 - alpha) * self.running_gradient_avg;
99 self.running_confidence_avg =
100 alpha * confidence + (1.0 - alpha) * self.running_confidence_avg;
101
102 self.loss_history.push_back(loss);
104 self.gradient_history.push_back(gradient_norm);
105 self.confidence_history.push_back(confidence);
106
107 if self.loss_history.len() > self.config.window_size {
109 self.loss_history.pop_front();
110 self.gradient_history.pop_front();
111 self.confidence_history.pop_front();
112 }
113 }
114
115 pub fn detect_boundary(&mut self) -> Result<Option<f32>> {
117 if self.sample_count < self.config.min_samples {
118 return Ok(None);
119 }
120
121 if self.sample_count - self.last_boundary_sample < self.config.min_samples {
122 return Ok(None);
123 }
124
125 let boundary_score = match self.config.detection_method {
126 DetectionMethod::LossIncrease => self.detect_loss_increase()?,
127 DetectionMethod::GradientMagnitude => self.detect_gradient_change()?,
128 DetectionMethod::ActivationPattern => self.detect_activation_change()?,
129 DetectionMethod::ConfidenceChange => self.detect_confidence_change()?,
130 DetectionMethod::Combined => self.detect_combined()?,
131 };
132
133 if boundary_score > self.config.threshold {
134 self.last_boundary_sample = self.sample_count;
135 Ok(Some(boundary_score))
136 } else {
137 Ok(None)
138 }
139 }
140
141 fn detect_loss_increase(&self) -> Result<f32> {
143 if self.loss_history.len() < self.config.window_size / 2 {
144 return Ok(0.0);
145 }
146
147 let mid_point = self.loss_history.len() / 2;
148 let recent_avg: f32 = self.loss_history.iter().skip(mid_point).sum::<f32>()
149 / (self.loss_history.len() - mid_point) as f32;
150 let old_avg: f32 = self.loss_history.iter().take(mid_point).sum::<f32>() / mid_point as f32;
151
152 let relative_increase = (recent_avg - old_avg) / old_avg.max(1e-8);
153 Ok(relative_increase.max(0.0))
154 }
155
156 fn detect_gradient_change(&self) -> Result<f32> {
158 if self.gradient_history.len() < self.config.window_size / 2 {
159 return Ok(0.0);
160 }
161
162 let mid_point = self.gradient_history.len() / 2;
163 let recent_avg: f32 = self.gradient_history.iter().skip(mid_point).sum::<f32>()
164 / (self.gradient_history.len() - mid_point) as f32;
165 let old_avg: f32 =
166 self.gradient_history.iter().take(mid_point).sum::<f32>() / mid_point as f32;
167
168 let relative_change = (recent_avg - old_avg).abs() / old_avg.max(1e-8);
169 Ok(relative_change)
170 }
171
172 fn detect_activation_change(&self) -> Result<f32> {
174 Ok(0.0)
176 }
177
178 fn detect_confidence_change(&self) -> Result<f32> {
180 if self.confidence_history.len() < self.config.window_size / 2 {
181 return Ok(0.0);
182 }
183
184 let mid_point = self.confidence_history.len() / 2;
185 let recent_avg: f32 = self.confidence_history.iter().skip(mid_point).sum::<f32>()
186 / (self.confidence_history.len() - mid_point) as f32;
187 let old_avg: f32 =
188 self.confidence_history.iter().take(mid_point).sum::<f32>() / mid_point as f32;
189
190 let confidence_drop = (old_avg - recent_avg) / old_avg.max(1e-8);
191 Ok(confidence_drop.max(0.0))
192 }
193
194 fn detect_combined(&self) -> Result<f32> {
196 let loss_score = self.detect_loss_increase()?;
197 let gradient_score = self.detect_gradient_change()?;
198 let confidence_score = self.detect_confidence_change()?;
199
200 let combined_score = 0.4 * loss_score + 0.3 * gradient_score + 0.3 * confidence_score;
202 Ok(combined_score)
203 }
204
205 pub fn get_statistics(&self) -> DetectorStats {
207 DetectorStats {
208 sample_count: self.sample_count,
209 running_loss_avg: self.running_loss_avg,
210 running_gradient_avg: self.running_gradient_avg,
211 running_confidence_avg: self.running_confidence_avg,
212 window_size: self.loss_history.len(),
213 last_boundary_sample: self.last_boundary_sample,
214 }
215 }
216
217 pub fn reset(&mut self) {
219 self.loss_history.clear();
220 self.gradient_history.clear();
221 self.confidence_history.clear();
222 self.running_loss_avg = 0.0;
223 self.running_gradient_avg = 0.0;
224 self.running_confidence_avg = 0.0;
225 self.sample_count = 0;
226 self.last_boundary_sample = 0;
227 }
228
229 pub fn force_boundary(&mut self) -> TaskTransition {
231 self.last_boundary_sample = self.sample_count;
232 TaskTransition {
233 from_task: "unknown".to_string(),
234 to_task: "unknown".to_string(),
235 timestamp: chrono::Utc::now(),
236 boundary_score: 1.0,
237 }
238 }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct DetectorStats {
244 pub sample_count: usize,
245 pub running_loss_avg: f32,
246 pub running_gradient_avg: f32,
247 pub running_confidence_avg: f32,
248 pub window_size: usize,
249 pub last_boundary_sample: usize,
250}
251
252#[derive(Debug)]
254pub struct AdaptiveThreshold {
255 base_threshold: f32,
256 adaptation_rate: f32,
257 false_positive_count: usize,
258 false_negative_count: usize,
259 total_boundaries: usize,
260}
261
262impl AdaptiveThreshold {
263 pub fn new(base_threshold: f32, adaptation_rate: f32) -> Self {
264 Self {
265 base_threshold,
266 adaptation_rate,
267 false_positive_count: 0,
268 false_negative_count: 0,
269 total_boundaries: 0,
270 }
271 }
272
273 pub fn update_threshold(&mut self, is_false_positive: bool, is_false_negative: bool) -> f32 {
275 if is_false_positive {
276 self.false_positive_count += 1;
277 self.base_threshold *= 1.0 + self.adaptation_rate;
278 } else if is_false_negative {
279 self.false_negative_count += 1;
280 self.base_threshold *= 1.0 - self.adaptation_rate;
281 }
282
283 self.total_boundaries += 1;
284 self.base_threshold = self.base_threshold.clamp(0.01, 1.0);
285 self.base_threshold
286 }
287
288 pub fn get_threshold(&self) -> f32 {
290 self.base_threshold
291 }
292
293 pub fn get_stats(&self) -> (f32, f32, usize) {
295 let false_positive_rate =
296 self.false_positive_count as f32 / self.total_boundaries.max(1) as f32;
297 let false_negative_rate =
298 self.false_negative_count as f32 / self.total_boundaries.max(1) as f32;
299 (
300 false_positive_rate,
301 false_negative_rate,
302 self.total_boundaries,
303 )
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use approx::assert_abs_diff_eq;
311
312 #[test]
313 fn test_boundary_detector_creation() {
314 let config = BoundaryDetectionConfig::default();
315 let detector = TaskBoundaryDetector::new(config);
316
317 let stats = detector.get_statistics();
318 assert_eq!(stats.sample_count, 0);
319 assert_eq!(stats.window_size, 0);
320 }
321
322 #[test]
323 fn test_boundary_detector_update() {
324 let config = BoundaryDetectionConfig::default();
325 let mut detector = TaskBoundaryDetector::new(config);
326
327 for i in 0..10 {
329 detector.update(i as f32, (i as f32) * 0.1, 0.9);
330 }
331
332 let stats = detector.get_statistics();
333 assert_eq!(stats.sample_count, 10);
334 assert_eq!(stats.window_size, 10);
335 assert!(stats.running_loss_avg > 0.0);
336 }
337
338 #[test]
339 fn test_loss_increase_detection() {
340 let config = BoundaryDetectionConfig {
341 window_size: 20,
342 threshold: 0.1,
343 detection_method: DetectionMethod::LossIncrease,
344 min_samples: 10,
345 ..Default::default()
346 };
347 let mut detector = TaskBoundaryDetector::new(config);
348
349 for _i in 0..15 {
351 detector.update(1.0, 0.1, 0.9);
352 }
353
354 for _i in 0..15 {
356 detector.update(2.0, 0.1, 0.9);
357 }
358
359 let boundary = detector.detect_boundary().expect("operation failed in test");
360 assert!(boundary.is_some());
361 assert!(boundary.expect("operation failed in test") > 0.1);
362 }
363
364 #[test]
365 fn test_adaptive_threshold() {
366 let mut threshold = AdaptiveThreshold::new(0.5, 0.1);
367
368 let initial_threshold = threshold.get_threshold();
369 assert_abs_diff_eq!(initial_threshold, 0.5, epsilon = 1e-6);
370
371 let new_threshold = threshold.update_threshold(true, false);
373 assert!(new_threshold > initial_threshold);
374
375 let newer_threshold = threshold.update_threshold(false, true);
377 assert!(newer_threshold < new_threshold);
378 }
379
380 #[test]
381 fn test_combined_detection() {
382 let config = BoundaryDetectionConfig {
383 window_size: 10,
384 threshold: 0.1,
385 detection_method: DetectionMethod::Combined,
386 min_samples: 5,
387 ..Default::default()
388 };
389 let mut detector = TaskBoundaryDetector::new(config);
390
391 for i in 0..15 {
393 let loss = if i < 10 { 1.0 } else { 1.5 };
394 let gradient = if i < 10 { 0.1 } else { 0.15 };
395 let confidence = if i < 10 { 0.9 } else { 0.7 };
396 detector.update(loss, gradient, confidence);
397 }
398
399 let boundary = detector.detect_boundary().expect("operation failed in test");
400 assert!(boundary.is_some());
401 }
402
403 #[test]
404 fn test_detector_reset() {
405 let config = BoundaryDetectionConfig::default();
406 let mut detector = TaskBoundaryDetector::new(config);
407
408 for i in 0..10 {
410 detector.update(i as f32, 0.1, 0.9);
411 }
412
413 assert_eq!(detector.get_statistics().sample_count, 10);
414
415 detector.reset();
416 assert_eq!(detector.get_statistics().sample_count, 0);
417 assert_eq!(detector.get_statistics().window_size, 0);
418 }
419}