trustformers_debug/streaming_debugger/
functions.rs1use super::types::*;
6use anyhow::Result;
7use std::sync::Arc;
8use tracing::info;
9impl crate::DebugSession {
11 pub async fn enable_streaming(
13 &mut self,
14 config: StreamingDebugConfig,
15 ) -> Result<Arc<StreamingDebugger>> {
16 let streaming_debugger = Arc::new(StreamingDebugger::new(config));
17 streaming_debugger.start().await?;
18 info!("Enabled streaming for debug session {}", self.id());
19 Ok(streaming_debugger)
20 }
21}
22#[macro_export]
24macro_rules! stream_tensor {
25 ($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
26 let tensor_id = uuid::Uuid::new_v4();
27 let shape = $tensor.shape().to_vec();
28 let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
29 $streamer
30 .send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
31 .await
32 }};
33}
34#[macro_export]
35macro_rules! stream_gradients {
36 ($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
37 let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
38 $streamer
39 .send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
40 .await
41 }};
42}
43#[macro_export]
44macro_rules! stream_anomaly {
45 (
46 $streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
47 $description:expr
48 ) => {{
49 $streamer
50 .send_anomaly_detected(
51 $session_id,
52 $anomaly_type,
53 $severity,
54 $description.to_string(),
55 0.95,
56 vec![],
57 )
58 .await
59 }};
60}
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use std::collections::HashMap;
65 use std::time::{Duration, SystemTime};
66
67 use uuid::Uuid;
68 #[tokio::test]
69 async fn test_streaming_debugger_creation() {
70 let config = StreamingDebugConfig::default();
71 let debugger = StreamingDebugger::new(config);
72 assert!(!*debugger.is_running.read().await);
73 }
74 #[tokio::test(flavor = "multi_thread")]
75 #[ignore] async fn test_start_stop_streaming() {
77 let config = StreamingDebugConfig {
78 stream_interval_ms: 50,
79 ..Default::default()
80 };
81 let debugger = StreamingDebugger::new(config);
82 let test_result = tokio::time::timeout(Duration::from_secs(3), async {
83 assert!(debugger.start().await.is_ok());
84 assert!(*debugger.is_running.read().await);
85 tokio::time::sleep(Duration::from_millis(50)).await;
86 assert!(debugger.stop().await.is_ok());
87 assert!(!*debugger.is_running.read().await);
88 tokio::time::sleep(Duration::from_millis(100)).await;
89 Ok::<(), anyhow::Error>(())
90 })
91 .await;
92 assert!(test_result.is_ok(), "Test timed out");
93 assert!(test_result.expect("test should not time out").is_ok());
94 }
95 #[tokio::test(flavor = "multi_thread")]
96 #[ignore] async fn test_subscription() {
98 let config = StreamingDebugConfig {
99 stream_interval_ms: 50,
100 ..Default::default()
101 };
102 let debugger = StreamingDebugger::new(config);
103 let test_result = tokio::time::timeout(Duration::from_secs(3), async {
104 debugger.start().await.expect("start should succeed");
105 let subscription = debugger
106 .subscribe(
107 "test_subscriber".to_string(),
108 StreamFormat::Json,
109 StreamFilter::default(),
110 )
111 .await
112 .expect("subscribe should succeed");
113 assert_eq!(debugger.get_subscribers().await.len(), 1);
114 debugger
115 .unsubscribe(subscription.subscriber_id())
116 .await
117 .expect("unsubscribe should succeed");
118 assert_eq!(debugger.get_subscribers().await.len(), 0);
119 debugger.stop().await.expect("stop should succeed");
120 tokio::time::sleep(Duration::from_millis(100)).await;
121 Ok::<(), anyhow::Error>(())
122 })
123 .await;
124 assert!(test_result.is_ok(), "Test timed out");
125 assert!(test_result.expect("test should not time out").is_ok());
126 }
127 #[tokio::test]
128 async fn test_tensor_statistics() {
129 let config = StreamingDebugConfig::default();
130 let debugger = StreamingDebugger::new(config);
131 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
132 let stats = debugger.compute_tensor_statistics(&values);
133 assert_eq!(stats.mean, 3.0);
134 assert!(stats.std > 0.0);
135 assert_eq!(stats.min, 1.0);
136 assert_eq!(stats.max, 5.0);
137 assert_eq!(stats.zero_count, 0);
138 }
139 #[tokio::test]
140 async fn test_gradient_statistics() {
141 let config = StreamingDebugConfig::default();
142 let debugger = StreamingDebugger::new(config);
143 let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
144 let stats = debugger.compute_gradient_statistics(&gradients);
145 assert!(stats.l1_norm > 0.0);
146 assert!(stats.l2_norm > 0.0);
147 assert_eq!(stats.max_grad, 0.3);
148 assert_eq!(stats.min_grad, -0.2);
149 }
150 #[tokio::test]
151 async fn test_event_filtering() {
152 let session_id1 = Uuid::new_v4();
153 let session_id2 = Uuid::new_v4();
154 let filter = StreamFilter {
155 session_ids: Some(vec![session_id1]),
156 event_types: Some(vec!["TensorData".to_string()]),
157 min_severity: None,
158 time_range: None,
159 custom_filters: HashMap::new(),
160 };
161 let matching_event = StreamEvent::TensorData {
162 session_id: session_id1,
163 tensor_id: Uuid::new_v4(),
164 name: "test".to_string(),
165 shape: vec![2, 2],
166 values: vec![1.0, 2.0, 3.0, 4.0],
167 statistics: TensorStatistics {
168 mean: 2.5,
169 std: 1.29,
170 min: 1.0,
171 max: 4.0,
172 nan_count: 0,
173 inf_count: 0,
174 zero_count: 0,
175 sparsity: 0.0,
176 },
177 timestamp: SystemTime::now(),
178 };
179 let non_matching_event = StreamEvent::TensorData {
180 session_id: session_id2,
181 tensor_id: Uuid::new_v4(),
182 name: "test".to_string(),
183 shape: vec![2, 2],
184 values: vec![1.0, 2.0, 3.0, 4.0],
185 statistics: TensorStatistics {
186 mean: 2.5,
187 std: 1.29,
188 min: 1.0,
189 max: 4.0,
190 nan_count: 0,
191 inf_count: 0,
192 zero_count: 0,
193 sparsity: 0.0,
194 },
195 timestamp: SystemTime::now(),
196 };
197 assert!(StreamSubscription::matches_filter(&matching_event, &filter));
198 assert!(!StreamSubscription::matches_filter(
199 &non_matching_event,
200 &filter
201 ));
202 }
203}
204pub trait AggregationRule {
206 fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
207 fn rule_name(&self) -> &str;
208}
209#[cfg(test)]
210mod enhanced_tests {
211 use super::*;
212 use std::time::{Duration, Instant, SystemTime};
213 use uuid::Uuid;
214 #[tokio::test(flavor = "multi_thread")]
215 #[ignore] async fn test_enhanced_streaming_debugger() {
217 let base_config = StreamingDebugConfig {
218 stream_interval_ms: 50,
219 ..Default::default()
220 };
221 let adaptive_config = AdaptiveStreamingConfig {
222 monitoring_interval_ms: 500,
223 ..Default::default()
224 };
225 let aggregation_config = RealTimeAggregationConfig {
226 window_size_seconds: 1,
227 ..Default::default()
228 };
229 let buffering_config = IntelligentBufferingConfig::default();
230 let mut debugger = EnhancedStreamingDebugger::new(
231 base_config,
232 adaptive_config,
233 aggregation_config,
234 buffering_config,
235 );
236 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
237 assert!(debugger.start_enhanced_streaming().await.is_ok());
238 tokio::time::sleep(Duration::from_millis(100)).await;
239 assert!(debugger.stop_enhanced_streaming().await.is_ok());
240 tokio::time::sleep(Duration::from_millis(200)).await;
241 Ok::<(), anyhow::Error>(())
242 })
243 .await;
244 assert!(test_result.is_ok(), "Test timed out");
245 assert!(test_result.expect("test should not time out").is_ok());
246 }
247 #[tokio::test]
248 async fn test_network_condition_monitor() {
249 let mut monitor = NetworkConditionMonitor::new();
250 monitor.update_conditions().await;
251 assert!(monitor.quality_score >= 0.0);
252 assert!(monitor.quality_score <= 1.0);
253 assert!(!monitor.history.is_empty());
254 }
255 #[test]
256 fn test_buffer_performance_predictor() {
257 let predictor = BufferPerformancePredictor {
258 performance_history: vec![
259 BufferPerformancePoint {
260 buffer_size: 500,
261 throughput: 100.0,
262 latency: 50.0,
263 memory_usage: 50000,
264 timestamp: Instant::now(),
265 },
266 BufferPerformancePoint {
267 buffer_size: 1000,
268 throughput: 150.0,
269 latency: 40.0,
270 memory_usage: 100000,
271 timestamp: Instant::now(),
272 },
273 ],
274 model_params: vec![],
275 accuracy: 0.8,
276 };
277 let optimal_size =
278 predictor.predict_optimal_size().expect("predict_optimal_size should succeed");
279 assert_eq!(optimal_size, 1000);
280 }
281 #[tokio::test]
282 async fn test_importance_scorer() {
283 let scorer = ImportanceScorer::new();
284 let critical_event = StreamEvent::AnomalyDetected {
285 session_id: Uuid::new_v4(),
286 anomaly_type: AnomalyType::GradientExplosion,
287 severity: AnomalySeverity::Critical,
288 description: "Critical gradient explosion".to_string(),
289 confidence: 0.95,
290 affected_components: vec!["layer1".to_string()],
291 timestamp: SystemTime::now(),
292 };
293 let low_event = StreamEvent::AnomalyDetected {
294 session_id: Uuid::new_v4(),
295 anomaly_type: AnomalyType::TrainingStagnation,
296 severity: AnomalySeverity::Low,
297 description: "Slow convergence detected".to_string(),
298 confidence: 0.6,
299 affected_components: vec!["layer2".to_string()],
300 timestamp: SystemTime::now(),
301 };
302 let critical_score = scorer
303 .calculate_importance(&critical_event)
304 .await
305 .expect("calculate_importance should succeed for critical event");
306 let low_score = scorer
307 .calculate_importance(&low_event)
308 .await
309 .expect("calculate_importance should succeed for low event");
310 assert!(critical_score > low_score);
311 }
312}