trustformers_debug/streaming_debugger/
functions.rs1use super::types::*;
6use anyhow::Result;
7use std::sync::Arc;
8use tracing::info;
9impl crate::DebugSession {
11 pub async fn enable_streaming(
13 &mut self,
14 config: StreamingDebugConfig,
15 ) -> Result<Arc<StreamingDebugger>> {
16 let streaming_debugger = Arc::new(StreamingDebugger::new(config));
17 streaming_debugger.start().await?;
18 info!("Enabled streaming for debug session {}", self.id());
19 Ok(streaming_debugger)
20 }
21}
22#[macro_export]
24macro_rules! stream_tensor {
25 ($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
26 let tensor_id = uuid::Uuid::new_v4();
27 let shape = $tensor.shape().to_vec();
28 let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
29 $streamer
30 .send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
31 .await
32 }};
33}
34#[macro_export]
35macro_rules! stream_gradients {
36 ($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
37 let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
38 $streamer
39 .send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
40 .await
41 }};
42}
43#[macro_export]
44macro_rules! stream_anomaly {
45 (
46 $streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
47 $description:expr
48 ) => {{
49 $streamer
50 .send_anomaly_detected(
51 $session_id,
52 $anomaly_type,
53 $severity,
54 $description.to_string(),
55 0.95,
56 vec![],
57 )
58 .await
59 }};
60}
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use std::collections::HashMap;
65 use std::time::{Duration, SystemTime};
66
67 use uuid::Uuid;
68 #[tokio::test]
69 async fn test_streaming_debugger_creation() {
70 let config = StreamingDebugConfig::default();
71 let debugger = StreamingDebugger::new(config);
72 assert!(!*debugger.is_running.read().await);
73 }
74 #[tokio::test(flavor = "multi_thread")]
75 #[ignore] async fn test_start_stop_streaming() {
77 let config = StreamingDebugConfig {
78 stream_interval_ms: 50,
79 ..Default::default()
80 };
81 let debugger = StreamingDebugger::new(config);
82 let test_result = tokio::time::timeout(Duration::from_secs(3), async {
83 assert!(debugger.start().await.is_ok());
84 assert!(*debugger.is_running.read().await);
85 tokio::time::sleep(Duration::from_millis(50)).await;
86 assert!(debugger.stop().await.is_ok());
87 assert!(!*debugger.is_running.read().await);
88 tokio::time::sleep(Duration::from_millis(100)).await;
89 Ok::<(), anyhow::Error>(())
90 })
91 .await;
92 assert!(test_result.is_ok(), "Test timed out");
93 assert!(test_result.unwrap().is_ok());
94 }
95 #[tokio::test(flavor = "multi_thread")]
96 #[ignore] async fn test_subscription() {
98 let config = StreamingDebugConfig {
99 stream_interval_ms: 50,
100 ..Default::default()
101 };
102 let debugger = StreamingDebugger::new(config);
103 let test_result = tokio::time::timeout(Duration::from_secs(3), async {
104 debugger.start().await.unwrap();
105 let subscription = debugger
106 .subscribe(
107 "test_subscriber".to_string(),
108 StreamFormat::Json,
109 StreamFilter::default(),
110 )
111 .await
112 .unwrap();
113 assert_eq!(debugger.get_subscribers().await.len(), 1);
114 debugger.unsubscribe(subscription.subscriber_id()).await.unwrap();
115 assert_eq!(debugger.get_subscribers().await.len(), 0);
116 debugger.stop().await.unwrap();
117 tokio::time::sleep(Duration::from_millis(100)).await;
118 Ok::<(), anyhow::Error>(())
119 })
120 .await;
121 assert!(test_result.is_ok(), "Test timed out");
122 assert!(test_result.unwrap().is_ok());
123 }
124 #[tokio::test]
125 async fn test_tensor_statistics() {
126 let config = StreamingDebugConfig::default();
127 let debugger = StreamingDebugger::new(config);
128 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
129 let stats = debugger.compute_tensor_statistics(&values);
130 assert_eq!(stats.mean, 3.0);
131 assert!(stats.std > 0.0);
132 assert_eq!(stats.min, 1.0);
133 assert_eq!(stats.max, 5.0);
134 assert_eq!(stats.zero_count, 0);
135 }
136 #[tokio::test]
137 async fn test_gradient_statistics() {
138 let config = StreamingDebugConfig::default();
139 let debugger = StreamingDebugger::new(config);
140 let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
141 let stats = debugger.compute_gradient_statistics(&gradients);
142 assert!(stats.l1_norm > 0.0);
143 assert!(stats.l2_norm > 0.0);
144 assert_eq!(stats.max_grad, 0.3);
145 assert_eq!(stats.min_grad, -0.2);
146 }
147 #[tokio::test]
148 async fn test_event_filtering() {
149 let session_id1 = Uuid::new_v4();
150 let session_id2 = Uuid::new_v4();
151 let filter = StreamFilter {
152 session_ids: Some(vec![session_id1]),
153 event_types: Some(vec!["TensorData".to_string()]),
154 min_severity: None,
155 time_range: None,
156 custom_filters: HashMap::new(),
157 };
158 let matching_event = StreamEvent::TensorData {
159 session_id: session_id1,
160 tensor_id: Uuid::new_v4(),
161 name: "test".to_string(),
162 shape: vec![2, 2],
163 values: vec![1.0, 2.0, 3.0, 4.0],
164 statistics: TensorStatistics {
165 mean: 2.5,
166 std: 1.29,
167 min: 1.0,
168 max: 4.0,
169 nan_count: 0,
170 inf_count: 0,
171 zero_count: 0,
172 sparsity: 0.0,
173 },
174 timestamp: SystemTime::now(),
175 };
176 let non_matching_event = StreamEvent::TensorData {
177 session_id: session_id2,
178 tensor_id: Uuid::new_v4(),
179 name: "test".to_string(),
180 shape: vec![2, 2],
181 values: vec![1.0, 2.0, 3.0, 4.0],
182 statistics: TensorStatistics {
183 mean: 2.5,
184 std: 1.29,
185 min: 1.0,
186 max: 4.0,
187 nan_count: 0,
188 inf_count: 0,
189 zero_count: 0,
190 sparsity: 0.0,
191 },
192 timestamp: SystemTime::now(),
193 };
194 assert!(StreamSubscription::matches_filter(&matching_event, &filter));
195 assert!(!StreamSubscription::matches_filter(
196 &non_matching_event,
197 &filter
198 ));
199 }
200}
201pub trait AggregationRule {
203 fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
204 fn rule_name(&self) -> &str;
205}
206#[cfg(test)]
207mod enhanced_tests {
208 use super::*;
209 use std::time::{Duration, Instant, SystemTime};
210 use uuid::Uuid;
211 #[tokio::test(flavor = "multi_thread")]
212 #[ignore] async fn test_enhanced_streaming_debugger() {
214 let base_config = StreamingDebugConfig {
215 stream_interval_ms: 50,
216 ..Default::default()
217 };
218 let adaptive_config = AdaptiveStreamingConfig {
219 monitoring_interval_ms: 500,
220 ..Default::default()
221 };
222 let aggregation_config = RealTimeAggregationConfig {
223 window_size_seconds: 1,
224 ..Default::default()
225 };
226 let buffering_config = IntelligentBufferingConfig::default();
227 let mut debugger = EnhancedStreamingDebugger::new(
228 base_config,
229 adaptive_config,
230 aggregation_config,
231 buffering_config,
232 );
233 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
234 assert!(debugger.start_enhanced_streaming().await.is_ok());
235 tokio::time::sleep(Duration::from_millis(100)).await;
236 assert!(debugger.stop_enhanced_streaming().await.is_ok());
237 tokio::time::sleep(Duration::from_millis(200)).await;
238 Ok::<(), anyhow::Error>(())
239 })
240 .await;
241 assert!(test_result.is_ok(), "Test timed out");
242 assert!(test_result.unwrap().is_ok());
243 }
244 #[tokio::test]
245 async fn test_network_condition_monitor() {
246 let mut monitor = NetworkConditionMonitor::new();
247 monitor.update_conditions().await;
248 assert!(monitor.quality_score >= 0.0);
249 assert!(monitor.quality_score <= 1.0);
250 assert!(!monitor.history.is_empty());
251 }
252 #[test]
253 fn test_buffer_performance_predictor() {
254 let predictor = BufferPerformancePredictor {
255 performance_history: vec![
256 BufferPerformancePoint {
257 buffer_size: 500,
258 throughput: 100.0,
259 latency: 50.0,
260 memory_usage: 50000,
261 timestamp: Instant::now(),
262 },
263 BufferPerformancePoint {
264 buffer_size: 1000,
265 throughput: 150.0,
266 latency: 40.0,
267 memory_usage: 100000,
268 timestamp: Instant::now(),
269 },
270 ],
271 model_params: vec![],
272 accuracy: 0.8,
273 };
274 let optimal_size = predictor.predict_optimal_size().unwrap();
275 assert_eq!(optimal_size, 1000);
276 }
277 #[tokio::test]
278 async fn test_importance_scorer() {
279 let scorer = ImportanceScorer::new();
280 let critical_event = StreamEvent::AnomalyDetected {
281 session_id: Uuid::new_v4(),
282 anomaly_type: AnomalyType::GradientExplosion,
283 severity: AnomalySeverity::Critical,
284 description: "Critical gradient explosion".to_string(),
285 confidence: 0.95,
286 affected_components: vec!["layer1".to_string()],
287 timestamp: SystemTime::now(),
288 };
289 let low_event = StreamEvent::AnomalyDetected {
290 session_id: Uuid::new_v4(),
291 anomaly_type: AnomalyType::TrainingStagnation,
292 severity: AnomalySeverity::Low,
293 description: "Slow convergence detected".to_string(),
294 confidence: 0.6,
295 affected_components: vec!["layer2".to_string()],
296 timestamp: SystemTime::now(),
297 };
298 let critical_score = scorer.calculate_importance(&critical_event).await.unwrap();
299 let low_score = scorer.calculate_importance(&low_event).await.unwrap();
300 assert!(critical_score > low_score);
301 }
302}