traverse_graph/
reachability.rs

1use crate::cg::{CallGraph, EdgeType, Node, NodeType, Visibility}; // Assuming NodeId is usize internally
2use std::collections::{HashMap, HashSet};
3use tracing::debug;
4
5pub type NodeId = usize; // Ensure this is pub if storage_access needs it directly from here.
6
7pub struct ReachabilityAnalyzer;
8
9impl ReachabilityAnalyzer {
10    pub fn new() -> Self {
11        Self
12    }
13
14    /// Performs reachability analysis starting from public/external functions.
15    ///
16    /// Traverses the call graph via `Call` edges. For each function node visited
17    /// during this traversal, if `is_node_of_interest` returns true for that function node,
18    /// `process_node_of_interest` is called.
19    ///
20    /// The `process_node_of_interest` function is responsible for inspecting the
21    /// current function node (e.g., its direct edges like `StorageRead`/`StorageWrite`)
22    /// and updating the state `S`.
23    ///
24    /// # Type Parameters
25    ///
26    /// * `S`: The type of the state to be accumulated for each entry point.
27    /// * `FNodeOfInterest`: A predicate `Fn(&Node) -> bool` that determines if a visited
28    ///   function/modifier/constructor node in the call graph is of interest.
29    /// * `FProcessNode`: An action `Fn(&Node, &mut S, &CallGraph)` called when a node of interest
30    ///   is visited. It receives the node of interest, the mutable state for the current
31    ///   entry point, and a reference to the full graph for context (e.g., to inspect edges).
32    ///
33    /// # Arguments
34    ///
35    /// * `graph`: The `CallGraph` to analyze.
36    /// * `is_node_of_interest`: Predicate to identify interesting function/modifier/constructor nodes.
37    /// * `process_node_of_interest`: Action to perform on interesting nodes.
38    /// * `initial_state_factory`: A function that creates an initial state `S` for each entry point.
39    ///
40    /// # Returns
41    ///
42    /// A `HashMap` mapping the `NodeId` of each public/external entry point to the
43    /// accumulated state `S` for that entry point.
44    pub fn analyze_entry_points<S, FNodeOfInterest, FProcessNode>(
45        &self,
46        graph: &CallGraph,
47        is_node_of_interest: &FNodeOfInterest,
48        process_node_of_interest: &FProcessNode,
49        initial_state_factory: impl Fn() -> S,
50    ) -> HashMap<NodeId, S>
51    where
52        FNodeOfInterest: Fn(&Node) -> bool,
53        FProcessNode: Fn(&Node, &mut S, &CallGraph),
54    {
55        let mut results: HashMap<NodeId, S> = HashMap::new();
56
57        let entry_point_nodes: Vec<&Node> = graph.iter_nodes()
58            .filter(|node|
59                node.node_type == NodeType::Function &&
60                (node.visibility == Visibility::Public || node.visibility == Visibility::External) &&
61                // Exclude interface function declarations:
62                // A function is an interface declaration if its contract_name (e.g., "IMyInterface")
63                // matches the name of an actual Interface node in the graph.
64                !node.contract_name.as_ref().map_or(false, |func_contract_name| {
65                    graph.nodes.iter().any(|n| {
66                        n.node_type == NodeType::Interface &&
67                        n.name == *func_contract_name && // The Interface node's name
68                        n.contract_name.as_deref() == Some(func_contract_name) // Interface node's scope is itself
69                    })
70                })
71            )
72            .collect();
73
74        for entry_node in entry_point_nodes {
75            let mut current_state = initial_state_factory();
76            // This set tracks visited functions *within the traversal for a single entry point*.
77            let mut visited_functions_for_this_entry_point: HashSet<NodeId> = HashSet::new();
78
79            self.dfs_traverse(
80                entry_node.id,
81                graph,
82                is_node_of_interest,
83                process_node_of_interest,
84                &mut current_state,
85                &mut visited_functions_for_this_entry_point,
86            );
87            results.insert(entry_node.id, current_state);
88        }
89        results
90    }
91
92    pub fn dfs_traverse<S, FNodeOfInterest, FProcessNode>(
93        &self,
94        current_node_id: NodeId,
95        graph: &CallGraph,
96        is_node_of_interest: &FNodeOfInterest,
97        process_node_of_interest: &FProcessNode,
98        state: &mut S,
99        visited_functions_for_this_entry_point: &mut HashSet<NodeId>,
100    ) where
101        FNodeOfInterest: Fn(&Node) -> bool,
102        FProcessNode: Fn(&Node, &mut S, &CallGraph),
103    {
104        let current_node = match graph.nodes.get(current_node_id) {
105            Some(node) => node,
106            None => {
107                debug!(
108                    "[Reachability DFS] Error: Node ID {} not found in graph.",
109                    current_node_id
110                );
111                return;
112            }
113        };
114
115        // Check if this function/modifier/constructor node has already been processed
116        // for this particular entry point's traversal to avoid redundant work and cycles.
117        if matches!(
118            current_node.node_type,
119            NodeType::Function | NodeType::Modifier | NodeType::Constructor
120        ) {
121            if !visited_functions_for_this_entry_point.insert(current_node_id) {
122                return;
123            }
124        }
125
126        // Process the current node (function/modifier/constructor) if it's of interest.
127        // The `process_node_of_interest` function will then typically look at this
128        // node's direct interactions (e.g., StorageRead/Write edges).
129        if is_node_of_interest(current_node) {
130            process_node_of_interest(current_node, state, graph);
131        }
132
133        // Traverse outgoing 'Call' edges to explore the call graph further.
134        for edge in &graph.edges {
135            if edge.source_node_id == current_node_id && edge.edge_type == EdgeType::Call {
136                // The target of this edge is the callee.
137                self.dfs_traverse(
138                    edge.target_node_id,
139                    graph,
140                    is_node_of_interest,
141                    process_node_of_interest,
142                    state,
143                    visited_functions_for_this_entry_point, // Pass the same set
144                );
145            }
146        }
147    }
148}
149
150#[cfg(test)]
151pub(crate) mod tests {
152    use super::*;
153    use crate::cg::{CallGraph, EdgeType, NodeType, Visibility};
154    use std::collections::HashSet;
155
156    pub fn create_test_graph_for_reachability() -> CallGraph {
157        let mut graph = CallGraph::new();
158
159        let a_pub_func_id = graph.add_node(
160            "a_pub_func".to_string(),
161            NodeType::Function,
162            Some("ContractA".to_string()),
163            Visibility::Public,
164            (0, 0),
165        );
166        let a_priv_func_id = graph.add_node(
167            "a_priv_func".to_string(),
168            NodeType::Function,
169            Some("ContractA".to_string()),
170            Visibility::Private,
171            (0, 0),
172        );
173        let b_pub_func_id = graph.add_node(
174            "b_pub_func".to_string(),
175            NodeType::Function,
176            Some("ContractB".to_string()),
177            Visibility::External,
178            (0, 0),
179        );
180        let b_internal_func_id = graph.add_node(
181            "b_internal_func".to_string(),
182            NodeType::Function,
183            Some("ContractB".to_string()),
184            Visibility::Internal,
185            (0, 0),
186        );
187        let c_internal_func_id = graph.add_node(
188            "c_internal_func".to_string(),
189            NodeType::Function,
190            Some("ContractC".to_string()),
191            Visibility::Internal,
192            (0, 0),
193        );
194        let _itest_iface_id = graph.add_node(
195            "ITest".to_string(),
196            NodeType::Interface,
197            Some("ITest".to_string()),
198            Visibility::Default,
199            (0, 0),
200        );
201        let itest_func_decl_id = graph.add_node(
202            "interface_func".to_string(),
203            NodeType::Function,
204            Some("ITest".to_string()),
205            Visibility::External,
206            (0, 0),
207        );
208
209        let storage_var1_id = graph.add_node(
210            "var1".to_string(),
211            NodeType::StorageVariable,
212            Some("ContractA".to_string()),
213            Visibility::Default,
214            (0, 0),
215        );
216        let storage_var2_id = graph.add_node(
217            "var2".to_string(),
218            NodeType::StorageVariable,
219            Some("ContractB".to_string()),
220            Visibility::Default,
221            (0, 0),
222        );
223        let storage_var3_id = graph.add_node(
224            "var3".to_string(),
225            NodeType::StorageVariable,
226            Some("ContractC".to_string()),
227            Visibility::Default,
228            (0, 0),
229        );
230
231        graph.add_edge(
232            a_pub_func_id,
233            a_priv_func_id,
234            EdgeType::Call,
235            (0, 0),
236            None,
237            1,
238            None,
239            None,
240            None,
241            None,
242        );
243        graph.add_edge(
244            a_priv_func_id,
245            b_internal_func_id,
246            EdgeType::Call,
247            (0, 0),
248            None,
249            1,
250            None,
251            None,
252            None,
253            None,
254        );
255        graph.add_edge(
256            b_pub_func_id,
257            b_internal_func_id,
258            EdgeType::Call,
259            (0, 0),
260            None,
261            1,
262            None,
263            None,
264            None,
265            None,
266        );
267        graph.add_edge(
268            b_internal_func_id,
269            c_internal_func_id,
270            EdgeType::Call,
271            (0, 0),
272            None,
273            1,
274            None,
275            None,
276            None,
277            None,
278        );
279
280        graph.add_edge(
281            a_pub_func_id,
282            storage_var1_id,
283            EdgeType::StorageRead,
284            (0, 0),
285            None,
286            2,
287            None,
288            None,
289            None,
290            None,
291        );
292        graph.add_edge(
293            a_priv_func_id,
294            storage_var1_id,
295            EdgeType::StorageWrite,
296            (0, 0),
297            None,
298            2,
299            None,
300            None,
301            None,
302            None,
303        );
304        graph.add_edge(
305            b_internal_func_id,
306            storage_var2_id,
307            EdgeType::StorageRead,
308            (0, 0),
309            None,
310            2,
311            None,
312            None,
313            None,
314            None,
315        );
316        graph.add_edge(
317            c_internal_func_id,
318            storage_var3_id,
319            EdgeType::StorageWrite,
320            (0, 0),
321            None,
322            1,
323            None,
324            None,
325            None,
326            None,
327        );
328        graph.add_edge(
329            b_pub_func_id,
330            storage_var2_id,
331            EdgeType::StorageWrite,
332            (0, 0),
333            None,
334            2,
335            None,
336            None,
337            None,
338            None,
339        );
340
341        assert_eq!(graph.nodes[a_pub_func_id].name, "a_pub_func");
342        assert_eq!(graph.nodes[itest_func_decl_id].name, "interface_func");
343        assert_eq!(graph.nodes[storage_var1_id].name, "var1");
344
345        graph
346    }
347
348    #[test]
349    fn test_analyze_entry_points_no_entry_points() {
350        let mut graph = CallGraph::new();
351        graph.add_node(
352            "internal_func".to_string(),
353            NodeType::Function,
354            Some("ContractA".to_string()),
355            Visibility::Internal,
356            (0, 0),
357        );
358        let analyzer = ReachabilityAnalyzer::new();
359        let results = analyzer.analyze_entry_points(
360            &graph,
361            &|_| true,     // Interested in all nodes for processing
362            &|_, _, _| {}, // No-op process
363            || (),         // Dummy state
364        );
365        assert!(
366            results.is_empty(),
367            "Expected no results for a graph with no public/external entry points"
368        );
369    }
370
371    #[test]
372    fn test_analyze_entry_points_with_cycle() {
373        let mut graph = CallGraph::new();
374        let func1_id = graph.add_node(
375            "func1".to_string(),
376            NodeType::Function,
377            Some("CycleContract".to_string()),
378            Visibility::Public,
379            (0, 0),
380        );
381        let func2_id = graph.add_node(
382            "func2".to_string(),
383            NodeType::Function,
384            Some("CycleContract".to_string()),
385            Visibility::Private,
386            (0, 0),
387        );
388        graph.add_edge(
389            func1_id,
390            func2_id,
391            EdgeType::Call,
392            (0, 0),
393            None,
394            1,
395            None,
396            None,
397            None,
398            None,
399        );
400        graph.add_edge(
401            func2_id,
402            func1_id,
403            EdgeType::Call,
404            (0, 0),
405            None,
406            1,
407            None,
408            None,
409            None,
410            None,
411        );
412
413        let analyzer = ReachabilityAnalyzer::new();
414        let _processed_nodes: HashSet<NodeId> = HashSet::new();
415
416        let results = analyzer.analyze_entry_points(
417            &graph,
418            &|node| {
419                matches!(
420                    node.node_type,
421                    NodeType::Function | NodeType::Modifier | NodeType::Constructor
422                )
423            },
424            &|node, state: &mut HashSet<NodeId>, _graph_ref| {
425                state.insert(node.id);
426            },
427            HashSet::new,
428        );
429
430        assert_eq!(results.len(), 1, "Expected one entry point result");
431        let summary = results.get(&func1_id).unwrap();
432        // Both func1 and func2 should have been processed once due to the cycle detection
433        // within the scope of a single entry point's traversal.
434        let expected_processed: HashSet<NodeId> = [func1_id, func2_id].iter().cloned().collect();
435        assert_eq!(
436            *summary, expected_processed,
437            "Both functions in cycle should be processed once for the entry point"
438        );
439    }
440
441    #[test]
442    fn test_interface_function_declarations_are_not_entry_points() {
443        let mut graph = CallGraph::new();
444
445        let iface_node_id = graph.add_node(
446            "IMyInterface".to_string(),
447            NodeType::Interface,
448            Some("IMyInterface".to_string()),
449            Visibility::Default,
450            (0, 0),
451        );
452        let iface_func_id = graph.add_node(
453            "doSomething".to_string(),
454            NodeType::Function,
455            Some("IMyInterface".to_string()),
456            Visibility::External,
457            (0, 0),
458        );
459
460        let regular_pub_func_id = graph.add_node(
461            "regularPublic".to_string(),
462            NodeType::Function,
463            Some("MyContract".to_string()),
464            Visibility::Public,
465            (0, 0),
466        );
467
468        let analyzer = ReachabilityAnalyzer::new();
469        let results = analyzer.analyze_entry_points(&graph, &|_| true, &|_, _, _| {}, || ());
470
471        assert_eq!(
472            results.len(),
473            1,
474            "Only regularPublic should be an entry point"
475        );
476        assert!(
477            results.contains_key(&regular_pub_func_id),
478            "regularPublic should be an entry point"
479        );
480        assert!(
481            !results.contains_key(&iface_func_id),
482            "Interface function declaration should not be an entry point"
483        );
484        assert!(
485            !results.contains_key(&iface_node_id),
486            "Interface node itself should not be an entry point"
487        );
488    }
489}