Skip to main content

trustformers_debug/streaming_debugger/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::*;
6use anyhow::Result;
7use std::sync::Arc;
8use tracing::info;
9/// Integration with main debug session
10impl crate::DebugSession {
11    /// Enable streaming for this debug session
12    pub async fn enable_streaming(
13        &mut self,
14        config: StreamingDebugConfig,
15    ) -> Result<Arc<StreamingDebugger>> {
16        let streaming_debugger = Arc::new(StreamingDebugger::new(config));
17        streaming_debugger.start().await?;
18        info!("Enabled streaming for debug session {}", self.id());
19        Ok(streaming_debugger)
20    }
21}
22/// Convenience macros for streaming debugging
23#[macro_export]
24macro_rules! stream_tensor {
25    ($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
26        let tensor_id = uuid::Uuid::new_v4();
27        let shape = $tensor.shape().to_vec();
28        let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
29        $streamer
30            .send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
31            .await
32    }};
33}
34#[macro_export]
35macro_rules! stream_gradients {
36    ($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
37        let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
38        $streamer
39            .send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
40            .await
41    }};
42}
43#[macro_export]
44macro_rules! stream_anomaly {
45    (
46        $streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
47        $description:expr
48    ) => {{
49        $streamer
50            .send_anomaly_detected(
51                $session_id,
52                $anomaly_type,
53                $severity,
54                $description.to_string(),
55                0.95,
56                vec![],
57            )
58            .await
59    }};
60}
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use std::collections::HashMap;
65    use std::time::{Duration, SystemTime};
66
67    use uuid::Uuid;
68    #[tokio::test]
69    async fn test_streaming_debugger_creation() {
70        let config = StreamingDebugConfig::default();
71        let debugger = StreamingDebugger::new(config);
72        assert!(!*debugger.is_running.read().await);
73    }
74    #[tokio::test(flavor = "multi_thread")]
75    #[ignore] // TODO: Fix timeout issue in streaming debugger
76    async fn test_start_stop_streaming() {
77        let config = StreamingDebugConfig {
78            stream_interval_ms: 50,
79            ..Default::default()
80        };
81        let debugger = StreamingDebugger::new(config);
82        let test_result = tokio::time::timeout(Duration::from_secs(3), async {
83            assert!(debugger.start().await.is_ok());
84            assert!(*debugger.is_running.read().await);
85            tokio::time::sleep(Duration::from_millis(50)).await;
86            assert!(debugger.stop().await.is_ok());
87            assert!(!*debugger.is_running.read().await);
88            tokio::time::sleep(Duration::from_millis(100)).await;
89            Ok::<(), anyhow::Error>(())
90        })
91        .await;
92        assert!(test_result.is_ok(), "Test timed out");
93        assert!(test_result.expect("test should not time out").is_ok());
94    }
95    #[tokio::test(flavor = "multi_thread")]
96    #[ignore] // TODO: Fix timeout issue in streaming debugger
97    async fn test_subscription() {
98        let config = StreamingDebugConfig {
99            stream_interval_ms: 50,
100            ..Default::default()
101        };
102        let debugger = StreamingDebugger::new(config);
103        let test_result = tokio::time::timeout(Duration::from_secs(3), async {
104            debugger.start().await.expect("start should succeed");
105            let subscription = debugger
106                .subscribe(
107                    "test_subscriber".to_string(),
108                    StreamFormat::Json,
109                    StreamFilter::default(),
110                )
111                .await
112                .expect("subscribe should succeed");
113            assert_eq!(debugger.get_subscribers().await.len(), 1);
114            debugger
115                .unsubscribe(subscription.subscriber_id())
116                .await
117                .expect("unsubscribe should succeed");
118            assert_eq!(debugger.get_subscribers().await.len(), 0);
119            debugger.stop().await.expect("stop should succeed");
120            tokio::time::sleep(Duration::from_millis(100)).await;
121            Ok::<(), anyhow::Error>(())
122        })
123        .await;
124        assert!(test_result.is_ok(), "Test timed out");
125        assert!(test_result.expect("test should not time out").is_ok());
126    }
127    #[tokio::test]
128    async fn test_tensor_statistics() {
129        let config = StreamingDebugConfig::default();
130        let debugger = StreamingDebugger::new(config);
131        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
132        let stats = debugger.compute_tensor_statistics(&values);
133        assert_eq!(stats.mean, 3.0);
134        assert!(stats.std > 0.0);
135        assert_eq!(stats.min, 1.0);
136        assert_eq!(stats.max, 5.0);
137        assert_eq!(stats.zero_count, 0);
138    }
139    #[tokio::test]
140    async fn test_gradient_statistics() {
141        let config = StreamingDebugConfig::default();
142        let debugger = StreamingDebugger::new(config);
143        let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
144        let stats = debugger.compute_gradient_statistics(&gradients);
145        assert!(stats.l1_norm > 0.0);
146        assert!(stats.l2_norm > 0.0);
147        assert_eq!(stats.max_grad, 0.3);
148        assert_eq!(stats.min_grad, -0.2);
149    }
150    #[tokio::test]
151    async fn test_event_filtering() {
152        let session_id1 = Uuid::new_v4();
153        let session_id2 = Uuid::new_v4();
154        let filter = StreamFilter {
155            session_ids: Some(vec![session_id1]),
156            event_types: Some(vec!["TensorData".to_string()]),
157            min_severity: None,
158            time_range: None,
159            custom_filters: HashMap::new(),
160        };
161        let matching_event = StreamEvent::TensorData {
162            session_id: session_id1,
163            tensor_id: Uuid::new_v4(),
164            name: "test".to_string(),
165            shape: vec![2, 2],
166            values: vec![1.0, 2.0, 3.0, 4.0],
167            statistics: TensorStatistics {
168                mean: 2.5,
169                std: 1.29,
170                min: 1.0,
171                max: 4.0,
172                nan_count: 0,
173                inf_count: 0,
174                zero_count: 0,
175                sparsity: 0.0,
176            },
177            timestamp: SystemTime::now(),
178        };
179        let non_matching_event = StreamEvent::TensorData {
180            session_id: session_id2,
181            tensor_id: Uuid::new_v4(),
182            name: "test".to_string(),
183            shape: vec![2, 2],
184            values: vec![1.0, 2.0, 3.0, 4.0],
185            statistics: TensorStatistics {
186                mean: 2.5,
187                std: 1.29,
188                min: 1.0,
189                max: 4.0,
190                nan_count: 0,
191                inf_count: 0,
192                zero_count: 0,
193                sparsity: 0.0,
194            },
195            timestamp: SystemTime::now(),
196        };
197        assert!(StreamSubscription::matches_filter(&matching_event, &filter));
198        assert!(!StreamSubscription::matches_filter(
199            &non_matching_event,
200            &filter
201        ));
202    }
203}
204/// Trait for custom aggregation rules
205pub trait AggregationRule {
206    fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
207    fn rule_name(&self) -> &str;
208}
209#[cfg(test)]
210mod enhanced_tests {
211    use super::*;
212    use std::time::{Duration, Instant, SystemTime};
213    use uuid::Uuid;
214    #[tokio::test(flavor = "multi_thread")]
215    #[ignore] // TODO: Fix timeout issue in enhanced streaming debugger
216    async fn test_enhanced_streaming_debugger() {
217        let base_config = StreamingDebugConfig {
218            stream_interval_ms: 50,
219            ..Default::default()
220        };
221        let adaptive_config = AdaptiveStreamingConfig {
222            monitoring_interval_ms: 500,
223            ..Default::default()
224        };
225        let aggregation_config = RealTimeAggregationConfig {
226            window_size_seconds: 1,
227            ..Default::default()
228        };
229        let buffering_config = IntelligentBufferingConfig::default();
230        let mut debugger = EnhancedStreamingDebugger::new(
231            base_config,
232            adaptive_config,
233            aggregation_config,
234            buffering_config,
235        );
236        let test_result = tokio::time::timeout(Duration::from_secs(5), async {
237            assert!(debugger.start_enhanced_streaming().await.is_ok());
238            tokio::time::sleep(Duration::from_millis(100)).await;
239            assert!(debugger.stop_enhanced_streaming().await.is_ok());
240            tokio::time::sleep(Duration::from_millis(200)).await;
241            Ok::<(), anyhow::Error>(())
242        })
243        .await;
244        assert!(test_result.is_ok(), "Test timed out");
245        assert!(test_result.expect("test should not time out").is_ok());
246    }
247    #[tokio::test]
248    async fn test_network_condition_monitor() {
249        let mut monitor = NetworkConditionMonitor::new();
250        monitor.update_conditions().await;
251        assert!(monitor.quality_score >= 0.0);
252        assert!(monitor.quality_score <= 1.0);
253        assert!(!monitor.history.is_empty());
254    }
255    #[test]
256    fn test_buffer_performance_predictor() {
257        let predictor = BufferPerformancePredictor {
258            performance_history: vec![
259                BufferPerformancePoint {
260                    buffer_size: 500,
261                    throughput: 100.0,
262                    latency: 50.0,
263                    memory_usage: 50000,
264                    timestamp: Instant::now(),
265                },
266                BufferPerformancePoint {
267                    buffer_size: 1000,
268                    throughput: 150.0,
269                    latency: 40.0,
270                    memory_usage: 100000,
271                    timestamp: Instant::now(),
272                },
273            ],
274            model_params: vec![],
275            accuracy: 0.8,
276        };
277        let optimal_size =
278            predictor.predict_optimal_size().expect("predict_optimal_size should succeed");
279        assert_eq!(optimal_size, 1000);
280    }
281    #[tokio::test]
282    async fn test_importance_scorer() {
283        let scorer = ImportanceScorer::new();
284        let critical_event = StreamEvent::AnomalyDetected {
285            session_id: Uuid::new_v4(),
286            anomaly_type: AnomalyType::GradientExplosion,
287            severity: AnomalySeverity::Critical,
288            description: "Critical gradient explosion".to_string(),
289            confidence: 0.95,
290            affected_components: vec!["layer1".to_string()],
291            timestamp: SystemTime::now(),
292        };
293        let low_event = StreamEvent::AnomalyDetected {
294            session_id: Uuid::new_v4(),
295            anomaly_type: AnomalyType::TrainingStagnation,
296            severity: AnomalySeverity::Low,
297            description: "Slow convergence detected".to_string(),
298            confidence: 0.6,
299            affected_components: vec!["layer2".to_string()],
300            timestamp: SystemTime::now(),
301        };
302        let critical_score = scorer
303            .calculate_importance(&critical_event)
304            .await
305            .expect("calculate_importance should succeed for critical event");
306        let low_score = scorer
307            .calculate_importance(&low_event)
308            .await
309            .expect("calculate_importance should succeed for low event");
310        assert!(critical_score > low_score);
311    }
312}