Skip to main content

trustformers_training/continual/
task_boundary.rs

1use anyhow::Result;
2// SciRS2 Integration Policy
3use serde::{Deserialize, Serialize};
4use std::collections::VecDeque;
5
6/// Configuration for task boundary detection
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct BoundaryDetectionConfig {
9    /// Window size for change detection
10    pub window_size: usize,
11    /// Threshold for boundary detection
12    pub threshold: f32,
13    /// Method for boundary detection
14    pub detection_method: DetectionMethod,
15    /// Minimum samples before detecting boundary
16    pub min_samples: usize,
17    /// Smoothing factor for running averages
18    pub smoothing_factor: f32,
19}
20
21impl Default for BoundaryDetectionConfig {
22    fn default() -> Self {
23        Self {
24            window_size: 100,
25            threshold: 0.05,
26            detection_method: DetectionMethod::LossIncrease,
27            min_samples: 50,
28            smoothing_factor: 0.1,
29        }
30    }
31}
32
33/// Methods for detecting task boundaries
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum DetectionMethod {
36    /// Detect based on loss increase
37    LossIncrease,
38    /// Detect based on gradient magnitude change
39    GradientMagnitude,
40    /// Detect based on activation pattern change
41    ActivationPattern,
42    /// Detect based on prediction confidence
43    ConfidenceChange,
44    /// Combined detection using multiple signals
45    Combined,
46}
47
48/// Task transition information
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct TaskTransition {
51    /// Previous task ID
52    pub from_task: String,
53    /// New task ID
54    pub to_task: String,
55    /// Timestamp of transition
56    pub timestamp: chrono::DateTime<chrono::Utc>,
57    /// Boundary detection confidence score
58    pub boundary_score: f32,
59}
60
61/// Task boundary detector
62#[derive(Debug)]
63pub struct TaskBoundaryDetector {
64    config: BoundaryDetectionConfig,
65    loss_history: VecDeque<f32>,
66    gradient_history: VecDeque<f32>,
67    confidence_history: VecDeque<f32>,
68    running_loss_avg: f32,
69    running_gradient_avg: f32,
70    running_confidence_avg: f32,
71    sample_count: usize,
72    last_boundary_sample: usize,
73}
74
75impl TaskBoundaryDetector {
76    pub fn new(config: BoundaryDetectionConfig) -> Self {
77        Self {
78            config,
79            loss_history: VecDeque::new(),
80            gradient_history: VecDeque::new(),
81            confidence_history: VecDeque::new(),
82            running_loss_avg: 0.0,
83            running_gradient_avg: 0.0,
84            running_confidence_avg: 0.0,
85            sample_count: 0,
86            last_boundary_sample: 0,
87        }
88    }
89
90    /// Update detector with new training sample
91    pub fn update(&mut self, loss: f32, gradient_norm: f32, confidence: f32) {
92        self.sample_count += 1;
93
94        // Update running averages
95        let alpha = self.config.smoothing_factor;
96        self.running_loss_avg = alpha * loss + (1.0 - alpha) * self.running_loss_avg;
97        self.running_gradient_avg =
98            alpha * gradient_norm + (1.0 - alpha) * self.running_gradient_avg;
99        self.running_confidence_avg =
100            alpha * confidence + (1.0 - alpha) * self.running_confidence_avg;
101
102        // Update history windows
103        self.loss_history.push_back(loss);
104        self.gradient_history.push_back(gradient_norm);
105        self.confidence_history.push_back(confidence);
106
107        // Maintain window size
108        if self.loss_history.len() > self.config.window_size {
109            self.loss_history.pop_front();
110            self.gradient_history.pop_front();
111            self.confidence_history.pop_front();
112        }
113    }
114
115    /// Check if a task boundary is detected
116    pub fn detect_boundary(&mut self) -> Result<Option<f32>> {
117        if self.sample_count < self.config.min_samples {
118            return Ok(None);
119        }
120
121        if self.sample_count - self.last_boundary_sample < self.config.min_samples {
122            return Ok(None);
123        }
124
125        let boundary_score = match self.config.detection_method {
126            DetectionMethod::LossIncrease => self.detect_loss_increase()?,
127            DetectionMethod::GradientMagnitude => self.detect_gradient_change()?,
128            DetectionMethod::ActivationPattern => self.detect_activation_change()?,
129            DetectionMethod::ConfidenceChange => self.detect_confidence_change()?,
130            DetectionMethod::Combined => self.detect_combined()?,
131        };
132
133        if boundary_score > self.config.threshold {
134            self.last_boundary_sample = self.sample_count;
135            Ok(Some(boundary_score))
136        } else {
137            Ok(None)
138        }
139    }
140
141    /// Detect boundary based on loss increase
142    fn detect_loss_increase(&self) -> Result<f32> {
143        if self.loss_history.len() < self.config.window_size / 2 {
144            return Ok(0.0);
145        }
146
147        let mid_point = self.loss_history.len() / 2;
148        let recent_avg: f32 = self.loss_history.iter().skip(mid_point).sum::<f32>()
149            / (self.loss_history.len() - mid_point) as f32;
150        let old_avg: f32 = self.loss_history.iter().take(mid_point).sum::<f32>() / mid_point as f32;
151
152        let relative_increase = (recent_avg - old_avg) / old_avg.max(1e-8);
153        Ok(relative_increase.max(0.0))
154    }
155
156    /// Detect boundary based on gradient magnitude change
157    fn detect_gradient_change(&self) -> Result<f32> {
158        if self.gradient_history.len() < self.config.window_size / 2 {
159            return Ok(0.0);
160        }
161
162        let mid_point = self.gradient_history.len() / 2;
163        let recent_avg: f32 = self.gradient_history.iter().skip(mid_point).sum::<f32>()
164            / (self.gradient_history.len() - mid_point) as f32;
165        let old_avg: f32 =
166            self.gradient_history.iter().take(mid_point).sum::<f32>() / mid_point as f32;
167
168        let relative_change = (recent_avg - old_avg).abs() / old_avg.max(1e-8);
169        Ok(relative_change)
170    }
171
172    /// Detect boundary based on activation pattern change
173    fn detect_activation_change(&self) -> Result<f32> {
174        // Placeholder - in practice would analyze activation patterns
175        Ok(0.0)
176    }
177
178    /// Detect boundary based on confidence change
179    fn detect_confidence_change(&self) -> Result<f32> {
180        if self.confidence_history.len() < self.config.window_size / 2 {
181            return Ok(0.0);
182        }
183
184        let mid_point = self.confidence_history.len() / 2;
185        let recent_avg: f32 = self.confidence_history.iter().skip(mid_point).sum::<f32>()
186            / (self.confidence_history.len() - mid_point) as f32;
187        let old_avg: f32 =
188            self.confidence_history.iter().take(mid_point).sum::<f32>() / mid_point as f32;
189
190        let confidence_drop = (old_avg - recent_avg) / old_avg.max(1e-8);
191        Ok(confidence_drop.max(0.0))
192    }
193
194    /// Combined detection using multiple signals
195    fn detect_combined(&self) -> Result<f32> {
196        let loss_score = self.detect_loss_increase()?;
197        let gradient_score = self.detect_gradient_change()?;
198        let confidence_score = self.detect_confidence_change()?;
199
200        // Weighted combination
201        let combined_score = 0.4 * loss_score + 0.3 * gradient_score + 0.3 * confidence_score;
202        Ok(combined_score)
203    }
204
205    /// Get detector statistics
206    pub fn get_statistics(&self) -> DetectorStats {
207        DetectorStats {
208            sample_count: self.sample_count,
209            running_loss_avg: self.running_loss_avg,
210            running_gradient_avg: self.running_gradient_avg,
211            running_confidence_avg: self.running_confidence_avg,
212            window_size: self.loss_history.len(),
213            last_boundary_sample: self.last_boundary_sample,
214        }
215    }
216
217    /// Reset detector state
218    pub fn reset(&mut self) {
219        self.loss_history.clear();
220        self.gradient_history.clear();
221        self.confidence_history.clear();
222        self.running_loss_avg = 0.0;
223        self.running_gradient_avg = 0.0;
224        self.running_confidence_avg = 0.0;
225        self.sample_count = 0;
226        self.last_boundary_sample = 0;
227    }
228
229    /// Force boundary detection
230    pub fn force_boundary(&mut self) -> TaskTransition {
231        self.last_boundary_sample = self.sample_count;
232        TaskTransition {
233            from_task: "unknown".to_string(),
234            to_task: "unknown".to_string(),
235            timestamp: chrono::Utc::now(),
236            boundary_score: 1.0,
237        }
238    }
239}
240
241/// Detector statistics
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct DetectorStats {
244    pub sample_count: usize,
245    pub running_loss_avg: f32,
246    pub running_gradient_avg: f32,
247    pub running_confidence_avg: f32,
248    pub window_size: usize,
249    pub last_boundary_sample: usize,
250}
251
252/// Adaptive threshold management
253#[derive(Debug)]
254pub struct AdaptiveThreshold {
255    base_threshold: f32,
256    adaptation_rate: f32,
257    false_positive_count: usize,
258    false_negative_count: usize,
259    total_boundaries: usize,
260}
261
262impl AdaptiveThreshold {
263    pub fn new(base_threshold: f32, adaptation_rate: f32) -> Self {
264        Self {
265            base_threshold,
266            adaptation_rate,
267            false_positive_count: 0,
268            false_negative_count: 0,
269            total_boundaries: 0,
270        }
271    }
272
273    /// Update threshold based on feedback
274    pub fn update_threshold(&mut self, is_false_positive: bool, is_false_negative: bool) -> f32 {
275        if is_false_positive {
276            self.false_positive_count += 1;
277            self.base_threshold *= 1.0 + self.adaptation_rate;
278        } else if is_false_negative {
279            self.false_negative_count += 1;
280            self.base_threshold *= 1.0 - self.adaptation_rate;
281        }
282
283        self.total_boundaries += 1;
284        self.base_threshold = self.base_threshold.clamp(0.01, 1.0);
285        self.base_threshold
286    }
287
288    /// Get current threshold
289    pub fn get_threshold(&self) -> f32 {
290        self.base_threshold
291    }
292
293    /// Get adaptation statistics
294    pub fn get_stats(&self) -> (f32, f32, usize) {
295        let false_positive_rate =
296            self.false_positive_count as f32 / self.total_boundaries.max(1) as f32;
297        let false_negative_rate =
298            self.false_negative_count as f32 / self.total_boundaries.max(1) as f32;
299        (
300            false_positive_rate,
301            false_negative_rate,
302            self.total_boundaries,
303        )
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use approx::assert_abs_diff_eq;
311
312    #[test]
313    fn test_boundary_detector_creation() {
314        let config = BoundaryDetectionConfig::default();
315        let detector = TaskBoundaryDetector::new(config);
316
317        let stats = detector.get_statistics();
318        assert_eq!(stats.sample_count, 0);
319        assert_eq!(stats.window_size, 0);
320    }
321
322    #[test]
323    fn test_boundary_detector_update() {
324        let config = BoundaryDetectionConfig::default();
325        let mut detector = TaskBoundaryDetector::new(config);
326
327        // Add samples
328        for i in 0..10 {
329            detector.update(i as f32, (i as f32) * 0.1, 0.9);
330        }
331
332        let stats = detector.get_statistics();
333        assert_eq!(stats.sample_count, 10);
334        assert_eq!(stats.window_size, 10);
335        assert!(stats.running_loss_avg > 0.0);
336    }
337
338    #[test]
339    fn test_loss_increase_detection() {
340        let config = BoundaryDetectionConfig {
341            window_size: 20,
342            threshold: 0.1,
343            detection_method: DetectionMethod::LossIncrease,
344            min_samples: 10,
345            ..Default::default()
346        };
347        let mut detector = TaskBoundaryDetector::new(config);
348
349        // Add samples with stable loss
350        for _i in 0..15 {
351            detector.update(1.0, 0.1, 0.9);
352        }
353
354        // Add samples with increased loss
355        for _i in 0..15 {
356            detector.update(2.0, 0.1, 0.9);
357        }
358
359        let boundary = detector.detect_boundary().expect("operation failed in test");
360        assert!(boundary.is_some());
361        assert!(boundary.expect("operation failed in test") > 0.1);
362    }
363
364    #[test]
365    fn test_adaptive_threshold() {
366        let mut threshold = AdaptiveThreshold::new(0.5, 0.1);
367
368        let initial_threshold = threshold.get_threshold();
369        assert_abs_diff_eq!(initial_threshold, 0.5, epsilon = 1e-6);
370
371        // Update with false positive
372        let new_threshold = threshold.update_threshold(true, false);
373        assert!(new_threshold > initial_threshold);
374
375        // Update with false negative
376        let newer_threshold = threshold.update_threshold(false, true);
377        assert!(newer_threshold < new_threshold);
378    }
379
380    #[test]
381    fn test_combined_detection() {
382        let config = BoundaryDetectionConfig {
383            window_size: 10,
384            threshold: 0.1,
385            detection_method: DetectionMethod::Combined,
386            min_samples: 5,
387            ..Default::default()
388        };
389        let mut detector = TaskBoundaryDetector::new(config);
390
391        // Add samples - need to ensure we have enough samples in the window
392        for i in 0..15 {
393            let loss = if i < 10 { 1.0 } else { 1.5 };
394            let gradient = if i < 10 { 0.1 } else { 0.15 };
395            let confidence = if i < 10 { 0.9 } else { 0.7 };
396            detector.update(loss, gradient, confidence);
397        }
398
399        let boundary = detector.detect_boundary().expect("operation failed in test");
400        assert!(boundary.is_some());
401    }
402
403    #[test]
404    fn test_detector_reset() {
405        let config = BoundaryDetectionConfig::default();
406        let mut detector = TaskBoundaryDetector::new(config);
407
408        // Add samples
409        for i in 0..10 {
410            detector.update(i as f32, 0.1, 0.9);
411        }
412
413        assert_eq!(detector.get_statistics().sample_count, 10);
414
415        detector.reset();
416        assert_eq!(detector.get_statistics().sample_count, 0);
417        assert_eq!(detector.get_statistics().window_size, 0);
418    }
419}