traverse_graph/natspec/
extract.rs

1/*
2    This module focuses on extracting comments from Solidity source code files
3    and associating them with the relevant code items (like contracts, functions,
4    state variables, etc.). It utilizes `tree-sitter` to parse the Solidity
5    Abstract Syntax Tree (AST).
6
7    Key functionalities include:
8    - Defining `SourceItemKind` to categorize Solidity code elements (e.g.,
9      Contract, Function, StateVariable).
10    - Defining `SourceComment` to store the extracted comment text, its source
11      span, the kind and name of the associated code item, the item's span,
12      and a flag indicating if it's a NatSpec comment.
13    - Using a `tree-sitter` query to identify comments (both block and line
14      comments) that are positioned immediately before or adjacent to
15      recognizable Solidity constructs.
16    - Extracting details for each matched comment and its corresponding code
17      item, including the item's kind and name. For certain items like state
18      variables, the name is derived by inspecting the item node's children,
19      as it might not be directly captured by a simple query name field.
20    - The main function `extract_source_comments` takes Solidity source code
21      as input and returns a vector of `SourceComment` structs.
22
23    This module acts as a bridge between raw Solidity code and structured
24    comment information, which can then be further processed, for instance,
25    by parsing the `text` field of `SourceComment` using the `natspec/mod.rs`
26    parsers if `is_natspec` is true.
27*/
28use crate::parser::get_node_text;
29use anyhow::{Context, Result};
30use crate::parser::get_solidity_language;
31use streaming_iterator::StreamingIterator;
32use tree_sitter::{Node, Parser, Query, QueryCursor};
33use serde::{Serialize, Deserialize};
34
35use super::{TextIndex, TextRange};
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub enum SourceItemKind {
39    Contract,
40    Interface,
41    Library,
42    Struct,
43    Enum,
44    Function,
45    Modifier,
46    Event,
47    Error,
48    StateVariable,
49    UsingDirective,
50    Unknown,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54pub struct SourceComment {
55    pub text: String,
56    pub raw_comment_span: TextRange,
57    pub item_kind: SourceItemKind,
58    pub item_name: Option<String>,
59    pub item_span: TextRange,
60    pub is_natspec: bool,
61}
62
63fn node_to_text_range(node: &tree_sitter::Node) -> TextRange {
64    TextRange {
65        start: TextIndex {
66            utf8: node.start_byte(),
67            line: node.start_position().row,
68            column: node.start_position().column,
69        },
70        end: TextIndex {
71            utf8: node.end_byte(),
72            line: node.end_position().row,
73            column: node.end_position().column,
74        },
75    }
76}
77
78const SOURCE_ITEM_COMMENT_QUERY: &str = r#"
79(
80  (comment) @comment
81  .
82  [
83    (contract_declaration name: (identifier) @item_name)
84    (interface_declaration name: (identifier) @item_name)
85    (library_declaration name: (identifier) @item_name)
86    (struct_declaration name: (identifier) @item_name)
87    (enum_declaration name: (identifier) @item_name)
88    (function_definition name: (identifier) @item_name)
89    (modifier_definition name: (identifier) @item_name)
90    (event_definition name: (identifier) @item_name)
91    (error_declaration name: (identifier) @item_name)
92    (state_variable_declaration name: (identifier) @item_name)
93    (using_directive)
94  ] @item
95)
96"#;
97
98pub fn extract_source_comments(source: &str) -> Result<Vec<SourceComment>> {
99    let solidity_lang = get_solidity_language();
100    let mut parser = Parser::new();
101    parser
102        .set_language(&solidity_lang)
103        .context("Failed to set language for Solidity parser")?;
104
105    let tree = parser
106        .parse(source, None)
107        .context("Failed to parse Solidity source")?;
108
109    let query = Query::new(&solidity_lang, SOURCE_ITEM_COMMENT_QUERY)
110        .context("Failed to create source item comment query")?;
111
112    let mut query_cursor = QueryCursor::new();
113    let mut matches = query_cursor.matches(&query, tree.root_node(), source.as_bytes());
114
115    let mut source_comments = Vec::new();
116
117    matches.advance();
118    while let Some(mat) = matches.get() {
119        let mut comment_node: Option<Node> = None;
120        let mut item_node: Option<Node> = None;
121        let mut item_name_node: Option<Node> = None;
122
123        for capture in mat.captures {
124            let capture_name = &query.capture_names()[capture.index as usize];
125            match *capture_name {
126                "comment" => comment_node = Some(capture.node),
127                "item" => item_node = Some(capture.node),
128                "item_name" => item_name_node = Some(capture.node),
129                _ => {}
130            }
131        }
132
133        if let (Some(comment_n), Some(item_n)) = (comment_node, item_node) {
134            let comment_text_str = get_node_text(&comment_n, source);
135            let is_natspec =
136                comment_text_str.starts_with("///") || comment_text_str.starts_with("/**");
137
138            let item_kind_str = item_n.kind();
139            let (item_kind, extracted_name) = match item_kind_str {
140                "contract_declaration" => (
141                    SourceItemKind::Contract,
142                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
143                ),
144                "interface_declaration" => (
145                    SourceItemKind::Interface,
146                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
147                ),
148                "library_declaration" => (
149                    SourceItemKind::Library,
150                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
151                ),
152                "struct_declaration" => (
153                    SourceItemKind::Struct,
154                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
155                ),
156                "enum_declaration" => (
157                    SourceItemKind::Enum,
158                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
159                ),
160                "function_definition" => (
161                    SourceItemKind::Function,
162                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
163                ),
164                "modifier_definition" => (
165                    SourceItemKind::Modifier,
166                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
167                ),
168                "event_definition" => (
169                    SourceItemKind::Event,
170                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
171                ),
172                "error_declaration" => (
173                    SourceItemKind::Error,
174                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
175                ),
176                "state_variable_declaration" => (
177                    SourceItemKind::StateVariable,
178                    item_name_node.map(|n| get_node_text(&n, source).to_string()),
179                ),
180                "using_directive" => (
181                    SourceItemKind::UsingDirective,
182                    Some(get_node_text(&item_n, source).to_string()), // Name for using_directive is the full text
183                ),
184                _ => (SourceItemKind::Unknown, None),
185            };
186
187            source_comments.push(SourceComment {
188                text: comment_text_str.to_string(),
189                raw_comment_span: node_to_text_range(&comment_n),
190                item_kind,
191                item_name: extracted_name,
192                item_span: node_to_text_range(&item_n),
193                is_natspec,
194            });
195        }
196        matches.advance();
197    }
198
199    Ok(source_comments)
200}
201
202#[cfg(test)]
203mod source_comment_extraction_tests {
204    use super::*;
205    
206
207    #[test]
208    fn test_extract_simple_contract_comment() {
209        let source = r#"
210        /// This is a contract
211        contract MyContract {}
212        "#;
213        let comments = extract_source_comments(source).unwrap();
214        assert_eq!(comments.len(), 1);
215        let comment = &comments[0];
216        assert_eq!(comment.text, "/// This is a contract");
217        assert!(comment.is_natspec);
218        assert_eq!(comment.item_kind, SourceItemKind::Contract);
219        assert_eq!(comment.item_name, Some("MyContract".to_string()));
220    }
221
222    #[test]
223    fn test_extract_function_comment() {
224        let source = r#"
225        /**
226         * This is a function.
227         * @param x an integer
228         */
229        function myFunction(uint x) public {}
230        "#;
231        let comments = extract_source_comments(source).unwrap();
232        assert_eq!(comments.len(), 1);
233        let comment = &comments[0];
234        assert_eq!(
235            comment.text,
236            "/**\n         * This is a function.\n         * @param x an integer\n         */"
237        );
238        assert!(comment.is_natspec);
239        assert_eq!(comment.item_kind, SourceItemKind::Function);
240        assert_eq!(comment.item_name, Some("myFunction".to_string()));
241    }
242
243    #[test]
244    fn test_extract_state_variable_comment() {
245        let source = r#"
246        contract TestContract {
247            /// The counter value
248            uint256 public count;
249        }
250        "#;
251        let comments = extract_source_comments(source).unwrap();
252        assert_eq!(comments.len(), 1);
253        let comment = &comments[0];
254        assert_eq!(comment.text, "/// The counter value");
255        assert!(comment.is_natspec);
256        assert_eq!(comment.item_kind, SourceItemKind::StateVariable);
257        assert_eq!(comment.item_name, Some("count".to_string()));
258    }
259
260    #[test]
261    fn test_extract_multiple_comments() {
262        let source = r#"
263        /// Contract C
264        contract C {
265            /// Var V
266            uint public v;
267
268            /** Func F */
269            function f() public {}
270        }
271        "#;
272        let comments = extract_source_comments(source).unwrap();
273        assert_eq!(comments.len(), 3);
274
275        let contract_comment = comments
276            .iter()
277            .find(|c| c.item_name == Some("C".to_string()))
278            .unwrap();
279        assert_eq!(contract_comment.text, "/// Contract C");
280        assert_eq!(contract_comment.item_kind, SourceItemKind::Contract);
281
282        let var_comment = comments
283            .iter()
284            .find(|c| c.item_name == Some("v".to_string()))
285            .unwrap();
286        assert_eq!(var_comment.text, "/// Var V");
287        assert_eq!(var_comment.item_kind, SourceItemKind::StateVariable);
288
289        let func_comment = comments
290            .iter()
291            .find(|c| c.item_name == Some("f".to_string()))
292            .unwrap();
293        assert_eq!(func_comment.text, "/** Func F */");
294        assert_eq!(func_comment.item_kind, SourceItemKind::Function);
295    }
296
297    #[test]
298    fn test_no_comment() {
299        let source = "contract NoComment {}";
300        let comments = extract_source_comments(source).unwrap();
301        assert!(comments.is_empty());
302    }
303
304    #[test]
305    fn test_regular_comment_not_natspec() {
306        let source = r#"
307        // A regular comment
308        function test() public {}
309        "#;
310        let comments = extract_source_comments(source).unwrap();
311        assert_eq!(comments.len(), 1);
312        assert_eq!(comments[0].text, "// A regular comment");
313        assert!(!comments[0].is_natspec);
314        assert_eq!(comments[0].item_kind, SourceItemKind::Function);
315        assert_eq!(comments[0].item_name, Some("test".to_string()));
316    }
317
318    #[test]
319    fn test_using_directive_comment() {
320        let source = r#"
321        contract TestContract {
322            /// @title Using SafeMath for uint256
323            using SafeMath for uint256;
324        }
325        "#;
326        let comments = extract_source_comments(source).unwrap();
327        assert_eq!(comments.len(), 1);
328        let comment = &comments[0];
329        assert_eq!(comment.text, "/// @title Using SafeMath for uint256");
330        assert!(comment.is_natspec);
331        assert_eq!(comment.item_kind, SourceItemKind::UsingDirective);
332        assert_eq!(
333            comment.item_name,
334            Some("using SafeMath for uint256;".to_string())
335        );
336    }
337
338    #[test]
339    fn test_state_variable_complex_declaration() {
340        let source = r#"
341        contract TestContract {
342            /// Stores the owner of the contract
343            address payable public owner;
344        }
345        "#;
346        let comments = extract_source_comments(source).unwrap();
347        assert_eq!(comments.len(), 1);
348        let comment = &comments[0];
349        assert_eq!(comment.text, "/// Stores the owner of the contract");
350        assert_eq!(comment.item_kind, SourceItemKind::StateVariable);
351        assert_eq!(comment.item_name, Some("owner".to_string()));
352    }
353
354    #[test]
355    fn test_state_variable_no_name_found() {
356        let source = r#"
357        contract Test {
358            /// This is a mapping
359            mapping(address => uint) public balances;
360        }
361        "#;
362        let comments = extract_source_comments(source).unwrap();
363        let mapping_comment = comments
364            .iter()
365            .find(|c| c.text == "/// This is a mapping")
366            .unwrap();
367        assert_eq!(mapping_comment.item_kind, SourceItemKind::StateVariable);
368        assert_eq!(mapping_comment.item_name, Some("balances".to_string()));
369    }
370
371    #[test]
372    fn test_extract_struct_and_event_comments() {
373        let source = r#"
374        /// Defines a new proposal.
375        struct Proposal {
376            address proposer;
377            string description;
378            uint voteCount;
379        }
380
381        /** @dev Emitted when a new proposal is created.
382          * @param proposalId The ID of the new proposal.
383          */
384        event ProposalCreated(uint proposalId);
385        "#;
386        let comments = extract_source_comments(source).unwrap();
387        assert_eq!(comments.len(), 2);
388
389        let struct_comment = comments
390            .iter()
391            .find(|c| c.item_name == Some("Proposal".to_string()))
392            .unwrap();
393        assert_eq!(struct_comment.text, "/// Defines a new proposal.");
394        assert_eq!(struct_comment.item_kind, SourceItemKind::Struct);
395
396        let event_comment = comments
397            .iter()
398            .find(|c| c.item_name == Some("ProposalCreated".to_string()))
399            .unwrap();
400        assert_eq!(event_comment.text, "/** @dev Emitted when a new proposal is created.\n          * @param proposalId The ID of the new proposal.\n          */");
401        assert_eq!(event_comment.item_kind, SourceItemKind::Event);
402    }
403}