1use std::collections::HashMap;
4use std::time::{Duration, Instant};
5
6use crate::capabilities::DeviceType;
7use crate::profiling::ProfileData;
8use crate::strategy::ExecutionStrategy;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExecutionPhase {
13 Preparing,
15 Executing,
17 Waiting,
19 Completed,
21 Failed,
23 Cancelled,
25}
26
27impl ExecutionPhase {
28 pub fn as_str(&self) -> &str {
29 match self {
30 ExecutionPhase::Preparing => "Preparing",
31 ExecutionPhase::Executing => "Executing",
32 ExecutionPhase::Waiting => "Waiting",
33 ExecutionPhase::Completed => "Completed",
34 ExecutionPhase::Failed => "Failed",
35 ExecutionPhase::Cancelled => "Cancelled",
36 }
37 }
38
39 pub fn is_terminal(&self) -> bool {
40 matches!(
41 self,
42 ExecutionPhase::Completed | ExecutionPhase::Failed | ExecutionPhase::Cancelled
43 )
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct ExecutionState {
50 pub phase: ExecutionPhase,
51 pub progress: f64, pub current_node: Option<usize>,
53 pub nodes_completed: usize,
54 pub total_nodes: usize,
55 pub start_time: Option<Instant>,
56 pub end_time: Option<Instant>,
57 pub error_message: Option<String>,
58}
59
60impl ExecutionState {
61 pub fn new(total_nodes: usize) -> Self {
62 ExecutionState {
63 phase: ExecutionPhase::Preparing,
64 progress: 0.0,
65 current_node: None,
66 nodes_completed: 0,
67 total_nodes,
68 start_time: None,
69 end_time: None,
70 error_message: None,
71 }
72 }
73
74 pub fn start(&mut self) {
75 self.phase = ExecutionPhase::Executing;
76 self.start_time = Some(Instant::now());
77 }
78
79 pub fn complete(&mut self) {
80 self.phase = ExecutionPhase::Completed;
81 self.end_time = Some(Instant::now());
82 self.progress = 1.0;
83 }
84
85 pub fn fail(&mut self, error: impl Into<String>) {
86 self.phase = ExecutionPhase::Failed;
87 self.end_time = Some(Instant::now());
88 self.error_message = Some(error.into());
89 }
90
91 pub fn cancel(&mut self) {
92 self.phase = ExecutionPhase::Cancelled;
93 self.end_time = Some(Instant::now());
94 }
95
96 pub fn update_progress(&mut self, node_idx: usize) {
97 self.current_node = Some(node_idx);
98 self.nodes_completed = node_idx + 1;
99 self.progress = if self.total_nodes > 0 {
100 self.nodes_completed as f64 / self.total_nodes as f64
101 } else {
102 0.0
103 };
104 }
105
106 pub fn elapsed(&self) -> Option<Duration> {
107 self.start_time.map(|start| {
108 self.end_time
109 .unwrap_or_else(Instant::now)
110 .duration_since(start)
111 })
112 }
113
114 pub fn is_running(&self) -> bool {
115 self.phase == ExecutionPhase::Executing
116 }
117
118 pub fn is_complete(&self) -> bool {
119 self.phase.is_terminal()
120 }
121}
122
123pub trait ExecutionHook: Send {
125 fn on_phase_change(&mut self, phase: ExecutionPhase, state: &ExecutionState);
127
128 fn on_node_start(&mut self, node_idx: usize, state: &ExecutionState);
130
131 fn on_node_complete(&mut self, node_idx: usize, duration: Duration, state: &ExecutionState);
133
134 fn on_error(&mut self, error: &str, state: &ExecutionState);
136
137 fn on_complete(&mut self, state: &ExecutionState);
139}
140
141pub struct LoggingHook {
143 log_phase_changes: bool,
144 log_node_execution: bool,
145}
146
147impl LoggingHook {
148 pub fn new() -> Self {
149 LoggingHook {
150 log_phase_changes: true,
151 log_node_execution: false,
152 }
153 }
154
155 pub fn verbose() -> Self {
156 LoggingHook {
157 log_phase_changes: true,
158 log_node_execution: true,
159 }
160 }
161}
162
163impl Default for LoggingHook {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169impl ExecutionHook for LoggingHook {
170 fn on_phase_change(&mut self, phase: ExecutionPhase, _state: &ExecutionState) {
171 if self.log_phase_changes {
172 eprintln!("[ExecutionHook] Phase changed to: {}", phase.as_str());
173 }
174 }
175
176 fn on_node_start(&mut self, node_idx: usize, _state: &ExecutionState) {
177 if self.log_node_execution {
178 eprintln!("[ExecutionHook] Starting node {}", node_idx);
179 }
180 }
181
182 fn on_node_complete(&mut self, node_idx: usize, duration: Duration, _state: &ExecutionState) {
183 if self.log_node_execution {
184 eprintln!(
185 "[ExecutionHook] Completed node {} in {:.3}ms",
186 node_idx,
187 duration.as_secs_f64() * 1000.0
188 );
189 }
190 }
191
192 fn on_error(&mut self, error: &str, _state: &ExecutionState) {
193 eprintln!("[ExecutionHook] Error: {}", error);
194 }
195
196 fn on_complete(&mut self, state: &ExecutionState) {
197 if self.log_phase_changes {
198 if let Some(elapsed) = state.elapsed() {
199 eprintln!(
200 "[ExecutionHook] Execution completed in {:.3}s",
201 elapsed.as_secs_f64()
202 );
203 }
204 }
205 }
206}
207
208pub struct ExecutionContext {
210 pub state: ExecutionState,
211 pub strategy: ExecutionStrategy,
212 pub device: DeviceType,
213 pub profile_data: Option<ProfileData>,
214 pub metadata: HashMap<String, String>,
215 hooks: Vec<Box<dyn ExecutionHook>>,
216}
217
218impl ExecutionContext {
219 pub fn new(total_nodes: usize, strategy: ExecutionStrategy) -> Self {
220 ExecutionContext {
221 state: ExecutionState::new(total_nodes),
222 strategy,
223 device: DeviceType::CPU,
224 profile_data: None,
225 metadata: HashMap::new(),
226 hooks: Vec::new(),
227 }
228 }
229
230 pub fn with_device(mut self, device: DeviceType) -> Self {
231 self.device = device;
232 self
233 }
234
235 pub fn with_profiling(mut self, enable: bool) -> Self {
236 if enable {
237 self.profile_data = Some(ProfileData::new());
238 }
239 self
240 }
241
242 pub fn add_hook(&mut self, hook: Box<dyn ExecutionHook>) {
243 self.hooks.push(hook);
244 }
245
246 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
247 self.metadata.insert(key.into(), value.into());
248 }
249
250 pub fn get_metadata(&self, key: &str) -> Option<&str> {
251 self.metadata.get(key).map(|s| s.as_str())
252 }
253
254 pub fn start(&mut self) {
256 self.state.start();
257 self.notify_phase_change(ExecutionPhase::Executing);
258 }
259
260 pub fn complete(&mut self) {
261 self.state.complete();
262 self.notify_complete();
263 self.notify_phase_change(ExecutionPhase::Completed);
264 }
265
266 pub fn fail(&mut self, error: impl Into<String>) {
267 let error_msg = error.into();
268 self.notify_error(&error_msg);
269 self.state.fail(error_msg);
270 self.notify_phase_change(ExecutionPhase::Failed);
271 }
272
273 pub fn cancel(&mut self) {
274 self.state.cancel();
275 self.notify_phase_change(ExecutionPhase::Cancelled);
276 }
277
278 pub fn begin_node(&mut self, node_idx: usize) {
279 self.state.update_progress(node_idx);
280 self.notify_node_start(node_idx);
281 }
282
283 pub fn end_node(&mut self, node_idx: usize, duration: Duration) {
284 self.notify_node_complete(node_idx, duration);
285 }
286
287 fn notify_phase_change(&mut self, phase: ExecutionPhase) {
289 for hook in &mut self.hooks {
290 hook.on_phase_change(phase, &self.state);
291 }
292 }
293
294 fn notify_node_start(&mut self, node_idx: usize) {
295 for hook in &mut self.hooks {
296 hook.on_node_start(node_idx, &self.state);
297 }
298 }
299
300 fn notify_node_complete(&mut self, node_idx: usize, duration: Duration) {
301 for hook in &mut self.hooks {
302 hook.on_node_complete(node_idx, duration, &self.state);
303 }
304 }
305
306 fn notify_error(&mut self, error: &str) {
307 for hook in &mut self.hooks {
308 hook.on_error(error, &self.state);
309 }
310 }
311
312 fn notify_complete(&mut self) {
313 for hook in &mut self.hooks {
314 hook.on_complete(&self.state);
315 }
316 }
317
318 pub fn summary(&self) -> String {
319 let mut summary = String::new();
320 summary.push_str("Execution Context Summary\n");
321 summary.push_str("=========================\n\n");
322 summary.push_str(&format!("Phase: {}\n", self.state.phase.as_str()));
323 summary.push_str(&format!("Progress: {:.1}%\n", self.state.progress * 100.0));
324 summary.push_str(&format!(
325 "Nodes: {}/{}\n",
326 self.state.nodes_completed, self.state.total_nodes
327 ));
328
329 if let Some(elapsed) = self.state.elapsed() {
330 summary.push_str(&format!("Elapsed: {:.3}s\n", elapsed.as_secs_f64()));
331 }
332
333 summary.push_str(&format!("Device: {}\n", self.device.as_str()));
334 summary.push_str(&format!("Strategy: {:?}\n", self.strategy.mode));
335
336 if let Some(error) = &self.state.error_message {
337 summary.push_str(&format!("\nError: {}\n", error));
338 }
339
340 if !self.metadata.is_empty() {
341 summary.push_str("\nMetadata:\n");
342 for (key, value) in &self.metadata {
343 summary.push_str(&format!(" {}: {}\n", key, value));
344 }
345 }
346
347 summary
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_execution_phase() {
357 assert_eq!(ExecutionPhase::Preparing.as_str(), "Preparing");
358 assert!(!ExecutionPhase::Executing.is_terminal());
359 assert!(ExecutionPhase::Completed.is_terminal());
360 assert!(ExecutionPhase::Failed.is_terminal());
361 }
362
363 #[test]
364 fn test_execution_state_lifecycle() {
365 let mut state = ExecutionState::new(10);
366
367 assert_eq!(state.phase, ExecutionPhase::Preparing);
368 assert_eq!(state.progress, 0.0);
369
370 state.start();
371 assert_eq!(state.phase, ExecutionPhase::Executing);
372 assert!(state.is_running());
373
374 state.update_progress(5);
375 assert_eq!(state.current_node, Some(5));
376 assert_eq!(state.progress, 0.6);
377
378 state.complete();
379 assert_eq!(state.phase, ExecutionPhase::Completed);
380 assert!(state.is_complete());
381 assert_eq!(state.progress, 1.0);
382 }
383
384 #[test]
385 fn test_execution_state_failure() {
386 let mut state = ExecutionState::new(10);
387 state.start();
388 state.fail("Test error");
389
390 assert_eq!(state.phase, ExecutionPhase::Failed);
391 assert_eq!(state.error_message, Some("Test error".to_string()));
392 assert!(state.is_complete());
393 }
394
395 #[test]
396 fn test_execution_state_elapsed() {
397 let mut state = ExecutionState::new(5);
398 state.start();
399 std::thread::sleep(Duration::from_millis(10));
400 state.complete();
401
402 let elapsed = state.elapsed().unwrap();
403 assert!(elapsed.as_millis() >= 10);
404 }
405
406 #[test]
407 fn test_execution_context_creation() {
408 let strategy = ExecutionStrategy::inference();
409 let context = ExecutionContext::new(10, strategy);
410
411 assert_eq!(context.state.total_nodes, 10);
412 assert_eq!(context.device, DeviceType::CPU);
413 assert!(context.profile_data.is_none());
414 }
415
416 #[test]
417 fn test_execution_context_with_device() {
418 let strategy = ExecutionStrategy::inference();
419 let context = ExecutionContext::new(10, strategy).with_device(DeviceType::GPU);
420
421 assert_eq!(context.device, DeviceType::GPU);
422 }
423
424 #[test]
425 fn test_execution_context_with_profiling() {
426 let strategy = ExecutionStrategy::inference();
427 let context = ExecutionContext::new(10, strategy).with_profiling(true);
428
429 assert!(context.profile_data.is_some());
430 }
431
432 #[test]
433 fn test_execution_context_metadata() {
434 let strategy = ExecutionStrategy::inference();
435 let mut context = ExecutionContext::new(10, strategy);
436
437 context.set_metadata("graph_id", "test-123");
438 context.set_metadata("user", "test-user");
439
440 assert_eq!(context.get_metadata("graph_id"), Some("test-123"));
441 assert_eq!(context.get_metadata("user"), Some("test-user"));
442 assert_eq!(context.get_metadata("missing"), None);
443 }
444
445 #[test]
446 fn test_execution_context_lifecycle() {
447 let strategy = ExecutionStrategy::inference();
448 let mut context = ExecutionContext::new(5, strategy);
449
450 context.start();
451 assert!(context.state.is_running());
452
453 context.begin_node(0);
454 context.end_node(0, Duration::from_millis(10));
455
456 context.begin_node(1);
457 context.end_node(1, Duration::from_millis(15));
458
459 assert_eq!(context.state.nodes_completed, 2);
460 assert!(context.state.progress > 0.0);
461
462 context.complete();
463 assert!(context.state.is_complete());
464 assert_eq!(context.state.phase, ExecutionPhase::Completed);
465 }
466
467 #[test]
468 fn test_execution_context_failure() {
469 let strategy = ExecutionStrategy::inference();
470 let mut context = ExecutionContext::new(5, strategy);
471
472 context.start();
473 context.fail("Test error occurred");
474
475 assert_eq!(context.state.phase, ExecutionPhase::Failed);
476 assert_eq!(
477 context.state.error_message,
478 Some("Test error occurred".to_string())
479 );
480 }
481
482 #[test]
483 fn test_execution_context_summary() {
484 let strategy = ExecutionStrategy::inference();
485 let mut context = ExecutionContext::new(5, strategy);
486 context.set_metadata("test_key", "test_value");
487
488 context.start();
489 context.begin_node(2);
490
491 let summary = context.summary();
492 assert!(summary.contains("Execution Context Summary"));
493 assert!(summary.contains("Progress:"));
494 assert!(summary.contains("test_key"));
495 }
496
497 #[test]
498 fn test_logging_hook() {
499 let hook = LoggingHook::new();
500 assert!(hook.log_phase_changes);
501 assert!(!hook.log_node_execution);
502
503 let verbose_hook = LoggingHook::verbose();
504 assert!(verbose_hook.log_phase_changes);
505 assert!(verbose_hook.log_node_execution);
506 }
507
508 #[test]
509 fn test_execution_with_hooks() {
510 let strategy = ExecutionStrategy::inference();
511 let mut context = ExecutionContext::new(3, strategy);
512
513 context.add_hook(Box::new(LoggingHook::new()));
515
516 context.start();
517 context.begin_node(0);
518 context.end_node(0, Duration::from_millis(10));
519 context.complete();
520
521 assert_eq!(context.state.phase, ExecutionPhase::Completed);
523 }
524}