rust_logic_graph/core/
executor.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use anyhow::Result;
3use tracing::{info, debug, warn};
4use std::time::{Duration, Instant};
5
6use crate::core::{Graph, GraphDef};
7use crate::node::{Node, RuleNode as ConcreteRuleNode, DBNode, AINode};
8use crate::rule::Rule;
9use crate::cache::{CacheManager, CacheKey};
10
11/// Execution statistics for a single node
12#[derive(Debug, Clone)]
13pub struct NodeExecutionStats {
14    pub node_id: String,
15    pub duration: Duration,
16    pub cache_hit: bool,
17    pub success: bool,
18}
19
20/// Overall execution metrics
21#[derive(Debug, Clone, Default)]
22pub struct ExecutionMetrics {
23    pub total_duration: Duration,
24    pub nodes_executed: usize,
25    pub nodes_skipped: usize,
26    pub nodes_failed: usize,
27    pub cache_hits: usize,
28    pub node_stats: Vec<NodeExecutionStats>,
29}
30
31/// Executor for running graph nodes in topological order.
32/// 
33/// # Thread Safety
34/// 
35/// The Executor is **NOT thread-safe** for concurrent executions on the same instance.
36/// While `execute()` is async and takes `&self` (shared reference), the underlying
37/// implementation assumes single-threaded access patterns:
38/// 
39/// - `self.nodes` is a regular `HashMap` without synchronization
40/// - Multiple concurrent calls to `execute()` would have data races when accessing nodes
41/// 
42/// ## Safe Usage Patterns
43/// 
44/// 1. **Single execution at a time**: Only call `execute()` once at a time per executor instance
45/// 2. **Clone for parallelism**: Create separate executor instances for parallel graph executions
46/// 3. **Sequential async**: Use `.await` to ensure executions don't overlap
47/// 
48/// ## Future Work
49/// 
50/// For true concurrent execution support, wrap `nodes` in `Arc<RwLock<HashMap>>` or similar.
51pub struct Executor {
52    nodes: HashMap<String, Box<dyn Node>>,
53    cache: Option<CacheManager>,
54    metrics: ExecutionMetrics,
55}
56
57impl Executor {
58    pub fn new() -> Self {
59        Self {
60            nodes: HashMap::new(),
61            cache: None,
62            metrics: ExecutionMetrics::default(),
63        }
64    }
65
66    /// Create a new executor with caching enabled
67    pub fn with_cache(cache: CacheManager) -> Self {
68        Self {
69            nodes: HashMap::new(),
70            cache: Some(cache),
71            metrics: ExecutionMetrics::default(),
72        }
73    }
74
75    /// Enable caching for this executor
76    pub fn set_cache(&mut self, cache: CacheManager) {
77        self.cache = Some(cache);
78    }
79
80    /// Get the cache manager (if enabled)
81    pub fn cache(&self) -> Option<&CacheManager> {
82        self.cache.as_ref()
83    }
84    
85    /// Get execution metrics from last run
86    pub fn metrics(&self) -> &ExecutionMetrics {
87        &self.metrics
88    }
89    
90    /// Reset execution metrics
91    pub fn reset_metrics(&mut self) {
92        self.metrics = ExecutionMetrics::default();
93    }
94
95    /// Build executor from graph definition
96    pub fn from_graph_def(def: &GraphDef) -> Result<Self> {
97        let mut executor = Self::new();
98
99        // Create concrete node instances based on NodeConfig
100        for (node_id, config) in &def.nodes {
101            let node: Box<dyn Node> = match config.node_type {
102                crate::node::NodeType::RuleNode => {
103                    let condition = config.condition.as_deref().unwrap_or("true");
104                    Box::new(ConcreteRuleNode::new(node_id, condition))
105                }
106                crate::node::NodeType::DBNode => {
107                    let query = config.query.clone()
108                        .unwrap_or_else(|| format!("SELECT * FROM {}", node_id));
109                    Box::new(DBNode::new(node_id, query))
110                }
111                crate::node::NodeType::AINode => {
112                    let prompt = config.prompt.clone()
113                        .unwrap_or_else(|| format!("Process data for {}", node_id));
114                    Box::new(AINode::new(node_id, prompt))
115                }
116                crate::node::NodeType::GrpcNode => {
117                    // Parse query field as "service_url#method"
118                    let query = config.query.clone()
119                        .unwrap_or_else(|| format!("http://localhost:50051#{}_method", node_id));
120                    let parts: Vec<&str> = query.split('#').collect();
121                    let service_url = parts.get(0).unwrap_or(&"http://localhost:50051").to_string();
122                    let method = parts.get(1).unwrap_or(&"UnknownMethod").to_string();
123                    Box::new(crate::node::GrpcNode::new(node_id, service_url, method))
124                }
125            };
126
127            executor.register_node(node);
128        }
129
130        Ok(executor)
131    }
132
133    /// Register a node with the executor
134    pub fn register_node(&mut self, node: Box<dyn Node>) {
135        let id = node.id().to_string();
136        self.nodes.insert(id, node);
137    }
138
139    /// Detect cycles in the graph using DFS
140    fn detect_cycles(&self, graph: &Graph) -> Result<()> {
141        let mut visited = HashSet::new();
142        let mut rec_stack = HashSet::new();
143        
144        // Build adjacency list for cycle detection
145        let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
146        for edge in &graph.def.edges {
147            adj_list
148                .entry(edge.from.clone())
149                .or_insert_with(Vec::new)
150                .push(edge.to.clone());
151        }
152        
153        // DFS to detect cycles
154        fn dfs_cycle_check(
155            node: &str,
156            adj_list: &HashMap<String, Vec<String>>,
157            visited: &mut HashSet<String>,
158            rec_stack: &mut HashSet<String>,
159            path: &mut Vec<String>,
160        ) -> Option<Vec<String>> {
161            visited.insert(node.to_string());
162            rec_stack.insert(node.to_string());
163            path.push(node.to_string());
164            
165            if let Some(neighbors) = adj_list.get(node) {
166                for neighbor in neighbors {
167                    if !visited.contains(neighbor) {
168                        if let Some(cycle) = dfs_cycle_check(neighbor, adj_list, visited, rec_stack, path) {
169                            return Some(cycle);
170                        }
171                    } else if rec_stack.contains(neighbor) {
172                        // Found a cycle - return the cycle path
173                        let cycle_start = path.iter().position(|n| n == neighbor).unwrap();
174                        return Some(path[cycle_start..].to_vec());
175                    }
176                }
177            }
178            
179            path.pop();
180            rec_stack.remove(node);
181            None
182        }
183        
184        // Check all nodes
185        for node_id in graph.def.nodes.keys() {
186            if !visited.contains(node_id) {
187                let mut path = Vec::new();
188                if let Some(cycle) = dfs_cycle_check(node_id, &adj_list, &mut visited, &mut rec_stack, &mut path) {
189                    return Err(anyhow::anyhow!(
190                        "Cycle detected in graph: {} -> {}",
191                        cycle.join(" -> "),
192                        cycle.first().unwrap()
193                    ));
194                }
195            }
196        }
197        
198        Ok(())
199    }
200
201    /// Execute the graph in topological order
202    pub async fn execute(&mut self, graph: &mut Graph) -> Result<()> {
203        info!("Executor: Starting graph execution");
204        let execution_start = Instant::now();
205        
206        // Reset metrics
207        self.metrics = ExecutionMetrics::default();
208
209        // Validate graph structure first
210        graph.def.validate()?;
211        
212        // Warn about disconnected components
213        if graph.def.has_disconnected_components() {
214            warn!("Graph has disconnected components - some nodes may not be reachable");
215        }
216
217        // First, detect cycles in the graph
218        self.detect_cycles(graph)?;
219
220        // Build adjacency list and in-degree map
221        let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
222        let mut in_degree: HashMap<String, usize> = HashMap::new();
223
224        // Initialize all nodes with 0 in-degree
225        for node_id in graph.def.nodes.keys() {
226            in_degree.insert(node_id.clone(), 0);
227            adj_list.insert(node_id.clone(), Vec::new());
228        }
229
230        // Build the graph structure
231        for edge in &graph.def.edges {
232            adj_list
233                .entry(edge.from.clone())
234                .or_insert_with(Vec::new)
235                .push(edge.to.clone());
236
237            *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
238        }
239
240        // Find all nodes with in-degree 0 (starting nodes)
241        let mut queue: VecDeque<String> = in_degree
242            .iter()
243            .filter(|(_, &degree)| degree == 0)
244            .map(|(id, _)| id.clone())
245            .collect();
246
247        if queue.is_empty() {
248            return Err(anyhow::anyhow!(
249                "No starting nodes found in graph (all nodes have incoming edges). \
250                This indicates either a cycle (which should have been caught earlier) \
251                or an invalid graph structure. Nodes: {:?}",
252                graph.def.nodes.keys().collect::<Vec<_>>()
253            ));
254        }
255
256        let mut executed = HashSet::new();
257        let mut execution_order = Vec::new();
258
259        // Topological sort & execution
260        while let Some(node_id) = queue.pop_front() {
261            if executed.contains(&node_id) {
262                continue;
263            }
264
265            info!("Executor: Processing node '{}'", node_id);
266
267            // Check if all incoming edges have their rules satisfied
268            // IMPORTANT: Only check edges from nodes that have been executed
269            let incoming_edges: Vec<_> = graph
270                .def
271                .edges
272                .iter()
273                .filter(|e| e.to == node_id && executed.contains(&e.from))
274                .collect();
275
276            let mut should_execute = true;
277
278            for edge in &incoming_edges {
279                if let Some(rule_id) = &edge.rule {
280                    let rule = Rule::new(rule_id, "true"); // Default condition
281
282                    match rule.evaluate(&graph.context.data) {
283                        Ok(result) => {
284                            debug!(
285                                "Rule '{}' for edge {} -> {} evaluated to: {:?}",
286                                rule_id, edge.from, edge.to, result
287                            );
288
289                            if let serde_json::Value::Bool(false) = result {
290                                should_execute = false;
291                                info!(
292                                    "Skipping node '{}' due to failed rule '{}' from executed node '{}'",
293                                    node_id, rule_id, edge.from
294                                );
295                                self.metrics.nodes_skipped += 1;
296                                break;
297                            }
298                        }
299                        Err(e) => {
300                            warn!(
301                                "Rule '{}' evaluation failed: {}. Assuming true.",
302                                rule_id, e
303                            );
304                        }
305                    }
306                }
307            }
308
309            // Execute the node
310            if should_execute {
311                if let Some(node) = self.nodes.get(&node_id) {
312                    let node_start = Instant::now();
313                    let mut cache_hit = false;
314                    
315                    // Create cache key based on node ID and relevant context only
316                    // Only include context keys that this node might depend on
317                    let relevant_context: HashMap<String, serde_json::Value> = incoming_edges
318                        .iter()
319                        .filter_map(|edge| {
320                            graph.context.data.get(&format!("{}_result", edge.from))
321                                .map(|v| (edge.from.clone(), v.clone()))
322                        })
323                        .collect();
324                    
325                    let context_value = serde_json::to_value(&relevant_context)?;
326                    let cache_key = CacheKey::new(&node_id, &context_value);
327
328                    // Check cache first
329                    let cached_result = if let Some(cache) = &self.cache {
330                        cache.get(&cache_key)
331                    } else {
332                        None
333                    };
334
335                    let result = if let Some(cached_value) = cached_result {
336                        info!("Node '{}' result retrieved from cache", node_id);
337                        cache_hit = true;
338                        self.metrics.cache_hits += 1;
339                        
340                        // Merge cached result into context
341                        if let serde_json::Value::Object(cached_obj) = cached_value {
342                            for (k, v) in cached_obj {
343                                graph.context.data.insert(k, v);
344                            }
345                        }
346                        
347                        Ok(serde_json::Value::Null) // Successfully used cache
348                    } else {
349                        // Execute node and cache result
350                        let exec_result = node.run(&mut graph.context).await;
351                        
352                        // Store result in cache if execution succeeded
353                        if exec_result.is_ok() {
354                            if let Some(cache) = &self.cache {
355                                let context_result = serde_json::to_value(&graph.context.data)?;
356                                if let Err(e) = cache.put(cache_key, context_result, None) {
357                                    warn!("Failed to cache result for node '{}': {}", node_id, e);
358                                }
359                            }
360                        }
361                        
362                        exec_result
363                    };
364
365                    match result {
366                        Ok(_) => {
367                            let duration = node_start.elapsed();
368                            info!("Node '{}' executed successfully in {:?}", node_id, duration);
369                            execution_order.push(node_id.clone());
370                            
371                            self.metrics.nodes_executed += 1;
372                            self.metrics.node_stats.push(NodeExecutionStats {
373                                node_id: node_id.clone(),
374                                duration,
375                                cache_hit,
376                                success: true,
377                            });
378                        }
379                        Err(e) => {
380                            let duration = node_start.elapsed();
381                            warn!("Node '{}' execution failed: {:?}", node_id, e);
382                            
383                            self.metrics.nodes_failed += 1;
384                            self.metrics.node_stats.push(NodeExecutionStats {
385                                node_id: node_id.clone(),
386                                duration,
387                                cache_hit,
388                                success: false,
389                            });
390                        }
391                    }
392                } else {
393                    warn!("Node '{}' not found in executor", node_id);
394                }
395            }
396
397            executed.insert(node_id.clone());
398
399            // Add downstream nodes to queue
400            if let Some(neighbors) = adj_list.get(&node_id) {
401                for neighbor in neighbors {
402                    if let Some(degree) = in_degree.get_mut(neighbor) {
403                        *degree = degree.saturating_sub(1);
404                        if *degree == 0 && !executed.contains(neighbor) {
405                            queue.push_back(neighbor.clone());
406                        }
407                    }
408                }
409            }
410        }
411
412        self.metrics.total_duration = execution_start.elapsed();
413        
414        info!(
415            "Executor: Completed execution in {:?}. Executed: {}, Skipped: {}, Failed: {}, Cache hits: {}",
416            self.metrics.total_duration,
417            self.metrics.nodes_executed,
418            self.metrics.nodes_skipped,
419            self.metrics.nodes_failed,
420            self.metrics.cache_hits
421        );
422
423        // Verify all nodes were executed (should not happen with cycle detection)
424        let unexecuted: Vec<_> = graph
425            .def
426            .nodes
427            .keys()
428            .filter(|id| !executed.contains(*id))
429            .collect();
430
431        if !unexecuted.is_empty() {
432            return Err(anyhow::anyhow!(
433                "Some nodes were not executed: {:?}. This indicates a bug in the executor logic.",
434                unexecuted
435            ));
436        }
437
438        Ok(())
439    }
440}
441
442impl Default for Executor {
443    fn default() -> Self {
444        Self::new()
445    }
446}