1use crate::cg::{CallGraph, EdgeType, Node, NodeType, Visibility}; use std::collections::{HashMap, HashSet};
3use tracing::debug;
4
5pub type NodeId = usize; pub struct ReachabilityAnalyzer;
8
9impl ReachabilityAnalyzer {
10 pub fn new() -> Self {
11 Self
12 }
13
14 pub fn analyze_entry_points<S, FNodeOfInterest, FProcessNode>(
45 &self,
46 graph: &CallGraph,
47 is_node_of_interest: &FNodeOfInterest,
48 process_node_of_interest: &FProcessNode,
49 initial_state_factory: impl Fn() -> S,
50 ) -> HashMap<NodeId, S>
51 where
52 FNodeOfInterest: Fn(&Node) -> bool,
53 FProcessNode: Fn(&Node, &mut S, &CallGraph),
54 {
55 let mut results: HashMap<NodeId, S> = HashMap::new();
56
57 let entry_point_nodes: Vec<&Node> = graph.iter_nodes()
58 .filter(|node|
59 node.node_type == NodeType::Function &&
60 (node.visibility == Visibility::Public || node.visibility == Visibility::External) &&
61 !node.contract_name.as_ref().map_or(false, |func_contract_name| {
65 graph.nodes.iter().any(|n| {
66 n.node_type == NodeType::Interface &&
67 n.name == *func_contract_name && n.contract_name.as_deref() == Some(func_contract_name) })
70 })
71 )
72 .collect();
73
74 for entry_node in entry_point_nodes {
75 let mut current_state = initial_state_factory();
76 let mut visited_functions_for_this_entry_point: HashSet<NodeId> = HashSet::new();
78
79 self.dfs_traverse(
80 entry_node.id,
81 graph,
82 is_node_of_interest,
83 process_node_of_interest,
84 &mut current_state,
85 &mut visited_functions_for_this_entry_point,
86 );
87 results.insert(entry_node.id, current_state);
88 }
89 results
90 }
91
92 pub fn dfs_traverse<S, FNodeOfInterest, FProcessNode>(
93 &self,
94 current_node_id: NodeId,
95 graph: &CallGraph,
96 is_node_of_interest: &FNodeOfInterest,
97 process_node_of_interest: &FProcessNode,
98 state: &mut S,
99 visited_functions_for_this_entry_point: &mut HashSet<NodeId>,
100 ) where
101 FNodeOfInterest: Fn(&Node) -> bool,
102 FProcessNode: Fn(&Node, &mut S, &CallGraph),
103 {
104 let current_node = match graph.nodes.get(current_node_id) {
105 Some(node) => node,
106 None => {
107 debug!(
108 "[Reachability DFS] Error: Node ID {} not found in graph.",
109 current_node_id
110 );
111 return;
112 }
113 };
114
115 if matches!(
118 current_node.node_type,
119 NodeType::Function | NodeType::Modifier | NodeType::Constructor
120 ) {
121 if !visited_functions_for_this_entry_point.insert(current_node_id) {
122 return;
123 }
124 }
125
126 if is_node_of_interest(current_node) {
130 process_node_of_interest(current_node, state, graph);
131 }
132
133 for edge in &graph.edges {
135 if edge.source_node_id == current_node_id && edge.edge_type == EdgeType::Call {
136 self.dfs_traverse(
138 edge.target_node_id,
139 graph,
140 is_node_of_interest,
141 process_node_of_interest,
142 state,
143 visited_functions_for_this_entry_point, );
145 }
146 }
147 }
148}
149
150#[cfg(test)]
151pub(crate) mod tests {
152 use super::*;
153 use crate::cg::{CallGraph, EdgeType, NodeType, Visibility};
154 use std::collections::HashSet;
155
156 pub fn create_test_graph_for_reachability() -> CallGraph {
157 let mut graph = CallGraph::new();
158
159 let a_pub_func_id = graph.add_node(
160 "a_pub_func".to_string(),
161 NodeType::Function,
162 Some("ContractA".to_string()),
163 Visibility::Public,
164 (0, 0),
165 );
166 let a_priv_func_id = graph.add_node(
167 "a_priv_func".to_string(),
168 NodeType::Function,
169 Some("ContractA".to_string()),
170 Visibility::Private,
171 (0, 0),
172 );
173 let b_pub_func_id = graph.add_node(
174 "b_pub_func".to_string(),
175 NodeType::Function,
176 Some("ContractB".to_string()),
177 Visibility::External,
178 (0, 0),
179 );
180 let b_internal_func_id = graph.add_node(
181 "b_internal_func".to_string(),
182 NodeType::Function,
183 Some("ContractB".to_string()),
184 Visibility::Internal,
185 (0, 0),
186 );
187 let c_internal_func_id = graph.add_node(
188 "c_internal_func".to_string(),
189 NodeType::Function,
190 Some("ContractC".to_string()),
191 Visibility::Internal,
192 (0, 0),
193 );
194 let _itest_iface_id = graph.add_node(
195 "ITest".to_string(),
196 NodeType::Interface,
197 Some("ITest".to_string()),
198 Visibility::Default,
199 (0, 0),
200 );
201 let itest_func_decl_id = graph.add_node(
202 "interface_func".to_string(),
203 NodeType::Function,
204 Some("ITest".to_string()),
205 Visibility::External,
206 (0, 0),
207 );
208
209 let storage_var1_id = graph.add_node(
210 "var1".to_string(),
211 NodeType::StorageVariable,
212 Some("ContractA".to_string()),
213 Visibility::Default,
214 (0, 0),
215 );
216 let storage_var2_id = graph.add_node(
217 "var2".to_string(),
218 NodeType::StorageVariable,
219 Some("ContractB".to_string()),
220 Visibility::Default,
221 (0, 0),
222 );
223 let storage_var3_id = graph.add_node(
224 "var3".to_string(),
225 NodeType::StorageVariable,
226 Some("ContractC".to_string()),
227 Visibility::Default,
228 (0, 0),
229 );
230
231 graph.add_edge(
232 a_pub_func_id,
233 a_priv_func_id,
234 EdgeType::Call,
235 (0, 0),
236 None,
237 1,
238 None,
239 None,
240 None,
241 None,
242 );
243 graph.add_edge(
244 a_priv_func_id,
245 b_internal_func_id,
246 EdgeType::Call,
247 (0, 0),
248 None,
249 1,
250 None,
251 None,
252 None,
253 None,
254 );
255 graph.add_edge(
256 b_pub_func_id,
257 b_internal_func_id,
258 EdgeType::Call,
259 (0, 0),
260 None,
261 1,
262 None,
263 None,
264 None,
265 None,
266 );
267 graph.add_edge(
268 b_internal_func_id,
269 c_internal_func_id,
270 EdgeType::Call,
271 (0, 0),
272 None,
273 1,
274 None,
275 None,
276 None,
277 None,
278 );
279
280 graph.add_edge(
281 a_pub_func_id,
282 storage_var1_id,
283 EdgeType::StorageRead,
284 (0, 0),
285 None,
286 2,
287 None,
288 None,
289 None,
290 None,
291 );
292 graph.add_edge(
293 a_priv_func_id,
294 storage_var1_id,
295 EdgeType::StorageWrite,
296 (0, 0),
297 None,
298 2,
299 None,
300 None,
301 None,
302 None,
303 );
304 graph.add_edge(
305 b_internal_func_id,
306 storage_var2_id,
307 EdgeType::StorageRead,
308 (0, 0),
309 None,
310 2,
311 None,
312 None,
313 None,
314 None,
315 );
316 graph.add_edge(
317 c_internal_func_id,
318 storage_var3_id,
319 EdgeType::StorageWrite,
320 (0, 0),
321 None,
322 1,
323 None,
324 None,
325 None,
326 None,
327 );
328 graph.add_edge(
329 b_pub_func_id,
330 storage_var2_id,
331 EdgeType::StorageWrite,
332 (0, 0),
333 None,
334 2,
335 None,
336 None,
337 None,
338 None,
339 );
340
341 assert_eq!(graph.nodes[a_pub_func_id].name, "a_pub_func");
342 assert_eq!(graph.nodes[itest_func_decl_id].name, "interface_func");
343 assert_eq!(graph.nodes[storage_var1_id].name, "var1");
344
345 graph
346 }
347
348 #[test]
349 fn test_analyze_entry_points_no_entry_points() {
350 let mut graph = CallGraph::new();
351 graph.add_node(
352 "internal_func".to_string(),
353 NodeType::Function,
354 Some("ContractA".to_string()),
355 Visibility::Internal,
356 (0, 0),
357 );
358 let analyzer = ReachabilityAnalyzer::new();
359 let results = analyzer.analyze_entry_points(
360 &graph,
361 &|_| true, &|_, _, _| {}, || (), );
365 assert!(
366 results.is_empty(),
367 "Expected no results for a graph with no public/external entry points"
368 );
369 }
370
371 #[test]
372 fn test_analyze_entry_points_with_cycle() {
373 let mut graph = CallGraph::new();
374 let func1_id = graph.add_node(
375 "func1".to_string(),
376 NodeType::Function,
377 Some("CycleContract".to_string()),
378 Visibility::Public,
379 (0, 0),
380 );
381 let func2_id = graph.add_node(
382 "func2".to_string(),
383 NodeType::Function,
384 Some("CycleContract".to_string()),
385 Visibility::Private,
386 (0, 0),
387 );
388 graph.add_edge(
389 func1_id,
390 func2_id,
391 EdgeType::Call,
392 (0, 0),
393 None,
394 1,
395 None,
396 None,
397 None,
398 None,
399 );
400 graph.add_edge(
401 func2_id,
402 func1_id,
403 EdgeType::Call,
404 (0, 0),
405 None,
406 1,
407 None,
408 None,
409 None,
410 None,
411 );
412
413 let analyzer = ReachabilityAnalyzer::new();
414 let _processed_nodes: HashSet<NodeId> = HashSet::new();
415
416 let results = analyzer.analyze_entry_points(
417 &graph,
418 &|node| {
419 matches!(
420 node.node_type,
421 NodeType::Function | NodeType::Modifier | NodeType::Constructor
422 )
423 },
424 &|node, state: &mut HashSet<NodeId>, _graph_ref| {
425 state.insert(node.id);
426 },
427 HashSet::new,
428 );
429
430 assert_eq!(results.len(), 1, "Expected one entry point result");
431 let summary = results.get(&func1_id).unwrap();
432 let expected_processed: HashSet<NodeId> = [func1_id, func2_id].iter().cloned().collect();
435 assert_eq!(
436 *summary, expected_processed,
437 "Both functions in cycle should be processed once for the entry point"
438 );
439 }
440
441 #[test]
442 fn test_interface_function_declarations_are_not_entry_points() {
443 let mut graph = CallGraph::new();
444
445 let iface_node_id = graph.add_node(
446 "IMyInterface".to_string(),
447 NodeType::Interface,
448 Some("IMyInterface".to_string()),
449 Visibility::Default,
450 (0, 0),
451 );
452 let iface_func_id = graph.add_node(
453 "doSomething".to_string(),
454 NodeType::Function,
455 Some("IMyInterface".to_string()),
456 Visibility::External,
457 (0, 0),
458 );
459
460 let regular_pub_func_id = graph.add_node(
461 "regularPublic".to_string(),
462 NodeType::Function,
463 Some("MyContract".to_string()),
464 Visibility::Public,
465 (0, 0),
466 );
467
468 let analyzer = ReachabilityAnalyzer::new();
469 let results = analyzer.analyze_entry_points(&graph, &|_| true, &|_, _, _| {}, || ());
470
471 assert_eq!(
472 results.len(),
473 1,
474 "Only regularPublic should be an entry point"
475 );
476 assert!(
477 results.contains_key(®ular_pub_func_id),
478 "regularPublic should be an entry point"
479 );
480 assert!(
481 !results.contains_key(&iface_func_id),
482 "Interface function declaration should not be an entry point"
483 );
484 assert!(
485 !results.contains_key(&iface_node_id),
486 "Interface node itself should not be an entry point"
487 );
488 }
489}