1use std::{collections::HashMap, iter};
2
3use crate::{
4 cg::{
5 CallGraph, CallGraphGeneratorContext, CallGraphGeneratorInput, CallGraphGeneratorStep,
6 NodeInfo, NodeType, Visibility,
7 },
8 parser::get_node_text,
9};
10use anyhow::{anyhow, Context, Result};
11use streaming_iterator::StreamingIterator;
12use tree_sitter::{Node as TsNode, Query, QueryCursor};
13use tracing::debug;
14
15#[derive(Default)] pub struct ContractHandling {
17 config: HashMap<String, String>,
19}
20
21impl CallGraphGeneratorStep for ContractHandling {
23 fn name(&self) -> &'static str {
24 "Contract-Handling"
25 }
26
27 fn config(&mut self, config: &HashMap<String, String>) {
28 self.config = config.clone(); }
30
31 fn generate(
32 &self,
33 input: CallGraphGeneratorInput,
34 ctx: &mut CallGraphGeneratorContext, graph: &mut CallGraph,
36 ) -> Result<()> {
37 let _config = &self.config; let mut definition_cursor = QueryCursor::new();
41
42 let definition_query_str = r#"
43 ; Contract identifier (captures name node for ALL contracts)
44 (contract_declaration
45 name: (identifier) @contract_identifier_node @contract_name_for_map
46 ) @contract_def_item
47
48 ; Interface definition
49 (interface_declaration
50 name: (identifier) @interface_name
51 ) @interface_def_item
52
53 ; Library definition
54 (library_declaration
55 name: (identifier) @library_name
56 ) @library_def_item
57
58 ; Contract inheritance
59 (contract_declaration
60 name: (identifier) @contract_name_for_inheritance
61 (inheritance_specifier
62 ancestor: (user_defined_type
63 (identifier) @inherited_name_for_contract
64 )
65 )
66 ) @contract_inheritance_item
67
68 ; Interface inheritance
69 (interface_declaration
70 name: (identifier) @interface_name_for_inheritance
71 (inheritance_specifier
72 ancestor: (user_defined_type
73 (identifier) @inherited_name_for_interface
74 )
75 )
76 ) @interface_inheritance_item
77
78 ; Function within an interface
79 (interface_declaration
80 (identifier) @interface_scope_for_func
81 (contract_body
82 (function_definition
83 (identifier) @function_name
84 (_)?
85 [(visibility) @visibility_node]?
86 ) @function_def_item
87 )
88 )
89
90 ; Function within a library
91 (library_declaration
92 (identifier) @library_scope_for_func
93 (contract_body
94 (function_definition
95 (identifier) @function_name
96 (_)?
97 [(visibility) @visibility_node]?
98 ) @function_def_item
99 )
100 )
101
102 ; Function within a contract
103 (contract_declaration
104 (identifier) @contract_scope_for_func
105 (contract_body
106 (function_definition
107 (identifier) @function_name
108 (_)?
109 [(visibility) @visibility_node]?
110 ) @function_def_item
111 )
112 )
113
114 ; Modifier within a contract
115 (contract_declaration
116 (identifier) @contract_scope_for_modifier
117 (contract_body
118 (modifier_definition
119 (identifier) @modifier_name
120 (_)?
121 [(visibility) @visibility_node]?
122 ) @modifier_def_item
123 )
124 )
125
126 ; Constructor within a contract
127 (contract_declaration
128 (identifier) @contract_scope_for_constructor
129 (contract_body
130 (constructor_definition
131 (_)?
132 [(visibility) @visibility_node]?
133 ) @constructor_def_item
134 )
135 )
136
137 ; Top-level function
138 (source_file
139 (function_definition
140 (identifier) @function_name
141 (_)?
142 [(visibility) @visibility_node]?
143 ) @function_def_item
144 )
145
146 ; State variable within a contract
147 (contract_declaration
148 name: (identifier) @contract_scope_for_var
149 (contract_body
150 ; Capture the whole state_variable_declaration node
151 ; Type, name, and visibility will be extracted from its children
152 (state_variable_declaration) @state_var_node_capture
153 )
154 ) @state_var_item
155
156 ; Using directive within a contract
157 (contract_declaration
158 name: (identifier) @contract_scope_for_using
159 (contract_body
160 (using_directive
161 (type_alias (identifier) @using_library_name)
162 source: (_) @using_type_or_wildcard_node
163 ) @using_directive_item
164 )
165 )
166 "#;
167 let definition_query = Query::new(&input.solidity_lang, definition_query_str)
168 .context("Failed to create definition query")?;
169
170 let root_node = input.tree.root_node();
171 let source_bytes = input.source.as_bytes();
172
173 debug!("[ContractHandling] Pass 1: Identifying Contracts, Interfaces, Libraries...");
175 let mut matches_pass1 =
176 definition_cursor.matches(&definition_query, root_node, |node: TsNode| {
177 iter::once(&source_bytes[node.byte_range()])
178 });
179 matches_pass1.advance();
180 while let Some(match_) = matches_pass1.get() {
181 for capture in match_.captures {
182 let capture_name = &definition_query.capture_names()[capture.index as usize];
183 let captured_ts_node = capture.node;
184 let _text = get_node_text(&captured_ts_node, &input.source);
185
186 match *capture_name {
187 "contract_def_item" => {
188 let name_node = captured_ts_node.child_by_field_name("name").unwrap();
189 let contract_name = get_node_text(&name_node, &input.source).to_string();
190 if !ctx.all_contracts.contains_key(&contract_name) {
191 let node_info = NodeInfo {
192 span: (name_node.start_byte(), name_node.end_byte()),
193 kind: name_node.kind().to_string(),
194 };
195 ctx.all_contracts.insert(contract_name.clone(), node_info);
196 }
200 }
201 "interface_def_item" => {
202 let name_node = captured_ts_node.child_by_field_name("name").unwrap();
203 let interface_name = get_node_text(&name_node, &input.source).to_string();
204 if !ctx.all_interfaces.contains_key(&interface_name) {
205 let node_info = NodeInfo {
206 span: (name_node.start_byte(), name_node.end_byte()),
207 kind: name_node.kind().to_string(),
208 };
209 ctx.all_interfaces.insert(interface_name.clone(), node_info.clone());
210 let node_id = graph.add_node(
211 interface_name.clone(),
212 NodeType::Interface,
213 Some(interface_name.clone()),
214 Visibility::Default,
215 node_info.span,
216 );
217 ctx.definition_nodes_info.push((node_id, node_info, Some(interface_name)));
218 }
219 }
220 "library_def_item" => {
221 let name_node = captured_ts_node.child_by_field_name("name").unwrap();
222 let library_name = get_node_text(&name_node, &input.source).to_string();
223 if !ctx.all_libraries.contains_key(&library_name) {
224 let node_info = NodeInfo {
225 span: (name_node.start_byte(), name_node.end_byte()),
226 kind: name_node.kind().to_string(),
227 };
228 ctx.all_libraries.insert(library_name.clone(), node_info.clone());
229 let node_id = graph.add_node(
230 library_name.clone(),
231 NodeType::Library,
232 Some(library_name.clone()),
233 Visibility::Default,
234 node_info.span,
235 );
236 ctx.definition_nodes_info.push((node_id, node_info, Some(library_name)));
237 }
238 }
239 _ => {}
240 }
241 }
242 matches_pass1.advance();
243 }
244 debug!("[ContractHandling] Pass 1: Found {} contracts, {} interfaces, {} libraries.", ctx.all_contracts.len(), ctx.all_interfaces.len(), ctx.all_libraries.len());
245
246 debug!("[ContractHandling] Pass 2: Processing members and relationships...");
248 let mut matches_pass2 =
249 definition_cursor.matches(&definition_query, root_node, |node: TsNode| {
250 iter::once(&source_bytes[node.byte_range()])
251 });
252 matches_pass2.advance();
253 while let Some(match_) = matches_pass2.get() {
254 let mut item_node_kind_opt: Option<&str> = None;
255 if let Some(item_capture) = match_.captures.iter().find(|cap| {
256 let cap_name = &definition_query.capture_names()[cap.index as usize];
257 cap_name.ends_with("_item") }) {
259 item_node_kind_opt = Some(definition_query.capture_names()[item_capture.index as usize]);
260 }
261
262 let mut captures_map: HashMap<String, TsNode> = HashMap::new();
263 for capture in match_.captures {
264 captures_map.insert(definition_query.capture_names()[capture.index as usize].to_string(), capture.node);
265 }
266
267 if let Some(item_kind_name) = item_node_kind_opt {
268 match item_kind_name {
269 "contract_inheritance_item" => {
270 if let (Some(contract_name_node), Some(inherited_name_node)) = (
271 captures_map.get("contract_name_for_inheritance"),
272 captures_map.get("inherited_name_for_contract"),
273 ) {
274 let contract_name = get_node_text(contract_name_node, &input.source).to_string();
275 let inherited_name = get_node_text(inherited_name_node, &input.source).to_string();
276 if ctx.all_interfaces.contains_key(&inherited_name) {
277 ctx.contract_implements.entry(contract_name.clone()).or_default().push(inherited_name.clone());
278 debug!("[ContractHandling] Contract '{}' implements interface '{}'", contract_name, inherited_name);
279 }
280 ctx.contract_inherits.entry(contract_name).or_default().push(inherited_name);
281 }
282 }
283 "interface_inheritance_item" => {
284 if let (Some(iface_name_node), Some(inherited_name_node)) = (
285 captures_map.get("interface_name_for_inheritance"),
286 captures_map.get("inherited_name_for_interface"),
287 ) {
288 let iface_name = get_node_text(iface_name_node, &input.source).to_string();
289 let inherited_name = get_node_text(inherited_name_node, &input.source).to_string();
290 ctx.interface_inherits.entry(iface_name).or_default().push(inherited_name);
291 }
292 }
293 "function_def_item" | "modifier_def_item" | "constructor_def_item" => {
294 let def_node = captures_map.get(item_kind_name).unwrap(); let node_type = match item_kind_name {
296 "function_def_item" => NodeType::Function,
297 "modifier_def_item" => NodeType::Modifier,
298 "constructor_def_item" => NodeType::Constructor,
299 _ => unreachable!(),
300 };
301
302 let name_opt = captures_map.get("function_name")
303 .or_else(|| captures_map.get("modifier_name"))
304 .map(|n| get_node_text(n, &input.source).to_string());
305
306 let scope_name_opt = captures_map.get("contract_scope_for_func")
307 .or_else(|| captures_map.get("library_scope_for_func"))
308 .or_else(|| captures_map.get("interface_scope_for_func"))
309 .or_else(|| captures_map.get("contract_scope_for_modifier"))
310 .or_else(|| captures_map.get("contract_scope_for_constructor"))
311 .map(|n| get_node_text(n, &input.source).to_string());
312
313 let final_name = match node_type {
314 NodeType::Constructor => scope_name_opt.clone().unwrap_or_default(),
315 _ => name_opt.unwrap_or_default(),
316 };
317
318 if final_name.is_empty() && node_type != NodeType::Constructor { debug!("Warning: Empty name for {:?} at span {:?}", node_type, def_node.byte_range());
320 matches_pass2.advance();
321 continue;
322 }
323
324 if node_type == NodeType::Constructor {
325 if let Some(c_name) = &scope_name_opt {
326 ctx.contracts_with_explicit_constructors.insert(c_name.clone());
327 }
328 }
329
330 let visibility = captures_map.get("visibility_node").map_or_else(
331 || match node_type { NodeType::Constructor => Visibility::Public,
333 _ => Visibility::Internal,
334 },
335 |vn| match get_node_text(vn, &input.source) {
336 "public" => Visibility::Public,
337 "private" => Visibility::Private,
338 "internal" => Visibility::Internal,
339 "external" => Visibility::External,
340 _ => Visibility::Internal, },
342 );
343
344 let node_id = graph.add_node(
345 final_name.clone(),
346 node_type.clone(),
347 scope_name_opt.clone(),
348 visibility,
349 (def_node.start_byte(), def_node.end_byte()),
350 );
351 let params = crate::cg::extract_function_parameters(*def_node, &input.source);
353 if let Some(graph_node_mut) = graph.nodes.get_mut(node_id) {
354 graph_node_mut.parameters = params;
355 }
356
357 let node_info = NodeInfo {
358 span: (def_node.start_byte(), def_node.end_byte()),
359 kind: def_node.kind().to_string(),
360 };
361 ctx.definition_nodes_info.push((node_id, node_info, scope_name_opt.clone()));
362
363 if node_type == NodeType::Function {
364 if let Some(scope) = &scope_name_opt {
365 if ctx.all_interfaces.contains_key(scope) {
366 ctx.interface_functions.entry(scope.clone()).or_default().push(final_name);
367 }
368 }
369 }
370 }
371 "state_var_item" => {
372 let state_var_decl_node = captures_map.get("state_var_node_capture")
373 .ok_or_else(|| anyhow!("state_var_item missing state_var_node_capture capture"))?;
374 let contract_name_node = captures_map.get("contract_scope_for_var")
375 .ok_or_else(|| anyhow!("state_var_item missing contract_scope_for_var capture"))?;
376 let contract_name = get_node_text(contract_name_node, &input.source).to_string();
377
378 let type_node_opt = state_var_decl_node.child_by_field_name("type");
380 let name_node_opt = state_var_decl_node.child_by_field_name("name");
381 let mut visibility_text_opt: Option<String> = None;
382
383 let mut child_cursor = state_var_decl_node.walk();
384 for child in state_var_decl_node.children(&mut child_cursor) {
385 if child.kind() == "visibility" {
386 visibility_text_opt = Some(get_node_text(&child, &input.source).to_string());
387 break;
388 }
389 }
390
391 if let (Some(var_type_node), Some(var_name_node)) = (type_node_opt, name_node_opt) {
392 let var_name_str = get_node_text(&var_name_node, &input.source).to_string();
393 let visibility = match visibility_text_opt.as_deref() {
394 Some("public") => Visibility::Public,
395 Some("internal") => Visibility::Internal,
396 Some("private") => Visibility::Private,
397 _ => Visibility::Internal, };
399
400
401 let mut extracted_key_types = Vec::new();
402 match parse_mapping_recursive(var_type_node, &input.source, &mut extracted_key_types) {
405 Ok((final_value_type, full_type_str, is_mapping)) => {
406 ctx.state_var_types.insert((contract_name.clone(), var_name_str.clone()), full_type_str.clone());
408 debug!("[ContractHandling DEBUG] Adding to state_var_types (any type): Key=({}, {}), Value={}", contract_name, var_name_str, full_type_str);
409
410 if is_mapping {
411 let mapping_info = crate::cg::MappingInfo {
413 name: var_name_str.clone(),
414 visibility: visibility.clone(),
415 key_types: extracted_key_types, value_type: final_value_type, span: (state_var_decl_node.start_byte(), state_var_decl_node.end_byte()),
418 full_type_str: full_type_str.clone(), };
420 ctx.contract_mappings.insert((contract_name.clone(), var_name_str.clone()), mapping_info.clone());
421 debug!("[ContractHandling] Added mapping info for {}.{}: Name='{}', Visibility='{:?}', Keys='{:?}', ValueType='{}', FullType='{}'",
422 contract_name, var_name_str,
423 mapping_info.name, mapping_info.visibility, mapping_info.key_types, mapping_info.value_type, mapping_info.full_type_str);
424 } else {
425 debug!("[ContractHandling] State variable {}.{} is not a mapping. Type: {}", contract_name, var_name_str, full_type_str);
428 }
429 }
430 Err(e) => {
431 debug!("Error parsing type for {}.{}: {}", contract_name, var_name_str, e);
432 let raw_type_str = get_node_text(&var_type_node, &input.source).to_string();
434 ctx.state_var_types.insert((contract_name.clone(), var_name_str.clone()), raw_type_str);
435
436 }
437 }
438 let node_id = graph.add_node(
443 var_name_str.clone(),
444 NodeType::StorageVariable,
445 Some(contract_name.clone()),
446 visibility, (state_var_decl_node.start_byte(), state_var_decl_node.end_byte()),
448 );
449 ctx.storage_var_nodes.insert((Some(contract_name), var_name_str), node_id);
450
451 } else {
452 debug!("Warning: Could not extract type or name for state variable in contract '{}' at span {:?}. Type found: {}, Name found: {}",
453 contract_name,
454 state_var_decl_node.byte_range(),
455 type_node_opt.is_some(),
456 name_node_opt.is_some()
457 );
458 }
459 }
460 "using_directive_item" => {
461 if let (Some(scope_node), Some(lib_name_node), Some(type_node)) = (
462 captures_map.get("contract_scope_for_using"),
463 captures_map.get("using_library_name"),
464 captures_map.get("using_type_or_wildcard_node"),
465 ) {
466 let contract_name = get_node_text(scope_node, &input.source).to_string();
467 let library_name = get_node_text(lib_name_node, &input.source).to_string();
468 let type_text = get_node_text(type_node, &input.source).to_string();
469 ctx.using_for_directives.entry((Some(contract_name), type_text)).or_default().push(library_name);
470 }
471 }
472 "contract_def_item" | "interface_def_item" | "library_def_item" => { }
474 _ => {
475 debug!("Warning: Unhandled item kind in Pass 2: {}", item_kind_name);
476 }
477 }
478 }
479 matches_pass2.advance();
480 }
481
482 for (contract_name, identifier_node_info) in &ctx.all_contracts {
485 if !ctx.contracts_with_explicit_constructors.contains(contract_name) {
486 let span = identifier_node_info.span;
487 let constructor_name = contract_name.clone();
488 let _node_id = graph.add_node( constructor_name.clone(),
490 NodeType::Constructor,
491 Some(contract_name.clone()),
492 Visibility::Public,
493 span,
494 );
495 }
497 }
498 debug!("[ContractHandling] Pass 2: Processing complete.");
499 Ok(())
500 }
501
502}
503
504fn parse_mapping_recursive(
506 current_node: TsNode, source: &str,
508 key_types: &mut Vec<String>, ) -> Result<(String, String, bool)> { if let (Some(key_type_field_node), Some(value_type_field_node)) = (
514 current_node.child_by_field_name("key_type"),
515 current_node.child_by_field_name("value_type"),
516 ) {
517 let key_type_str = get_node_text(&key_type_field_node, source).to_string();
521 key_types.push(key_type_str.clone());
522
523 let (final_value_type, nested_value_str, _is_nested_mapping) =
526 parse_mapping_recursive(value_type_field_node, source, key_types)?;
527
528 let current_level_full_str = format!("mapping({} => {})", key_type_str, nested_value_str);
529 return Ok((final_value_type, current_level_full_str, true));
530 }
531
532 let type_str = get_node_text(¤t_node, source).to_string();
535 Ok((type_str.clone(), type_str, false))
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::cg::{
542 CallGraph, CallGraphGeneratorContext, CallGraphGeneratorInput, Visibility,
543 };
544 use crate::parser::get_solidity_language;
545
546 use tree_sitter::Parser;
547
548 fn run_contract_handling(
549 source_code: &str,
550 ) -> Result<(CallGraph, CallGraphGeneratorContext)> {
551 let mut parser = Parser::new();
552 let sol_lang = get_solidity_language();
553 parser
554 .set_language(&sol_lang) .expect("Error loading Solidity grammar");
556 let tree = parser
557 .parse(source_code, None)
558 .context("Failed to parse source code")?;
559 let mut graph = CallGraph::new();
560 let mut ctx = CallGraphGeneratorContext::default();
561 let contract_handler = ContractHandling::default();
562
563 let input = CallGraphGeneratorInput {
564 source: source_code.to_string(),
565 tree,
566 solidity_lang: sol_lang.clone(), };
569
570 contract_handler.generate(input, &mut ctx, &mut graph)?;
571 Ok((graph, ctx))
572 }
573
574 #[test]
575 fn test_mapping_state_variables() -> Result<()> {
576 let source_code = r#"
577 contract TestMappings {
578 mapping(address => uint) public balanceOf;
579 mapping(address => mapping(address => uint)) public allowance;
580 mapping(bytes32 => mapping(uint256 => mapping(address => bool))) internal nestedMap;
581 mapping(address => UserStruct) userInfos; // Assuming UserStruct is defined elsewhere or not relevant for type string
582 struct UserStruct { uint id; }
583 }
584 "#;
585
586 let (_graph, ctx) = run_contract_handling(source_code)?;
587
588 assert_eq!(
590 ctx.state_var_types
591 .get(&("TestMappings".to_string(), "balanceOf".to_string())),
592 Some(&"mapping(address => uint)".to_string())
593 );
594 assert_eq!(
595 ctx.state_var_types
596 .get(&("TestMappings".to_string(), "allowance".to_string())),
597 Some(&"mapping(address => mapping(address => uint))".to_string())
598 );
599 assert_eq!(
600 ctx.state_var_types
601 .get(&("TestMappings".to_string(), "nestedMap".to_string())),
602 Some(&"mapping(bytes32 => mapping(uint256 => mapping(address => bool)))".to_string())
603 );
604 assert_eq!(
605 ctx.state_var_types
606 .get(&("TestMappings".to_string(), "userInfos".to_string())),
607 Some(&"mapping(address => UserStruct)".to_string()) );
609
610
611 let balance_of_key = ("TestMappings".to_string(), "balanceOf".to_string());
613 let balance_of_info = ctx.contract_mappings.get(&balance_of_key).unwrap();
614 assert_eq!(balance_of_info.name, "balanceOf");
615 assert_eq!(balance_of_info.visibility, Visibility::Public);
616 assert_eq!(balance_of_info.key_types, vec!["address".to_string()]);
617 assert_eq!(balance_of_info.value_type, "uint".to_string());
618 assert_eq!(balance_of_info.full_type_str, "mapping(address => uint)");
619
620 let allowance_key = ("TestMappings".to_string(), "allowance".to_string());
622 let allowance_info = ctx.contract_mappings.get(&allowance_key).unwrap();
623 assert_eq!(allowance_info.name, "allowance");
624 assert_eq!(allowance_info.visibility, Visibility::Public);
625 assert_eq!(
626 allowance_info.key_types,
627 vec!["address".to_string(), "address".to_string()]
628 );
629 assert_eq!(allowance_info.value_type, "uint".to_string());
630 assert_eq!(
631 allowance_info.full_type_str,
632 "mapping(address => mapping(address => uint))"
633 );
634
635 let nested_map_key = ("TestMappings".to_string(), "nestedMap".to_string());
637 let nested_map_info = ctx.contract_mappings.get(&nested_map_key).unwrap();
638 assert_eq!(nested_map_info.name, "nestedMap");
639 assert_eq!(nested_map_info.visibility, Visibility::Internal); assert_eq!(
641 nested_map_info.key_types,
642 vec!["bytes32".to_string(), "uint256".to_string(), "address".to_string()]
643 );
644 assert_eq!(nested_map_info.value_type, "bool".to_string());
645 assert_eq!(
646 nested_map_info.full_type_str,
647 "mapping(bytes32 => mapping(uint256 => mapping(address => bool)))"
648 );
649
650 let user_infos_key = ("TestMappings".to_string(), "userInfos".to_string());
652 let user_infos_info = ctx.contract_mappings.get(&user_infos_key).unwrap();
653 assert_eq!(user_infos_info.name, "userInfos");
654 assert_eq!(user_infos_info.visibility, Visibility::Internal); assert_eq!(user_infos_info.key_types, vec!["address".to_string()]);
656 assert_eq!(user_infos_info.value_type, "UserStruct".to_string());
657 assert_eq!(user_infos_info.full_type_str, "mapping(address => UserStruct)");
658
659
660 Ok(())
661 }
662}