Skip to main content

trustformers_debug/streaming_debugger/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::*;
6use anyhow::Result;
7use std::sync::Arc;
8use tracing::info;
9/// Integration with main debug session
10impl crate::DebugSession {
11    /// Enable streaming for this debug session
12    pub async fn enable_streaming(
13        &mut self,
14        config: StreamingDebugConfig,
15    ) -> Result<Arc<StreamingDebugger>> {
16        let streaming_debugger = Arc::new(StreamingDebugger::new(config));
17        streaming_debugger.start().await?;
18        info!("Enabled streaming for debug session {}", self.id());
19        Ok(streaming_debugger)
20    }
21}
22/// Convenience macros for streaming debugging
23#[macro_export]
24macro_rules! stream_tensor {
25    ($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
26        let tensor_id = uuid::Uuid::new_v4();
27        let shape = $tensor.shape().to_vec();
28        let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
29        $streamer
30            .send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
31            .await
32    }};
33}
34#[macro_export]
35macro_rules! stream_gradients {
36    ($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
37        let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
38        $streamer
39            .send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
40            .await
41    }};
42}
43#[macro_export]
44macro_rules! stream_anomaly {
45    (
46        $streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
47        $description:expr
48    ) => {{
49        $streamer
50            .send_anomaly_detected(
51                $session_id,
52                $anomaly_type,
53                $severity,
54                $description.to_string(),
55                0.95,
56                vec![],
57            )
58            .await
59    }};
60}
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use std::collections::HashMap;
65    use std::time::{Duration, SystemTime};
66
67    use uuid::Uuid;
68    #[tokio::test]
69    async fn test_streaming_debugger_creation() {
70        let config = StreamingDebugConfig::default();
71        let debugger = StreamingDebugger::new(config);
72        assert!(!*debugger.is_running.read().await);
73    }
74    #[tokio::test(flavor = "multi_thread")]
75    #[ignore] // TODO: Fix timeout issue in streaming debugger
76    async fn test_start_stop_streaming() {
77        let config = StreamingDebugConfig {
78            stream_interval_ms: 50,
79            ..Default::default()
80        };
81        let debugger = StreamingDebugger::new(config);
82        let test_result = tokio::time::timeout(Duration::from_secs(3), async {
83            assert!(debugger.start().await.is_ok());
84            assert!(*debugger.is_running.read().await);
85            tokio::time::sleep(Duration::from_millis(50)).await;
86            assert!(debugger.stop().await.is_ok());
87            assert!(!*debugger.is_running.read().await);
88            tokio::time::sleep(Duration::from_millis(100)).await;
89            Ok::<(), anyhow::Error>(())
90        })
91        .await;
92        assert!(test_result.is_ok(), "Test timed out");
93        assert!(test_result.unwrap().is_ok());
94    }
95    #[tokio::test(flavor = "multi_thread")]
96    #[ignore] // TODO: Fix timeout issue in streaming debugger
97    async fn test_subscription() {
98        let config = StreamingDebugConfig {
99            stream_interval_ms: 50,
100            ..Default::default()
101        };
102        let debugger = StreamingDebugger::new(config);
103        let test_result = tokio::time::timeout(Duration::from_secs(3), async {
104            debugger.start().await.unwrap();
105            let subscription = debugger
106                .subscribe(
107                    "test_subscriber".to_string(),
108                    StreamFormat::Json,
109                    StreamFilter::default(),
110                )
111                .await
112                .unwrap();
113            assert_eq!(debugger.get_subscribers().await.len(), 1);
114            debugger.unsubscribe(subscription.subscriber_id()).await.unwrap();
115            assert_eq!(debugger.get_subscribers().await.len(), 0);
116            debugger.stop().await.unwrap();
117            tokio::time::sleep(Duration::from_millis(100)).await;
118            Ok::<(), anyhow::Error>(())
119        })
120        .await;
121        assert!(test_result.is_ok(), "Test timed out");
122        assert!(test_result.unwrap().is_ok());
123    }
124    #[tokio::test]
125    async fn test_tensor_statistics() {
126        let config = StreamingDebugConfig::default();
127        let debugger = StreamingDebugger::new(config);
128        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
129        let stats = debugger.compute_tensor_statistics(&values);
130        assert_eq!(stats.mean, 3.0);
131        assert!(stats.std > 0.0);
132        assert_eq!(stats.min, 1.0);
133        assert_eq!(stats.max, 5.0);
134        assert_eq!(stats.zero_count, 0);
135    }
136    #[tokio::test]
137    async fn test_gradient_statistics() {
138        let config = StreamingDebugConfig::default();
139        let debugger = StreamingDebugger::new(config);
140        let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
141        let stats = debugger.compute_gradient_statistics(&gradients);
142        assert!(stats.l1_norm > 0.0);
143        assert!(stats.l2_norm > 0.0);
144        assert_eq!(stats.max_grad, 0.3);
145        assert_eq!(stats.min_grad, -0.2);
146    }
147    #[tokio::test]
148    async fn test_event_filtering() {
149        let session_id1 = Uuid::new_v4();
150        let session_id2 = Uuid::new_v4();
151        let filter = StreamFilter {
152            session_ids: Some(vec![session_id1]),
153            event_types: Some(vec!["TensorData".to_string()]),
154            min_severity: None,
155            time_range: None,
156            custom_filters: HashMap::new(),
157        };
158        let matching_event = StreamEvent::TensorData {
159            session_id: session_id1,
160            tensor_id: Uuid::new_v4(),
161            name: "test".to_string(),
162            shape: vec![2, 2],
163            values: vec![1.0, 2.0, 3.0, 4.0],
164            statistics: TensorStatistics {
165                mean: 2.5,
166                std: 1.29,
167                min: 1.0,
168                max: 4.0,
169                nan_count: 0,
170                inf_count: 0,
171                zero_count: 0,
172                sparsity: 0.0,
173            },
174            timestamp: SystemTime::now(),
175        };
176        let non_matching_event = StreamEvent::TensorData {
177            session_id: session_id2,
178            tensor_id: Uuid::new_v4(),
179            name: "test".to_string(),
180            shape: vec![2, 2],
181            values: vec![1.0, 2.0, 3.0, 4.0],
182            statistics: TensorStatistics {
183                mean: 2.5,
184                std: 1.29,
185                min: 1.0,
186                max: 4.0,
187                nan_count: 0,
188                inf_count: 0,
189                zero_count: 0,
190                sparsity: 0.0,
191            },
192            timestamp: SystemTime::now(),
193        };
194        assert!(StreamSubscription::matches_filter(&matching_event, &filter));
195        assert!(!StreamSubscription::matches_filter(
196            &non_matching_event,
197            &filter
198        ));
199    }
200}
201/// Trait for custom aggregation rules
202pub trait AggregationRule {
203    fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
204    fn rule_name(&self) -> &str;
205}
206#[cfg(test)]
207mod enhanced_tests {
208    use super::*;
209    use std::time::{Duration, Instant, SystemTime};
210    use uuid::Uuid;
211    #[tokio::test(flavor = "multi_thread")]
212    #[ignore] // TODO: Fix timeout issue in enhanced streaming debugger
213    async fn test_enhanced_streaming_debugger() {
214        let base_config = StreamingDebugConfig {
215            stream_interval_ms: 50,
216            ..Default::default()
217        };
218        let adaptive_config = AdaptiveStreamingConfig {
219            monitoring_interval_ms: 500,
220            ..Default::default()
221        };
222        let aggregation_config = RealTimeAggregationConfig {
223            window_size_seconds: 1,
224            ..Default::default()
225        };
226        let buffering_config = IntelligentBufferingConfig::default();
227        let mut debugger = EnhancedStreamingDebugger::new(
228            base_config,
229            adaptive_config,
230            aggregation_config,
231            buffering_config,
232        );
233        let test_result = tokio::time::timeout(Duration::from_secs(5), async {
234            assert!(debugger.start_enhanced_streaming().await.is_ok());
235            tokio::time::sleep(Duration::from_millis(100)).await;
236            assert!(debugger.stop_enhanced_streaming().await.is_ok());
237            tokio::time::sleep(Duration::from_millis(200)).await;
238            Ok::<(), anyhow::Error>(())
239        })
240        .await;
241        assert!(test_result.is_ok(), "Test timed out");
242        assert!(test_result.unwrap().is_ok());
243    }
244    #[tokio::test]
245    async fn test_network_condition_monitor() {
246        let mut monitor = NetworkConditionMonitor::new();
247        monitor.update_conditions().await;
248        assert!(monitor.quality_score >= 0.0);
249        assert!(monitor.quality_score <= 1.0);
250        assert!(!monitor.history.is_empty());
251    }
252    #[test]
253    fn test_buffer_performance_predictor() {
254        let predictor = BufferPerformancePredictor {
255            performance_history: vec![
256                BufferPerformancePoint {
257                    buffer_size: 500,
258                    throughput: 100.0,
259                    latency: 50.0,
260                    memory_usage: 50000,
261                    timestamp: Instant::now(),
262                },
263                BufferPerformancePoint {
264                    buffer_size: 1000,
265                    throughput: 150.0,
266                    latency: 40.0,
267                    memory_usage: 100000,
268                    timestamp: Instant::now(),
269                },
270            ],
271            model_params: vec![],
272            accuracy: 0.8,
273        };
274        let optimal_size = predictor.predict_optimal_size().unwrap();
275        assert_eq!(optimal_size, 1000);
276    }
277    #[tokio::test]
278    async fn test_importance_scorer() {
279        let scorer = ImportanceScorer::new();
280        let critical_event = StreamEvent::AnomalyDetected {
281            session_id: Uuid::new_v4(),
282            anomaly_type: AnomalyType::GradientExplosion,
283            severity: AnomalySeverity::Critical,
284            description: "Critical gradient explosion".to_string(),
285            confidence: 0.95,
286            affected_components: vec!["layer1".to_string()],
287            timestamp: SystemTime::now(),
288        };
289        let low_event = StreamEvent::AnomalyDetected {
290            session_id: Uuid::new_v4(),
291            anomaly_type: AnomalyType::TrainingStagnation,
292            severity: AnomalySeverity::Low,
293            description: "Slow convergence detected".to_string(),
294            confidence: 0.6,
295            affected_components: vec!["layer2".to_string()],
296            timestamp: SystemTime::now(),
297        };
298        let critical_score = scorer.calculate_importance(&critical_event).await.unwrap();
299        let low_score = scorer.calculate_importance(&low_event).await.unwrap();
300        assert!(critical_score > low_score);
301    }
302}