traverse_graph/storage_access.rs
1//! # Storage Access Analysis Module
2//!
3//! This module provides functionality to analyze a `CallGraph` and determine
4//! which storage variables are read or written by each public or external
5//! entry point function within the analyzed contracts.
6//!
7//! ## Core Functionality
8//!
9//! The main entry point is the `analyze_storage_access` function. It performs
10//! the following steps:
11//!
12//! 1. **Identifies Entry Points**: It uses `ReachabilityAnalyzer` to find all
13//! functions marked as `public` or `external` that are not part of an
14//! interface definition. These serve as the starting points for the analysis.
15//!
16//! 2. **Call Graph Traversal**: For each identified entry point, it traverses
17//! the call graph using a Depth-First Search (DFS) approach, following
18//! `Call` edges.
19//!
20//! 3. **Storage Interaction Processing**: When a function, modifier, or
21//! constructor node is visited during the DFS traversal, this module
22//! inspects its direct outgoing edges. If these edges are of type
23//! `StorageRead` or `StorageWrite` and target a `StorageVariable` node,
24//! the ID of the storage variable is recorded.
25//!
26//! 4. **Aggregation**: The read and write accesses are aggregated into a
27//! `StorageAccessSummary` for each entry point. This summary contains
28//! two `HashSet`s: one for the NodeIds of storage variables read, and
29//! one for those written.
30//!
31//! ## Output
32//!
33//! The `analyze_storage_access` function returns a `HashMap` where the keys are
34//! the `NodeId`s of the entry point functions, and the values are their
35//! corresponding `StorageAccessSummary` objects.
36//!
37//! This information is crucial for understanding the storage footprint and
38//! side effects of different functions in a smart contract system.
39//!
40use crate::cg::{CallGraph, EdgeType, Node, NodeType};
41use crate::reachability::{NodeId, ReachabilityAnalyzer};
42use std::collections::{HashMap, HashSet};
43
44#[derive(Clone, Default, Debug, PartialEq, Eq)]
45pub struct StorageAccessSummary {
46 pub reads: HashSet<NodeId>,
47 pub writes: HashSet<NodeId>,
48}
49
50pub fn analyze_storage_access(graph: &CallGraph) -> HashMap<NodeId, StorageAccessSummary> {
51 let analyzer = ReachabilityAnalyzer::new();
52
53 // For storage analysis, every function/modifier/constructor encountered in the
54 // call tree is "of interest" because we need to check its direct storage interactions.
55 let is_function_like_node = |node: &Node| -> bool {
56 matches!(node.node_type, NodeType::Function | NodeType::Modifier | NodeType::Constructor)
57 };
58
59 // This function is called for each `func_node` (a function, modifier, or constructor)
60 // that is visited during the DFS traversal of the call graph.
61 // It inspects `func_node`'s direct outgoing edges to find StorageRead/StorageWrite
62 // interactions and updates the `StorageAccessSummary` state.
63 let process_function_for_storage_interactions = |
64 func_node: &Node, // The function/modifier/constructor currently being processed
65 state: &mut StorageAccessSummary,
66 graph: &CallGraph
67 | {
68 for edge in &graph.edges {
69 // Check if the edge originates from the current function-like node
70 if edge.source_node_id == func_node.id {
71 // Check if the target of the edge is a storage variable
72 if let Some(target_node) = graph.nodes.get(edge.target_node_id) {
73 if target_node.node_type == NodeType::StorageVariable {
74 match edge.edge_type {
75 EdgeType::StorageRead => {
76 state.reads.insert(target_node.id);
77 }
78 EdgeType::StorageWrite => {
79 state.writes.insert(target_node.id);
80 }
81 _ => {} // Other edge types to storage vars are not relevant for this summary
82 }
83 }
84 }
85 }
86 }
87 };
88
89 analyzer.analyze_entry_points(
90 graph,
91 &is_function_like_node,
92 &process_function_for_storage_interactions,
93 StorageAccessSummary::default,
94 )
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 use crate::reachability::tests::create_test_graph_for_reachability;
102 use std::collections::HashSet;
103
104 #[test]
105 fn test_analyze_storage_access_basic() {
106 let graph = create_test_graph_for_reachability();
107 let results = analyze_storage_access(&graph);
108
109 assert_eq!(results.len(), 2, "Expected 2 entry points (a_pub_func, b_pub_func)");
110
111 let a_pub_func_id = graph.iter_nodes().find(|n| n.name == "a_pub_func").unwrap().id;
112 let b_pub_func_id = graph.iter_nodes().find(|n| n.name == "b_pub_func").unwrap().id;
113
114 let storage_var1_id = graph.iter_nodes().find(|n| n.name == "var1").unwrap().id;
115 let storage_var2_id = graph.iter_nodes().find(|n| n.name == "var2").unwrap().id;
116 let storage_var3_id = graph.iter_nodes().find(|n| n.name == "var3").unwrap().id;
117
118 let summary_a = results.get(&a_pub_func_id).expect("Summary for a_pub_func missing");
119 let expected_reads_a: HashSet<NodeId> = [storage_var1_id, storage_var2_id].iter().cloned().collect();
120 let expected_writes_a: HashSet<NodeId> = [storage_var1_id, storage_var3_id].iter().cloned().collect();
121 assert_eq!(summary_a.reads, expected_reads_a, "Mismatch in reads for a_pub_func");
122 assert_eq!(summary_a.writes, expected_writes_a, "Mismatch in writes for a_pub_func");
123
124 let summary_b = results.get(&b_pub_func_id).expect("Summary for b_pub_func missing");
125 let expected_reads_b: HashSet<NodeId> = [storage_var2_id].iter().cloned().collect();
126 let expected_writes_b: HashSet<NodeId> = [storage_var2_id, storage_var3_id].iter().cloned().collect();
127 assert_eq!(summary_b.reads, expected_reads_b, "Mismatch in reads for b_pub_func");
128 assert_eq!(summary_b.writes, expected_writes_b, "Mismatch in writes for b_pub_func");
129 }
130}
131