trustformers_debug/streaming_debugger/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::*;
6use anyhow::Result;
7use std::sync::Arc;
8use tracing::info;
9/// Integration with main debug session
10impl crate::DebugSession {
11    /// Enable streaming for this debug session
12    pub async fn enable_streaming(
13        &mut self,
14        config: StreamingDebugConfig,
15    ) -> Result<Arc<StreamingDebugger>> {
16        let streaming_debugger = Arc::new(StreamingDebugger::new(config));
17        streaming_debugger.start().await?;
18        info!("Enabled streaming for debug session {}", self.id());
19        Ok(streaming_debugger)
20    }
21}
22/// Convenience macros for streaming debugging
23#[macro_export]
24macro_rules! stream_tensor {
25    ($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
26        let tensor_id = uuid::Uuid::new_v4();
27        let shape = $tensor.shape().to_vec();
28        let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
29        $streamer
30            .send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
31            .await
32    }};
33}
34#[macro_export]
35macro_rules! stream_gradients {
36    ($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
37        let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
38        $streamer
39            .send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
40            .await
41    }};
42}
43#[macro_export]
44macro_rules! stream_anomaly {
45    (
46        $streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
47        $description:expr
48    ) => {{
49        $streamer
50            .send_anomaly_detected(
51                $session_id,
52                $anomaly_type,
53                $severity,
54                $description.to_string(),
55                0.95,
56                vec![],
57            )
58            .await
59    }};
60}
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use std::collections::HashMap;
65    use std::time::{Duration, SystemTime};
66
67    use uuid::Uuid;
68    #[tokio::test]
69    async fn test_streaming_debugger_creation() {
70        let config = StreamingDebugConfig::default();
71        let debugger = StreamingDebugger::new(config);
72        assert!(!*debugger.is_running.read().await);
73    }
74    #[tokio::test(flavor = "multi_thread")]
75    async fn test_start_stop_streaming() {
76        let config = StreamingDebugConfig {
77            stream_interval_ms: 50,
78            ..Default::default()
79        };
80        let debugger = StreamingDebugger::new(config);
81        let test_result = tokio::time::timeout(Duration::from_secs(3), async {
82            assert!(debugger.start().await.is_ok());
83            assert!(*debugger.is_running.read().await);
84            tokio::time::sleep(Duration::from_millis(50)).await;
85            assert!(debugger.stop().await.is_ok());
86            assert!(!*debugger.is_running.read().await);
87            tokio::time::sleep(Duration::from_millis(100)).await;
88            Ok::<(), anyhow::Error>(())
89        })
90        .await;
91        assert!(test_result.is_ok(), "Test timed out");
92        assert!(test_result.unwrap().is_ok());
93    }
94    #[tokio::test(flavor = "multi_thread")]
95    async fn test_subscription() {
96        let config = StreamingDebugConfig {
97            stream_interval_ms: 50,
98            ..Default::default()
99        };
100        let debugger = StreamingDebugger::new(config);
101        let test_result = tokio::time::timeout(Duration::from_secs(3), async {
102            debugger.start().await.unwrap();
103            let subscription = debugger
104                .subscribe(
105                    "test_subscriber".to_string(),
106                    StreamFormat::Json,
107                    StreamFilter::default(),
108                )
109                .await
110                .unwrap();
111            assert_eq!(debugger.get_subscribers().await.len(), 1);
112            debugger.unsubscribe(subscription.subscriber_id()).await.unwrap();
113            assert_eq!(debugger.get_subscribers().await.len(), 0);
114            debugger.stop().await.unwrap();
115            tokio::time::sleep(Duration::from_millis(100)).await;
116            Ok::<(), anyhow::Error>(())
117        })
118        .await;
119        assert!(test_result.is_ok(), "Test timed out");
120        assert!(test_result.unwrap().is_ok());
121    }
122    #[tokio::test]
123    async fn test_tensor_statistics() {
124        let config = StreamingDebugConfig::default();
125        let debugger = StreamingDebugger::new(config);
126        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
127        let stats = debugger.compute_tensor_statistics(&values);
128        assert_eq!(stats.mean, 3.0);
129        assert!(stats.std > 0.0);
130        assert_eq!(stats.min, 1.0);
131        assert_eq!(stats.max, 5.0);
132        assert_eq!(stats.zero_count, 0);
133    }
134    #[tokio::test]
135    async fn test_gradient_statistics() {
136        let config = StreamingDebugConfig::default();
137        let debugger = StreamingDebugger::new(config);
138        let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
139        let stats = debugger.compute_gradient_statistics(&gradients);
140        assert!(stats.l1_norm > 0.0);
141        assert!(stats.l2_norm > 0.0);
142        assert_eq!(stats.max_grad, 0.3);
143        assert_eq!(stats.min_grad, -0.2);
144    }
145    #[tokio::test]
146    async fn test_event_filtering() {
147        let session_id1 = Uuid::new_v4();
148        let session_id2 = Uuid::new_v4();
149        let filter = StreamFilter {
150            session_ids: Some(vec![session_id1]),
151            event_types: Some(vec!["TensorData".to_string()]),
152            min_severity: None,
153            time_range: None,
154            custom_filters: HashMap::new(),
155        };
156        let matching_event = StreamEvent::TensorData {
157            session_id: session_id1,
158            tensor_id: Uuid::new_v4(),
159            name: "test".to_string(),
160            shape: vec![2, 2],
161            values: vec![1.0, 2.0, 3.0, 4.0],
162            statistics: TensorStatistics {
163                mean: 2.5,
164                std: 1.29,
165                min: 1.0,
166                max: 4.0,
167                nan_count: 0,
168                inf_count: 0,
169                zero_count: 0,
170                sparsity: 0.0,
171            },
172            timestamp: SystemTime::now(),
173        };
174        let non_matching_event = StreamEvent::TensorData {
175            session_id: session_id2,
176            tensor_id: Uuid::new_v4(),
177            name: "test".to_string(),
178            shape: vec![2, 2],
179            values: vec![1.0, 2.0, 3.0, 4.0],
180            statistics: TensorStatistics {
181                mean: 2.5,
182                std: 1.29,
183                min: 1.0,
184                max: 4.0,
185                nan_count: 0,
186                inf_count: 0,
187                zero_count: 0,
188                sparsity: 0.0,
189            },
190            timestamp: SystemTime::now(),
191        };
192        assert!(StreamSubscription::matches_filter(&matching_event, &filter));
193        assert!(!StreamSubscription::matches_filter(
194            &non_matching_event,
195            &filter
196        ));
197    }
198}
199/// Trait for custom aggregation rules
200pub trait AggregationRule {
201    fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
202    fn rule_name(&self) -> &str;
203}
204#[cfg(test)]
205mod enhanced_tests {
206    use super::*;
207    use std::time::{Duration, Instant, SystemTime};
208    use uuid::Uuid;
209    #[tokio::test(flavor = "multi_thread")]
210    async fn test_enhanced_streaming_debugger() {
211        let base_config = StreamingDebugConfig {
212            stream_interval_ms: 50,
213            ..Default::default()
214        };
215        let adaptive_config = AdaptiveStreamingConfig {
216            monitoring_interval_ms: 500,
217            ..Default::default()
218        };
219        let aggregation_config = RealTimeAggregationConfig {
220            window_size_seconds: 1,
221            ..Default::default()
222        };
223        let buffering_config = IntelligentBufferingConfig::default();
224        let mut debugger = EnhancedStreamingDebugger::new(
225            base_config,
226            adaptive_config,
227            aggregation_config,
228            buffering_config,
229        );
230        let test_result = tokio::time::timeout(Duration::from_secs(5), async {
231            assert!(debugger.start_enhanced_streaming().await.is_ok());
232            tokio::time::sleep(Duration::from_millis(100)).await;
233            assert!(debugger.stop_enhanced_streaming().await.is_ok());
234            tokio::time::sleep(Duration::from_millis(200)).await;
235            Ok::<(), anyhow::Error>(())
236        })
237        .await;
238        assert!(test_result.is_ok(), "Test timed out");
239        assert!(test_result.unwrap().is_ok());
240    }
241    #[tokio::test]
242    async fn test_network_condition_monitor() {
243        let mut monitor = NetworkConditionMonitor::new();
244        monitor.update_conditions().await;
245        assert!(monitor.quality_score >= 0.0);
246        assert!(monitor.quality_score <= 1.0);
247        assert!(!monitor.history.is_empty());
248    }
249    #[test]
250    fn test_buffer_performance_predictor() {
251        let predictor = BufferPerformancePredictor {
252            performance_history: vec![
253                BufferPerformancePoint {
254                    buffer_size: 500,
255                    throughput: 100.0,
256                    latency: 50.0,
257                    memory_usage: 50000,
258                    timestamp: Instant::now(),
259                },
260                BufferPerformancePoint {
261                    buffer_size: 1000,
262                    throughput: 150.0,
263                    latency: 40.0,
264                    memory_usage: 100000,
265                    timestamp: Instant::now(),
266                },
267            ],
268            model_params: vec![],
269            accuracy: 0.8,
270        };
271        let optimal_size = predictor.predict_optimal_size().unwrap();
272        assert_eq!(optimal_size, 1000);
273    }
274    #[tokio::test]
275    async fn test_importance_scorer() {
276        let scorer = ImportanceScorer::new();
277        let critical_event = StreamEvent::AnomalyDetected {
278            session_id: Uuid::new_v4(),
279            anomaly_type: AnomalyType::GradientExplosion,
280            severity: AnomalySeverity::Critical,
281            description: "Critical gradient explosion".to_string(),
282            confidence: 0.95,
283            affected_components: vec!["layer1".to_string()],
284            timestamp: SystemTime::now(),
285        };
286        let low_event = StreamEvent::AnomalyDetected {
287            session_id: Uuid::new_v4(),
288            anomaly_type: AnomalyType::TrainingStagnation,
289            severity: AnomalySeverity::Low,
290            description: "Slow convergence detected".to_string(),
291            confidence: 0.6,
292            affected_components: vec!["layer2".to_string()],
293            timestamp: SystemTime::now(),
294        };
295        let critical_score = scorer.calculate_importance(&critical_event).await.unwrap();
296        let low_score = scorer.calculate_importance(&low_event).await.unwrap();
297        assert!(critical_score > low_score);
298    }
299}