trustformers_debug/streaming_debugger/
functions.rs1use super::types::*;
6use anyhow::Result;
7use std::sync::Arc;
8use tracing::info;
9impl crate::DebugSession {
11 pub async fn enable_streaming(
13 &mut self,
14 config: StreamingDebugConfig,
15 ) -> Result<Arc<StreamingDebugger>> {
16 let streaming_debugger = Arc::new(StreamingDebugger::new(config));
17 streaming_debugger.start().await?;
18 info!("Enabled streaming for debug session {}", self.id());
19 Ok(streaming_debugger)
20 }
21}
22#[macro_export]
24macro_rules! stream_tensor {
25 ($streamer:expr, $session_id:expr, $tensor:expr, $name:expr) => {{
26 let tensor_id = uuid::Uuid::new_v4();
27 let shape = $tensor.shape().to_vec();
28 let values: Vec<f64> = $tensor.iter().map(|&x| x.into()).collect();
29 $streamer
30 .send_tensor_data($session_id, tensor_id, $name.to_string(), shape, values)
31 .await
32 }};
33}
34#[macro_export]
35macro_rules! stream_gradients {
36 ($streamer:expr, $session_id:expr, $layer_name:expr, $gradients:expr) => {{
37 let gradient_values: Vec<f64> = $gradients.iter().map(|&x| x.into()).collect();
38 $streamer
39 .send_gradient_flow($session_id, $layer_name.to_string(), &gradient_values)
40 .await
41 }};
42}
43#[macro_export]
44macro_rules! stream_anomaly {
45 (
46 $streamer:expr, $session_id:expr, $anomaly_type:expr, $severity:expr,
47 $description:expr
48 ) => {{
49 $streamer
50 .send_anomaly_detected(
51 $session_id,
52 $anomaly_type,
53 $severity,
54 $description.to_string(),
55 0.95,
56 vec![],
57 )
58 .await
59 }};
60}
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use std::collections::HashMap;
65 use std::time::{Duration, SystemTime};
66
67 use uuid::Uuid;
68 #[tokio::test]
69 async fn test_streaming_debugger_creation() {
70 let config = StreamingDebugConfig::default();
71 let debugger = StreamingDebugger::new(config);
72 assert!(!*debugger.is_running.read().await);
73 }
74 #[tokio::test(flavor = "multi_thread")]
75 async fn test_start_stop_streaming() {
76 let config = StreamingDebugConfig {
77 stream_interval_ms: 50,
78 ..Default::default()
79 };
80 let debugger = StreamingDebugger::new(config);
81 let test_result = tokio::time::timeout(Duration::from_secs(3), async {
82 assert!(debugger.start().await.is_ok());
83 assert!(*debugger.is_running.read().await);
84 tokio::time::sleep(Duration::from_millis(50)).await;
85 assert!(debugger.stop().await.is_ok());
86 assert!(!*debugger.is_running.read().await);
87 tokio::time::sleep(Duration::from_millis(100)).await;
88 Ok::<(), anyhow::Error>(())
89 })
90 .await;
91 assert!(test_result.is_ok(), "Test timed out");
92 assert!(test_result.unwrap().is_ok());
93 }
94 #[tokio::test(flavor = "multi_thread")]
95 async fn test_subscription() {
96 let config = StreamingDebugConfig {
97 stream_interval_ms: 50,
98 ..Default::default()
99 };
100 let debugger = StreamingDebugger::new(config);
101 let test_result = tokio::time::timeout(Duration::from_secs(3), async {
102 debugger.start().await.unwrap();
103 let subscription = debugger
104 .subscribe(
105 "test_subscriber".to_string(),
106 StreamFormat::Json,
107 StreamFilter::default(),
108 )
109 .await
110 .unwrap();
111 assert_eq!(debugger.get_subscribers().await.len(), 1);
112 debugger.unsubscribe(subscription.subscriber_id()).await.unwrap();
113 assert_eq!(debugger.get_subscribers().await.len(), 0);
114 debugger.stop().await.unwrap();
115 tokio::time::sleep(Duration::from_millis(100)).await;
116 Ok::<(), anyhow::Error>(())
117 })
118 .await;
119 assert!(test_result.is_ok(), "Test timed out");
120 assert!(test_result.unwrap().is_ok());
121 }
122 #[tokio::test]
123 async fn test_tensor_statistics() {
124 let config = StreamingDebugConfig::default();
125 let debugger = StreamingDebugger::new(config);
126 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
127 let stats = debugger.compute_tensor_statistics(&values);
128 assert_eq!(stats.mean, 3.0);
129 assert!(stats.std > 0.0);
130 assert_eq!(stats.min, 1.0);
131 assert_eq!(stats.max, 5.0);
132 assert_eq!(stats.zero_count, 0);
133 }
134 #[tokio::test]
135 async fn test_gradient_statistics() {
136 let config = StreamingDebugConfig::default();
137 let debugger = StreamingDebugger::new(config);
138 let gradients = vec![0.1, -0.2, 0.3, -0.1, 0.0];
139 let stats = debugger.compute_gradient_statistics(&gradients);
140 assert!(stats.l1_norm > 0.0);
141 assert!(stats.l2_norm > 0.0);
142 assert_eq!(stats.max_grad, 0.3);
143 assert_eq!(stats.min_grad, -0.2);
144 }
145 #[tokio::test]
146 async fn test_event_filtering() {
147 let session_id1 = Uuid::new_v4();
148 let session_id2 = Uuid::new_v4();
149 let filter = StreamFilter {
150 session_ids: Some(vec![session_id1]),
151 event_types: Some(vec!["TensorData".to_string()]),
152 min_severity: None,
153 time_range: None,
154 custom_filters: HashMap::new(),
155 };
156 let matching_event = StreamEvent::TensorData {
157 session_id: session_id1,
158 tensor_id: Uuid::new_v4(),
159 name: "test".to_string(),
160 shape: vec![2, 2],
161 values: vec![1.0, 2.0, 3.0, 4.0],
162 statistics: TensorStatistics {
163 mean: 2.5,
164 std: 1.29,
165 min: 1.0,
166 max: 4.0,
167 nan_count: 0,
168 inf_count: 0,
169 zero_count: 0,
170 sparsity: 0.0,
171 },
172 timestamp: SystemTime::now(),
173 };
174 let non_matching_event = StreamEvent::TensorData {
175 session_id: session_id2,
176 tensor_id: Uuid::new_v4(),
177 name: "test".to_string(),
178 shape: vec![2, 2],
179 values: vec![1.0, 2.0, 3.0, 4.0],
180 statistics: TensorStatistics {
181 mean: 2.5,
182 std: 1.29,
183 min: 1.0,
184 max: 4.0,
185 nan_count: 0,
186 inf_count: 0,
187 zero_count: 0,
188 sparsity: 0.0,
189 },
190 timestamp: SystemTime::now(),
191 };
192 assert!(StreamSubscription::matches_filter(&matching_event, &filter));
193 assert!(!StreamSubscription::matches_filter(
194 &non_matching_event,
195 &filter
196 ));
197 }
198}
199pub trait AggregationRule {
201 fn aggregate(&self, events: &[StreamEvent]) -> Result<f64>;
202 fn rule_name(&self) -> &str;
203}
204#[cfg(test)]
205mod enhanced_tests {
206 use super::*;
207 use std::time::{Duration, Instant, SystemTime};
208 use uuid::Uuid;
209 #[tokio::test(flavor = "multi_thread")]
210 async fn test_enhanced_streaming_debugger() {
211 let base_config = StreamingDebugConfig {
212 stream_interval_ms: 50,
213 ..Default::default()
214 };
215 let adaptive_config = AdaptiveStreamingConfig {
216 monitoring_interval_ms: 500,
217 ..Default::default()
218 };
219 let aggregation_config = RealTimeAggregationConfig {
220 window_size_seconds: 1,
221 ..Default::default()
222 };
223 let buffering_config = IntelligentBufferingConfig::default();
224 let mut debugger = EnhancedStreamingDebugger::new(
225 base_config,
226 adaptive_config,
227 aggregation_config,
228 buffering_config,
229 );
230 let test_result = tokio::time::timeout(Duration::from_secs(5), async {
231 assert!(debugger.start_enhanced_streaming().await.is_ok());
232 tokio::time::sleep(Duration::from_millis(100)).await;
233 assert!(debugger.stop_enhanced_streaming().await.is_ok());
234 tokio::time::sleep(Duration::from_millis(200)).await;
235 Ok::<(), anyhow::Error>(())
236 })
237 .await;
238 assert!(test_result.is_ok(), "Test timed out");
239 assert!(test_result.unwrap().is_ok());
240 }
241 #[tokio::test]
242 async fn test_network_condition_monitor() {
243 let mut monitor = NetworkConditionMonitor::new();
244 monitor.update_conditions().await;
245 assert!(monitor.quality_score >= 0.0);
246 assert!(monitor.quality_score <= 1.0);
247 assert!(!monitor.history.is_empty());
248 }
249 #[test]
250 fn test_buffer_performance_predictor() {
251 let predictor = BufferPerformancePredictor {
252 performance_history: vec![
253 BufferPerformancePoint {
254 buffer_size: 500,
255 throughput: 100.0,
256 latency: 50.0,
257 memory_usage: 50000,
258 timestamp: Instant::now(),
259 },
260 BufferPerformancePoint {
261 buffer_size: 1000,
262 throughput: 150.0,
263 latency: 40.0,
264 memory_usage: 100000,
265 timestamp: Instant::now(),
266 },
267 ],
268 model_params: vec![],
269 accuracy: 0.8,
270 };
271 let optimal_size = predictor.predict_optimal_size().unwrap();
272 assert_eq!(optimal_size, 1000);
273 }
274 #[tokio::test]
275 async fn test_importance_scorer() {
276 let scorer = ImportanceScorer::new();
277 let critical_event = StreamEvent::AnomalyDetected {
278 session_id: Uuid::new_v4(),
279 anomaly_type: AnomalyType::GradientExplosion,
280 severity: AnomalySeverity::Critical,
281 description: "Critical gradient explosion".to_string(),
282 confidence: 0.95,
283 affected_components: vec!["layer1".to_string()],
284 timestamp: SystemTime::now(),
285 };
286 let low_event = StreamEvent::AnomalyDetected {
287 session_id: Uuid::new_v4(),
288 anomaly_type: AnomalyType::TrainingStagnation,
289 severity: AnomalySeverity::Low,
290 description: "Slow convergence detected".to_string(),
291 confidence: 0.6,
292 affected_components: vec!["layer2".to_string()],
293 timestamp: SystemTime::now(),
294 };
295 let critical_score = scorer.calculate_importance(&critical_event).await.unwrap();
296 let low_score = scorer.calculate_importance(&low_event).await.unwrap();
297 assert!(critical_score > low_score);
298 }
299}