1use crate::parser::get_node_text;
29use anyhow::{Context, Result};
30use crate::parser::get_solidity_language;
31use streaming_iterator::StreamingIterator;
32use tree_sitter::{Node, Parser, Query, QueryCursor};
33use serde::{Serialize, Deserialize};
34
35use super::{TextIndex, TextRange};
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub enum SourceItemKind {
39 Contract,
40 Interface,
41 Library,
42 Struct,
43 Enum,
44 Function,
45 Modifier,
46 Event,
47 Error,
48 StateVariable,
49 UsingDirective,
50 Unknown,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54pub struct SourceComment {
55 pub text: String,
56 pub raw_comment_span: TextRange,
57 pub item_kind: SourceItemKind,
58 pub item_name: Option<String>,
59 pub item_span: TextRange,
60 pub is_natspec: bool,
61}
62
63fn node_to_text_range(node: &tree_sitter::Node) -> TextRange {
64 TextRange {
65 start: TextIndex {
66 utf8: node.start_byte(),
67 line: node.start_position().row,
68 column: node.start_position().column,
69 },
70 end: TextIndex {
71 utf8: node.end_byte(),
72 line: node.end_position().row,
73 column: node.end_position().column,
74 },
75 }
76}
77
78const SOURCE_ITEM_COMMENT_QUERY: &str = r#"
79(
80 (comment) @comment
81 .
82 [
83 (contract_declaration name: (identifier) @item_name)
84 (interface_declaration name: (identifier) @item_name)
85 (library_declaration name: (identifier) @item_name)
86 (struct_declaration name: (identifier) @item_name)
87 (enum_declaration name: (identifier) @item_name)
88 (function_definition name: (identifier) @item_name)
89 (modifier_definition name: (identifier) @item_name)
90 (event_definition name: (identifier) @item_name)
91 (error_declaration name: (identifier) @item_name)
92 (state_variable_declaration name: (identifier) @item_name)
93 (using_directive)
94 ] @item
95)
96"#;
97
98pub fn extract_source_comments(source: &str) -> Result<Vec<SourceComment>> {
99 let solidity_lang = get_solidity_language();
100 let mut parser = Parser::new();
101 parser
102 .set_language(&solidity_lang)
103 .context("Failed to set language for Solidity parser")?;
104
105 let tree = parser
106 .parse(source, None)
107 .context("Failed to parse Solidity source")?;
108
109 let query = Query::new(&solidity_lang, SOURCE_ITEM_COMMENT_QUERY)
110 .context("Failed to create source item comment query")?;
111
112 let mut query_cursor = QueryCursor::new();
113 let mut matches = query_cursor.matches(&query, tree.root_node(), source.as_bytes());
114
115 let mut source_comments = Vec::new();
116
117 matches.advance();
118 while let Some(mat) = matches.get() {
119 let mut comment_node: Option<Node> = None;
120 let mut item_node: Option<Node> = None;
121 let mut item_name_node: Option<Node> = None;
122
123 for capture in mat.captures {
124 let capture_name = &query.capture_names()[capture.index as usize];
125 match *capture_name {
126 "comment" => comment_node = Some(capture.node),
127 "item" => item_node = Some(capture.node),
128 "item_name" => item_name_node = Some(capture.node),
129 _ => {}
130 }
131 }
132
133 if let (Some(comment_n), Some(item_n)) = (comment_node, item_node) {
134 let comment_text_str = get_node_text(&comment_n, source);
135 let is_natspec =
136 comment_text_str.starts_with("///") || comment_text_str.starts_with("/**");
137
138 let item_kind_str = item_n.kind();
139 let (item_kind, extracted_name) = match item_kind_str {
140 "contract_declaration" => (
141 SourceItemKind::Contract,
142 item_name_node.map(|n| get_node_text(&n, source).to_string()),
143 ),
144 "interface_declaration" => (
145 SourceItemKind::Interface,
146 item_name_node.map(|n| get_node_text(&n, source).to_string()),
147 ),
148 "library_declaration" => (
149 SourceItemKind::Library,
150 item_name_node.map(|n| get_node_text(&n, source).to_string()),
151 ),
152 "struct_declaration" => (
153 SourceItemKind::Struct,
154 item_name_node.map(|n| get_node_text(&n, source).to_string()),
155 ),
156 "enum_declaration" => (
157 SourceItemKind::Enum,
158 item_name_node.map(|n| get_node_text(&n, source).to_string()),
159 ),
160 "function_definition" => (
161 SourceItemKind::Function,
162 item_name_node.map(|n| get_node_text(&n, source).to_string()),
163 ),
164 "modifier_definition" => (
165 SourceItemKind::Modifier,
166 item_name_node.map(|n| get_node_text(&n, source).to_string()),
167 ),
168 "event_definition" => (
169 SourceItemKind::Event,
170 item_name_node.map(|n| get_node_text(&n, source).to_string()),
171 ),
172 "error_declaration" => (
173 SourceItemKind::Error,
174 item_name_node.map(|n| get_node_text(&n, source).to_string()),
175 ),
176 "state_variable_declaration" => (
177 SourceItemKind::StateVariable,
178 item_name_node.map(|n| get_node_text(&n, source).to_string()),
179 ),
180 "using_directive" => (
181 SourceItemKind::UsingDirective,
182 Some(get_node_text(&item_n, source).to_string()), ),
184 _ => (SourceItemKind::Unknown, None),
185 };
186
187 source_comments.push(SourceComment {
188 text: comment_text_str.to_string(),
189 raw_comment_span: node_to_text_range(&comment_n),
190 item_kind,
191 item_name: extracted_name,
192 item_span: node_to_text_range(&item_n),
193 is_natspec,
194 });
195 }
196 matches.advance();
197 }
198
199 Ok(source_comments)
200}
201
202#[cfg(test)]
203mod source_comment_extraction_tests {
204 use super::*;
205
206
207 #[test]
208 fn test_extract_simple_contract_comment() {
209 let source = r#"
210 /// This is a contract
211 contract MyContract {}
212 "#;
213 let comments = extract_source_comments(source).unwrap();
214 assert_eq!(comments.len(), 1);
215 let comment = &comments[0];
216 assert_eq!(comment.text, "/// This is a contract");
217 assert!(comment.is_natspec);
218 assert_eq!(comment.item_kind, SourceItemKind::Contract);
219 assert_eq!(comment.item_name, Some("MyContract".to_string()));
220 }
221
222 #[test]
223 fn test_extract_function_comment() {
224 let source = r#"
225 /**
226 * This is a function.
227 * @param x an integer
228 */
229 function myFunction(uint x) public {}
230 "#;
231 let comments = extract_source_comments(source).unwrap();
232 assert_eq!(comments.len(), 1);
233 let comment = &comments[0];
234 assert_eq!(
235 comment.text,
236 "/**\n * This is a function.\n * @param x an integer\n */"
237 );
238 assert!(comment.is_natspec);
239 assert_eq!(comment.item_kind, SourceItemKind::Function);
240 assert_eq!(comment.item_name, Some("myFunction".to_string()));
241 }
242
243 #[test]
244 fn test_extract_state_variable_comment() {
245 let source = r#"
246 contract TestContract {
247 /// The counter value
248 uint256 public count;
249 }
250 "#;
251 let comments = extract_source_comments(source).unwrap();
252 assert_eq!(comments.len(), 1);
253 let comment = &comments[0];
254 assert_eq!(comment.text, "/// The counter value");
255 assert!(comment.is_natspec);
256 assert_eq!(comment.item_kind, SourceItemKind::StateVariable);
257 assert_eq!(comment.item_name, Some("count".to_string()));
258 }
259
260 #[test]
261 fn test_extract_multiple_comments() {
262 let source = r#"
263 /// Contract C
264 contract C {
265 /// Var V
266 uint public v;
267
268 /** Func F */
269 function f() public {}
270 }
271 "#;
272 let comments = extract_source_comments(source).unwrap();
273 assert_eq!(comments.len(), 3);
274
275 let contract_comment = comments
276 .iter()
277 .find(|c| c.item_name == Some("C".to_string()))
278 .unwrap();
279 assert_eq!(contract_comment.text, "/// Contract C");
280 assert_eq!(contract_comment.item_kind, SourceItemKind::Contract);
281
282 let var_comment = comments
283 .iter()
284 .find(|c| c.item_name == Some("v".to_string()))
285 .unwrap();
286 assert_eq!(var_comment.text, "/// Var V");
287 assert_eq!(var_comment.item_kind, SourceItemKind::StateVariable);
288
289 let func_comment = comments
290 .iter()
291 .find(|c| c.item_name == Some("f".to_string()))
292 .unwrap();
293 assert_eq!(func_comment.text, "/** Func F */");
294 assert_eq!(func_comment.item_kind, SourceItemKind::Function);
295 }
296
297 #[test]
298 fn test_no_comment() {
299 let source = "contract NoComment {}";
300 let comments = extract_source_comments(source).unwrap();
301 assert!(comments.is_empty());
302 }
303
304 #[test]
305 fn test_regular_comment_not_natspec() {
306 let source = r#"
307 // A regular comment
308 function test() public {}
309 "#;
310 let comments = extract_source_comments(source).unwrap();
311 assert_eq!(comments.len(), 1);
312 assert_eq!(comments[0].text, "// A regular comment");
313 assert!(!comments[0].is_natspec);
314 assert_eq!(comments[0].item_kind, SourceItemKind::Function);
315 assert_eq!(comments[0].item_name, Some("test".to_string()));
316 }
317
318 #[test]
319 fn test_using_directive_comment() {
320 let source = r#"
321 contract TestContract {
322 /// @title Using SafeMath for uint256
323 using SafeMath for uint256;
324 }
325 "#;
326 let comments = extract_source_comments(source).unwrap();
327 assert_eq!(comments.len(), 1);
328 let comment = &comments[0];
329 assert_eq!(comment.text, "/// @title Using SafeMath for uint256");
330 assert!(comment.is_natspec);
331 assert_eq!(comment.item_kind, SourceItemKind::UsingDirective);
332 assert_eq!(
333 comment.item_name,
334 Some("using SafeMath for uint256;".to_string())
335 );
336 }
337
338 #[test]
339 fn test_state_variable_complex_declaration() {
340 let source = r#"
341 contract TestContract {
342 /// Stores the owner of the contract
343 address payable public owner;
344 }
345 "#;
346 let comments = extract_source_comments(source).unwrap();
347 assert_eq!(comments.len(), 1);
348 let comment = &comments[0];
349 assert_eq!(comment.text, "/// Stores the owner of the contract");
350 assert_eq!(comment.item_kind, SourceItemKind::StateVariable);
351 assert_eq!(comment.item_name, Some("owner".to_string()));
352 }
353
354 #[test]
355 fn test_state_variable_no_name_found() {
356 let source = r#"
357 contract Test {
358 /// This is a mapping
359 mapping(address => uint) public balances;
360 }
361 "#;
362 let comments = extract_source_comments(source).unwrap();
363 let mapping_comment = comments
364 .iter()
365 .find(|c| c.text == "/// This is a mapping")
366 .unwrap();
367 assert_eq!(mapping_comment.item_kind, SourceItemKind::StateVariable);
368 assert_eq!(mapping_comment.item_name, Some("balances".to_string()));
369 }
370
371 #[test]
372 fn test_extract_struct_and_event_comments() {
373 let source = r#"
374 /// Defines a new proposal.
375 struct Proposal {
376 address proposer;
377 string description;
378 uint voteCount;
379 }
380
381 /** @dev Emitted when a new proposal is created.
382 * @param proposalId The ID of the new proposal.
383 */
384 event ProposalCreated(uint proposalId);
385 "#;
386 let comments = extract_source_comments(source).unwrap();
387 assert_eq!(comments.len(), 2);
388
389 let struct_comment = comments
390 .iter()
391 .find(|c| c.item_name == Some("Proposal".to_string()))
392 .unwrap();
393 assert_eq!(struct_comment.text, "/// Defines a new proposal.");
394 assert_eq!(struct_comment.item_kind, SourceItemKind::Struct);
395
396 let event_comment = comments
397 .iter()
398 .find(|c| c.item_name == Some("ProposalCreated".to_string()))
399 .unwrap();
400 assert_eq!(event_comment.text, "/** @dev Emitted when a new proposal is created.\n * @param proposalId The ID of the new proposal.\n */");
401 assert_eq!(event_comment.item_kind, SourceItemKind::Event);
402 }
403}