traverse_graph/
storage_access.rs

1//! # Storage Access Analysis Module
2//!
3//! This module provides functionality to analyze a `CallGraph` and determine
4//! which storage variables are read or written by each public or external
5//! entry point function within the analyzed contracts.
6//!
7//! ## Core Functionality
8//!
9//! The main entry point is the `analyze_storage_access` function. It performs
10//! the following steps:
11//!
12//! 1.  **Identifies Entry Points**: It uses `ReachabilityAnalyzer` to find all
13//!     functions marked as `public` or `external` that are not part of an
14//!     interface definition. These serve as the starting points for the analysis.
15//!
16//! 2.  **Call Graph Traversal**: For each identified entry point, it traverses
17//!     the call graph using a Depth-First Search (DFS) approach, following
18//!     `Call` edges.
19//!
20//! 3.  **Storage Interaction Processing**: When a function, modifier, or
21//!     constructor node is visited during the DFS traversal, this module
22//!     inspects its direct outgoing edges. If these edges are of type
23//!     `StorageRead` or `StorageWrite` and target a `StorageVariable` node,
24//!     the ID of the storage variable is recorded.
25//!
26//! 4.  **Aggregation**: The read and write accesses are aggregated into a
27//!     `StorageAccessSummary` for each entry point. This summary contains
28//!     two `HashSet`s: one for the NodeIds of storage variables read, and
29//!     one for those written.
30//!
31//! ## Output
32//!
33//! The `analyze_storage_access` function returns a `HashMap` where the keys are
34//! the `NodeId`s of the entry point functions, and the values are their
35//! corresponding `StorageAccessSummary` objects.
36//!
37//! This information is crucial for understanding the storage footprint and
38//! side effects of different functions in a smart contract system.
39//!
40use crate::cg::{CallGraph, EdgeType, Node, NodeType};
41use crate::reachability::{NodeId, ReachabilityAnalyzer}; 
42use std::collections::{HashMap, HashSet};
43
44#[derive(Clone, Default, Debug, PartialEq, Eq)]
45pub struct StorageAccessSummary {
46    pub reads: HashSet<NodeId>,  
47    pub writes: HashSet<NodeId>, 
48}
49
50pub fn analyze_storage_access(graph: &CallGraph) -> HashMap<NodeId, StorageAccessSummary> {
51    let analyzer = ReachabilityAnalyzer::new();
52
53    // For storage analysis, every function/modifier/constructor encountered in the
54    // call tree is "of interest" because we need to check its direct storage interactions.
55    let is_function_like_node = |node: &Node| -> bool {
56        matches!(node.node_type, NodeType::Function | NodeType::Modifier | NodeType::Constructor)
57    };
58
59    // This function is called for each `func_node` (a function, modifier, or constructor)
60    // that is visited during the DFS traversal of the call graph.
61    // It inspects `func_node`'s direct outgoing edges to find StorageRead/StorageWrite
62    // interactions and updates the `StorageAccessSummary` state.
63    let process_function_for_storage_interactions = |
64        func_node: &Node, // The function/modifier/constructor currently being processed
65        state: &mut StorageAccessSummary,
66        graph: &CallGraph
67    | {
68        for edge in &graph.edges {
69            // Check if the edge originates from the current function-like node
70            if edge.source_node_id == func_node.id {
71                // Check if the target of the edge is a storage variable
72                if let Some(target_node) = graph.nodes.get(edge.target_node_id) {
73                    if target_node.node_type == NodeType::StorageVariable {
74                        match edge.edge_type {
75                            EdgeType::StorageRead => {
76                                state.reads.insert(target_node.id);
77                            }
78                            EdgeType::StorageWrite => {
79                                state.writes.insert(target_node.id);
80                            }
81                            _ => {} // Other edge types to storage vars are not relevant for this summary
82                        }
83                    }
84                }
85            }
86        }
87    };
88
89    analyzer.analyze_entry_points(
90        graph,
91        &is_function_like_node,
92        &process_function_for_storage_interactions,
93        StorageAccessSummary::default, 
94    )
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100     
101    use crate::reachability::tests::create_test_graph_for_reachability; 
102    use std::collections::HashSet; 
103
104    #[test]
105    fn test_analyze_storage_access_basic() {
106        let graph = create_test_graph_for_reachability();
107        let results = analyze_storage_access(&graph);
108
109        assert_eq!(results.len(), 2, "Expected 2 entry points (a_pub_func, b_pub_func)");
110
111        let a_pub_func_id = graph.iter_nodes().find(|n| n.name == "a_pub_func").unwrap().id;
112        let b_pub_func_id = graph.iter_nodes().find(|n| n.name == "b_pub_func").unwrap().id;
113
114        let storage_var1_id = graph.iter_nodes().find(|n| n.name == "var1").unwrap().id;
115        let storage_var2_id = graph.iter_nodes().find(|n| n.name == "var2").unwrap().id;
116        let storage_var3_id = graph.iter_nodes().find(|n| n.name == "var3").unwrap().id;
117
118        let summary_a = results.get(&a_pub_func_id).expect("Summary for a_pub_func missing");
119        let expected_reads_a: HashSet<NodeId> = [storage_var1_id, storage_var2_id].iter().cloned().collect();
120        let expected_writes_a: HashSet<NodeId> = [storage_var1_id, storage_var3_id].iter().cloned().collect();
121        assert_eq!(summary_a.reads, expected_reads_a, "Mismatch in reads for a_pub_func");
122        assert_eq!(summary_a.writes, expected_writes_a, "Mismatch in writes for a_pub_func");
123
124        let summary_b = results.get(&b_pub_func_id).expect("Summary for b_pub_func missing");
125        let expected_reads_b: HashSet<NodeId> = [storage_var2_id].iter().cloned().collect();
126        let expected_writes_b: HashSet<NodeId> = [storage_var2_id, storage_var3_id].iter().cloned().collect();
127        assert_eq!(summary_b.reads, expected_reads_b, "Mismatch in reads for b_pub_func");
128        assert_eq!(summary_b.writes, expected_writes_b, "Mismatch in writes for b_pub_func");
129    }
130}
131