Skip to main content

tensorlogic_infer/
context.rs

1//! Execution context and state management for coordinated execution.
2
3use std::collections::HashMap;
4use std::time::{Duration, Instant};
5
6use crate::capabilities::DeviceType;
7use crate::profiling::ProfileData;
8use crate::strategy::ExecutionStrategy;
9
10/// Execution phase for lifecycle tracking
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExecutionPhase {
13    /// Preparing for execution (validation, optimization)
14    Preparing,
15    /// Currently executing
16    Executing,
17    /// Waiting for resources or synchronization
18    Waiting,
19    /// Completed successfully
20    Completed,
21    /// Failed with error
22    Failed,
23    /// Cancelled by user
24    Cancelled,
25}
26
27impl ExecutionPhase {
28    pub fn as_str(&self) -> &str {
29        match self {
30            ExecutionPhase::Preparing => "Preparing",
31            ExecutionPhase::Executing => "Executing",
32            ExecutionPhase::Waiting => "Waiting",
33            ExecutionPhase::Completed => "Completed",
34            ExecutionPhase::Failed => "Failed",
35            ExecutionPhase::Cancelled => "Cancelled",
36        }
37    }
38
39    pub fn is_terminal(&self) -> bool {
40        matches!(
41            self,
42            ExecutionPhase::Completed | ExecutionPhase::Failed | ExecutionPhase::Cancelled
43        )
44    }
45}
46
47/// Execution state tracking
48#[derive(Debug, Clone)]
49pub struct ExecutionState {
50    pub phase: ExecutionPhase,
51    pub progress: f64, // 0.0 to 1.0
52    pub current_node: Option<usize>,
53    pub nodes_completed: usize,
54    pub total_nodes: usize,
55    pub start_time: Option<Instant>,
56    pub end_time: Option<Instant>,
57    pub error_message: Option<String>,
58}
59
60impl ExecutionState {
61    pub fn new(total_nodes: usize) -> Self {
62        ExecutionState {
63            phase: ExecutionPhase::Preparing,
64            progress: 0.0,
65            current_node: None,
66            nodes_completed: 0,
67            total_nodes,
68            start_time: None,
69            end_time: None,
70            error_message: None,
71        }
72    }
73
74    pub fn start(&mut self) {
75        self.phase = ExecutionPhase::Executing;
76        self.start_time = Some(Instant::now());
77    }
78
79    pub fn complete(&mut self) {
80        self.phase = ExecutionPhase::Completed;
81        self.end_time = Some(Instant::now());
82        self.progress = 1.0;
83    }
84
85    pub fn fail(&mut self, error: impl Into<String>) {
86        self.phase = ExecutionPhase::Failed;
87        self.end_time = Some(Instant::now());
88        self.error_message = Some(error.into());
89    }
90
91    pub fn cancel(&mut self) {
92        self.phase = ExecutionPhase::Cancelled;
93        self.end_time = Some(Instant::now());
94    }
95
96    pub fn update_progress(&mut self, node_idx: usize) {
97        self.current_node = Some(node_idx);
98        self.nodes_completed = node_idx + 1;
99        self.progress = if self.total_nodes > 0 {
100            self.nodes_completed as f64 / self.total_nodes as f64
101        } else {
102            0.0
103        };
104    }
105
106    pub fn elapsed(&self) -> Option<Duration> {
107        self.start_time.map(|start| {
108            self.end_time
109                .unwrap_or_else(Instant::now)
110                .duration_since(start)
111        })
112    }
113
114    pub fn is_running(&self) -> bool {
115        self.phase == ExecutionPhase::Executing
116    }
117
118    pub fn is_complete(&self) -> bool {
119        self.phase.is_terminal()
120    }
121}
122
123/// Hook for monitoring execution events
124pub trait ExecutionHook: Send {
125    /// Called when execution phase changes
126    fn on_phase_change(&mut self, phase: ExecutionPhase, state: &ExecutionState);
127
128    /// Called when a node starts executing
129    fn on_node_start(&mut self, node_idx: usize, state: &ExecutionState);
130
131    /// Called when a node completes
132    fn on_node_complete(&mut self, node_idx: usize, duration: Duration, state: &ExecutionState);
133
134    /// Called when an error occurs
135    fn on_error(&mut self, error: &str, state: &ExecutionState);
136
137    /// Called when execution completes
138    fn on_complete(&mut self, state: &ExecutionState);
139}
140
141/// Simple logging hook for demonstration
142pub struct LoggingHook {
143    log_phase_changes: bool,
144    log_node_execution: bool,
145}
146
147impl LoggingHook {
148    pub fn new() -> Self {
149        LoggingHook {
150            log_phase_changes: true,
151            log_node_execution: false,
152        }
153    }
154
155    pub fn verbose() -> Self {
156        LoggingHook {
157            log_phase_changes: true,
158            log_node_execution: true,
159        }
160    }
161}
162
163impl Default for LoggingHook {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169impl ExecutionHook for LoggingHook {
170    fn on_phase_change(&mut self, phase: ExecutionPhase, _state: &ExecutionState) {
171        if self.log_phase_changes {
172            eprintln!("[ExecutionHook] Phase changed to: {}", phase.as_str());
173        }
174    }
175
176    fn on_node_start(&mut self, node_idx: usize, _state: &ExecutionState) {
177        if self.log_node_execution {
178            eprintln!("[ExecutionHook] Starting node {}", node_idx);
179        }
180    }
181
182    fn on_node_complete(&mut self, node_idx: usize, duration: Duration, _state: &ExecutionState) {
183        if self.log_node_execution {
184            eprintln!(
185                "[ExecutionHook] Completed node {} in {:.3}ms",
186                node_idx,
187                duration.as_secs_f64() * 1000.0
188            );
189        }
190    }
191
192    fn on_error(&mut self, error: &str, _state: &ExecutionState) {
193        eprintln!("[ExecutionHook] Error: {}", error);
194    }
195
196    fn on_complete(&mut self, state: &ExecutionState) {
197        if self.log_phase_changes {
198            if let Some(elapsed) = state.elapsed() {
199                eprintln!(
200                    "[ExecutionHook] Execution completed in {:.3}s",
201                    elapsed.as_secs_f64()
202                );
203            }
204        }
205    }
206}
207
208/// Execution context for coordinated execution
209pub struct ExecutionContext {
210    pub state: ExecutionState,
211    pub strategy: ExecutionStrategy,
212    pub device: DeviceType,
213    pub profile_data: Option<ProfileData>,
214    pub metadata: HashMap<String, String>,
215    hooks: Vec<Box<dyn ExecutionHook>>,
216}
217
218impl ExecutionContext {
219    pub fn new(total_nodes: usize, strategy: ExecutionStrategy) -> Self {
220        ExecutionContext {
221            state: ExecutionState::new(total_nodes),
222            strategy,
223            device: DeviceType::CPU,
224            profile_data: None,
225            metadata: HashMap::new(),
226            hooks: Vec::new(),
227        }
228    }
229
230    pub fn with_device(mut self, device: DeviceType) -> Self {
231        self.device = device;
232        self
233    }
234
235    pub fn with_profiling(mut self, enable: bool) -> Self {
236        if enable {
237            self.profile_data = Some(ProfileData::new());
238        }
239        self
240    }
241
242    pub fn add_hook(&mut self, hook: Box<dyn ExecutionHook>) {
243        self.hooks.push(hook);
244    }
245
246    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
247        self.metadata.insert(key.into(), value.into());
248    }
249
250    pub fn get_metadata(&self, key: &str) -> Option<&str> {
251        self.metadata.get(key).map(|s| s.as_str())
252    }
253
254    // Lifecycle methods
255    pub fn start(&mut self) {
256        self.state.start();
257        self.notify_phase_change(ExecutionPhase::Executing);
258    }
259
260    pub fn complete(&mut self) {
261        self.state.complete();
262        self.notify_complete();
263        self.notify_phase_change(ExecutionPhase::Completed);
264    }
265
266    pub fn fail(&mut self, error: impl Into<String>) {
267        let error_msg = error.into();
268        self.notify_error(&error_msg);
269        self.state.fail(error_msg);
270        self.notify_phase_change(ExecutionPhase::Failed);
271    }
272
273    pub fn cancel(&mut self) {
274        self.state.cancel();
275        self.notify_phase_change(ExecutionPhase::Cancelled);
276    }
277
278    pub fn begin_node(&mut self, node_idx: usize) {
279        self.state.update_progress(node_idx);
280        self.notify_node_start(node_idx);
281    }
282
283    pub fn end_node(&mut self, node_idx: usize, duration: Duration) {
284        self.notify_node_complete(node_idx, duration);
285    }
286
287    // Hook notifications
288    fn notify_phase_change(&mut self, phase: ExecutionPhase) {
289        for hook in &mut self.hooks {
290            hook.on_phase_change(phase, &self.state);
291        }
292    }
293
294    fn notify_node_start(&mut self, node_idx: usize) {
295        for hook in &mut self.hooks {
296            hook.on_node_start(node_idx, &self.state);
297        }
298    }
299
300    fn notify_node_complete(&mut self, node_idx: usize, duration: Duration) {
301        for hook in &mut self.hooks {
302            hook.on_node_complete(node_idx, duration, &self.state);
303        }
304    }
305
306    fn notify_error(&mut self, error: &str) {
307        for hook in &mut self.hooks {
308            hook.on_error(error, &self.state);
309        }
310    }
311
312    fn notify_complete(&mut self) {
313        for hook in &mut self.hooks {
314            hook.on_complete(&self.state);
315        }
316    }
317
318    pub fn summary(&self) -> String {
319        let mut summary = String::new();
320        summary.push_str("Execution Context Summary\n");
321        summary.push_str("=========================\n\n");
322        summary.push_str(&format!("Phase: {}\n", self.state.phase.as_str()));
323        summary.push_str(&format!("Progress: {:.1}%\n", self.state.progress * 100.0));
324        summary.push_str(&format!(
325            "Nodes: {}/{}\n",
326            self.state.nodes_completed, self.state.total_nodes
327        ));
328
329        if let Some(elapsed) = self.state.elapsed() {
330            summary.push_str(&format!("Elapsed: {:.3}s\n", elapsed.as_secs_f64()));
331        }
332
333        summary.push_str(&format!("Device: {}\n", self.device.as_str()));
334        summary.push_str(&format!("Strategy: {:?}\n", self.strategy.mode));
335
336        if let Some(error) = &self.state.error_message {
337            summary.push_str(&format!("\nError: {}\n", error));
338        }
339
340        if !self.metadata.is_empty() {
341            summary.push_str("\nMetadata:\n");
342            for (key, value) in &self.metadata {
343                summary.push_str(&format!("  {}: {}\n", key, value));
344            }
345        }
346
347        summary
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_execution_phase() {
357        assert_eq!(ExecutionPhase::Preparing.as_str(), "Preparing");
358        assert!(!ExecutionPhase::Executing.is_terminal());
359        assert!(ExecutionPhase::Completed.is_terminal());
360        assert!(ExecutionPhase::Failed.is_terminal());
361    }
362
363    #[test]
364    fn test_execution_state_lifecycle() {
365        let mut state = ExecutionState::new(10);
366
367        assert_eq!(state.phase, ExecutionPhase::Preparing);
368        assert_eq!(state.progress, 0.0);
369
370        state.start();
371        assert_eq!(state.phase, ExecutionPhase::Executing);
372        assert!(state.is_running());
373
374        state.update_progress(5);
375        assert_eq!(state.current_node, Some(5));
376        assert_eq!(state.progress, 0.6);
377
378        state.complete();
379        assert_eq!(state.phase, ExecutionPhase::Completed);
380        assert!(state.is_complete());
381        assert_eq!(state.progress, 1.0);
382    }
383
384    #[test]
385    fn test_execution_state_failure() {
386        let mut state = ExecutionState::new(10);
387        state.start();
388        state.fail("Test error");
389
390        assert_eq!(state.phase, ExecutionPhase::Failed);
391        assert_eq!(state.error_message, Some("Test error".to_string()));
392        assert!(state.is_complete());
393    }
394
395    #[test]
396    fn test_execution_state_elapsed() {
397        let mut state = ExecutionState::new(5);
398        state.start();
399        std::thread::sleep(Duration::from_millis(10));
400        state.complete();
401
402        let elapsed = state.elapsed().unwrap();
403        assert!(elapsed.as_millis() >= 10);
404    }
405
406    #[test]
407    fn test_execution_context_creation() {
408        let strategy = ExecutionStrategy::inference();
409        let context = ExecutionContext::new(10, strategy);
410
411        assert_eq!(context.state.total_nodes, 10);
412        assert_eq!(context.device, DeviceType::CPU);
413        assert!(context.profile_data.is_none());
414    }
415
416    #[test]
417    fn test_execution_context_with_device() {
418        let strategy = ExecutionStrategy::inference();
419        let context = ExecutionContext::new(10, strategy).with_device(DeviceType::GPU);
420
421        assert_eq!(context.device, DeviceType::GPU);
422    }
423
424    #[test]
425    fn test_execution_context_with_profiling() {
426        let strategy = ExecutionStrategy::inference();
427        let context = ExecutionContext::new(10, strategy).with_profiling(true);
428
429        assert!(context.profile_data.is_some());
430    }
431
432    #[test]
433    fn test_execution_context_metadata() {
434        let strategy = ExecutionStrategy::inference();
435        let mut context = ExecutionContext::new(10, strategy);
436
437        context.set_metadata("graph_id", "test-123");
438        context.set_metadata("user", "test-user");
439
440        assert_eq!(context.get_metadata("graph_id"), Some("test-123"));
441        assert_eq!(context.get_metadata("user"), Some("test-user"));
442        assert_eq!(context.get_metadata("missing"), None);
443    }
444
445    #[test]
446    fn test_execution_context_lifecycle() {
447        let strategy = ExecutionStrategy::inference();
448        let mut context = ExecutionContext::new(5, strategy);
449
450        context.start();
451        assert!(context.state.is_running());
452
453        context.begin_node(0);
454        context.end_node(0, Duration::from_millis(10));
455
456        context.begin_node(1);
457        context.end_node(1, Duration::from_millis(15));
458
459        assert_eq!(context.state.nodes_completed, 2);
460        assert!(context.state.progress > 0.0);
461
462        context.complete();
463        assert!(context.state.is_complete());
464        assert_eq!(context.state.phase, ExecutionPhase::Completed);
465    }
466
467    #[test]
468    fn test_execution_context_failure() {
469        let strategy = ExecutionStrategy::inference();
470        let mut context = ExecutionContext::new(5, strategy);
471
472        context.start();
473        context.fail("Test error occurred");
474
475        assert_eq!(context.state.phase, ExecutionPhase::Failed);
476        assert_eq!(
477            context.state.error_message,
478            Some("Test error occurred".to_string())
479        );
480    }
481
482    #[test]
483    fn test_execution_context_summary() {
484        let strategy = ExecutionStrategy::inference();
485        let mut context = ExecutionContext::new(5, strategy);
486        context.set_metadata("test_key", "test_value");
487
488        context.start();
489        context.begin_node(2);
490
491        let summary = context.summary();
492        assert!(summary.contains("Execution Context Summary"));
493        assert!(summary.contains("Progress:"));
494        assert!(summary.contains("test_key"));
495    }
496
497    #[test]
498    fn test_logging_hook() {
499        let hook = LoggingHook::new();
500        assert!(hook.log_phase_changes);
501        assert!(!hook.log_node_execution);
502
503        let verbose_hook = LoggingHook::verbose();
504        assert!(verbose_hook.log_phase_changes);
505        assert!(verbose_hook.log_node_execution);
506    }
507
508    #[test]
509    fn test_execution_with_hooks() {
510        let strategy = ExecutionStrategy::inference();
511        let mut context = ExecutionContext::new(3, strategy);
512
513        // Add a logging hook
514        context.add_hook(Box::new(LoggingHook::new()));
515
516        context.start();
517        context.begin_node(0);
518        context.end_node(0, Duration::from_millis(10));
519        context.complete();
520
521        // Hooks should have been called (check via side effects in real implementation)
522        assert_eq!(context.state.phase, ExecutionPhase::Completed);
523    }
524}