1use proc_macro2::Span;
5use syn::{
6 Attribute, File, ImplItem, ItemConst, ItemEnum, ItemFn, ItemImpl, ItemMacro, ItemMod,
7 ItemStatic, ItemStruct, ItemTrait, ItemType, TraitItem, Visibility as SynVisibility,
8 spanned::Spanned, visit::Visit,
9};
10
11use crate::types::{LineSpan, SemanticUnit, SemanticUnitKind, Visibility};
12
13pub struct SemanticUnitVisitor {
15 units: Vec<SemanticUnit>,
16 in_test_module: bool,
17 current_impl_name: Option<String>,
18}
19
20impl SemanticUnitVisitor {
21 pub fn new() -> Self {
35 Self {
36 units: Vec::new(),
37 in_test_module: false,
38 current_impl_name: None,
39 }
40 }
41
42 pub fn extract(file: &File) -> Vec<SemanticUnit> {
63 let mut visitor = Self::new();
64 visitor.visit_file(file);
65 visitor.units
66 }
67
68 fn span_to_line_span(&self, span: Span) -> LineSpan {
69 let start = span.start();
70 let end = span.end();
71 LineSpan::new(start.line, end.line)
72 }
73
74 fn convert_visibility(&self, vis: &SynVisibility) -> Visibility {
75 match vis {
76 SynVisibility::Public(_) => Visibility::Public,
77 SynVisibility::Restricted(r) => {
78 if r.path.is_ident("crate") {
79 Visibility::Crate
80 } else {
81 Visibility::Restricted
82 }
83 }
84 SynVisibility::Inherited => Visibility::Private,
85 }
86 }
87
88 fn extract_attributes(&self, attrs: &[Attribute]) -> Vec<String> {
89 attrs
90 .iter()
91 .filter_map(|attr| attr.path().get_ident().map(|ident| ident.to_string()))
92 .collect()
93 }
94
95 fn has_test_attribute(&self, attrs: &[Attribute]) -> bool {
96 attrs.iter().any(|attr| {
97 let path = attr.path();
98 if path.is_ident("test") || path.is_ident("bench") {
99 return true;
100 }
101 if path.is_ident("cfg")
102 && let Ok(meta) = attr.meta.require_list()
103 {
104 let tokens = meta.tokens.to_string();
105 if tokens.contains("test") {
106 return true;
107 }
108 }
109 false
110 })
111 }
112
113 fn is_test_module(&self, attrs: &[Attribute]) -> bool {
114 attrs.iter().any(|attr| {
115 if attr.path().is_ident("cfg")
116 && let Ok(meta) = attr.meta.require_list()
117 {
118 let tokens = meta.tokens.to_string();
119 return tokens.contains("test");
120 }
121 false
122 })
123 }
124
125 fn add_unit(
126 &mut self,
127 kind: SemanticUnitKind,
128 name: String,
129 visibility: Visibility,
130 span: Span,
131 attrs: &[Attribute],
132 ) {
133 let mut attributes = self.extract_attributes(attrs);
134
135 if self.in_test_module && !attributes.contains(&"cfg_test".to_string()) {
136 attributes.push("cfg_test".to_string());
137 }
138
139 if self.has_test_attribute(attrs) && !attributes.contains(&"test".to_string()) {
140 attributes.push("test".to_string());
141 }
142
143 let unit = match &self.current_impl_name {
144 Some(impl_name) => SemanticUnit::with_impl(
145 kind,
146 name,
147 impl_name.clone(),
148 visibility,
149 self.span_to_line_span(span),
150 attributes,
151 ),
152 None => SemanticUnit::new(
153 kind,
154 name,
155 visibility,
156 self.span_to_line_span(span),
157 attributes,
158 ),
159 };
160 self.units.push(unit);
161 }
162}
163
164impl Default for SemanticUnitVisitor {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl<'ast> Visit<'ast> for SemanticUnitVisitor {
171 fn visit_item_fn(&mut self, node: &'ast ItemFn) {
172 self.add_unit(
173 SemanticUnitKind::Function,
174 node.sig.ident.to_string(),
175 self.convert_visibility(&node.vis),
176 node.span(),
177 &node.attrs,
178 );
179 syn::visit::visit_item_fn(self, node);
180 }
181
182 fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
183 self.add_unit(
184 SemanticUnitKind::Struct,
185 node.ident.to_string(),
186 self.convert_visibility(&node.vis),
187 node.span(),
188 &node.attrs,
189 );
190 syn::visit::visit_item_struct(self, node);
191 }
192
193 fn visit_item_enum(&mut self, node: &'ast ItemEnum) {
194 self.add_unit(
195 SemanticUnitKind::Enum,
196 node.ident.to_string(),
197 self.convert_visibility(&node.vis),
198 node.span(),
199 &node.attrs,
200 );
201 syn::visit::visit_item_enum(self, node);
202 }
203
204 fn visit_item_trait(&mut self, node: &'ast ItemTrait) {
205 self.add_unit(
206 SemanticUnitKind::Trait,
207 node.ident.to_string(),
208 self.convert_visibility(&node.vis),
209 node.span(),
210 &node.attrs,
211 );
212 syn::visit::visit_item_trait(self, node);
213 }
214
215 fn visit_item_impl(&mut self, node: &'ast ItemImpl) {
216 let impl_name = if let Some((_, path, _)) = &node.trait_ {
217 format!(
218 "{} for {}",
219 path.segments
220 .last()
221 .map(|s| s.ident.to_string())
222 .unwrap_or_default(),
223 type_to_string(&node.self_ty)
224 )
225 } else {
226 type_to_string(&node.self_ty)
227 };
228
229 self.add_unit(
230 SemanticUnitKind::Impl,
231 impl_name.clone(),
232 Visibility::Private,
233 node.span(),
234 &node.attrs,
235 );
236
237 let previous_impl_name = self.current_impl_name.take();
238 self.current_impl_name = Some(impl_name);
239
240 for item in &node.items {
241 match item {
242 ImplItem::Fn(method) => {
243 self.add_unit(
244 SemanticUnitKind::Function,
245 method.sig.ident.to_string(),
246 self.convert_visibility(&method.vis),
247 method.span(),
248 &method.attrs,
249 );
250 }
251 ImplItem::Const(c) => {
252 self.add_unit(
253 SemanticUnitKind::Const,
254 c.ident.to_string(),
255 self.convert_visibility(&c.vis),
256 c.span(),
257 &c.attrs,
258 );
259 }
260 ImplItem::Type(t) => {
261 self.add_unit(
262 SemanticUnitKind::TypeAlias,
263 t.ident.to_string(),
264 self.convert_visibility(&t.vis),
265 t.span(),
266 &t.attrs,
267 );
268 }
269 _ => {}
270 }
271 }
272
273 self.current_impl_name = previous_impl_name;
274 }
275
276 fn visit_item_const(&mut self, node: &'ast ItemConst) {
277 self.add_unit(
278 SemanticUnitKind::Const,
279 node.ident.to_string(),
280 self.convert_visibility(&node.vis),
281 node.span(),
282 &node.attrs,
283 );
284 }
285
286 fn visit_item_static(&mut self, node: &'ast ItemStatic) {
287 self.add_unit(
288 SemanticUnitKind::Static,
289 node.ident.to_string(),
290 self.convert_visibility(&node.vis),
291 node.span(),
292 &node.attrs,
293 );
294 }
295
296 fn visit_item_type(&mut self, node: &'ast ItemType) {
297 self.add_unit(
298 SemanticUnitKind::TypeAlias,
299 node.ident.to_string(),
300 self.convert_visibility(&node.vis),
301 node.span(),
302 &node.attrs,
303 );
304 }
305
306 fn visit_item_macro(&mut self, node: &'ast ItemMacro) {
307 if let Some(ident) = &node.ident {
308 self.add_unit(
309 SemanticUnitKind::Macro,
310 ident.to_string(),
311 Visibility::Private,
312 node.span(),
313 &node.attrs,
314 );
315 }
316 }
317
318 fn visit_item_mod(&mut self, node: &'ast ItemMod) {
319 let is_test = self.is_test_module(&node.attrs) || node.ident == "tests";
320
321 self.add_unit(
322 SemanticUnitKind::Module,
323 node.ident.to_string(),
324 self.convert_visibility(&node.vis),
325 node.span(),
326 &node.attrs,
327 );
328
329 if let Some((_, items)) = &node.content {
330 let was_in_test = self.in_test_module;
331 self.in_test_module = is_test || was_in_test;
332
333 for item in items {
334 self.visit_item(item);
335 }
336
337 self.in_test_module = was_in_test;
338 }
339 }
340
341 fn visit_trait_item(&mut self, node: &'ast TraitItem) {
342 match node {
343 TraitItem::Fn(method) => {
344 self.add_unit(
345 SemanticUnitKind::Function,
346 method.sig.ident.to_string(),
347 Visibility::Public,
348 method.span(),
349 &method.attrs,
350 );
351 }
352 TraitItem::Const(c) => {
353 self.add_unit(
354 SemanticUnitKind::Const,
355 c.ident.to_string(),
356 Visibility::Public,
357 c.span(),
358 &c.attrs,
359 );
360 }
361 TraitItem::Type(t) => {
362 self.add_unit(
363 SemanticUnitKind::TypeAlias,
364 t.ident.to_string(),
365 Visibility::Public,
366 t.span(),
367 &t.attrs,
368 );
369 }
370 _ => {}
371 }
372 syn::visit::visit_trait_item(self, node);
373 }
374}
375
376fn type_to_string(ty: &syn::Type) -> String {
377 match ty {
378 syn::Type::Path(p) => p
379 .path
380 .segments
381 .last()
382 .map(|s| s.ident.to_string())
383 .unwrap_or_else(|| "Unknown".to_string()),
384 _ => "Unknown".to_string(),
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_extract_function() {
394 let code = "pub fn hello() {}";
395 let file = syn::parse_file(code).expect("parse failed");
396 let units = SemanticUnitVisitor::extract(&file);
397
398 assert_eq!(units.len(), 1);
399 assert_eq!(units[0].name, "hello");
400 assert!(matches!(units[0].kind, SemanticUnitKind::Function));
401 assert!(matches!(units[0].visibility, Visibility::Public));
402 }
403
404 #[test]
405 fn test_extract_struct() {
406 let code = "struct Point { x: i32, y: i32 }";
407 let file = syn::parse_file(code).expect("parse failed");
408 let units = SemanticUnitVisitor::extract(&file);
409
410 assert_eq!(units.len(), 1);
411 assert_eq!(units[0].name, "Point");
412 assert!(matches!(units[0].kind, SemanticUnitKind::Struct));
413 }
414
415 #[test]
416 fn test_extract_test_function() {
417 let code = r#"
418 #[test]
419 fn test_something() {}
420 "#;
421 let file = syn::parse_file(code).expect("parse failed");
422 let units = SemanticUnitVisitor::extract(&file);
423
424 assert_eq!(units.len(), 1);
425 assert!(units[0].has_attribute("test"));
426 }
427
428 #[test]
429 fn test_extract_impl_block() {
430 let code = r#"
431 struct Foo;
432 impl Foo {
433 pub fn new() -> Self { Foo }
434 }
435 "#;
436 let file = syn::parse_file(code).expect("parse failed");
437 let units = SemanticUnitVisitor::extract(&file);
438
439 assert_eq!(units.len(), 3);
440 assert!(
441 units
442 .iter()
443 .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Struct))
444 );
445 assert!(
446 units
447 .iter()
448 .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Impl))
449 );
450 assert!(
451 units
452 .iter()
453 .any(|u| u.name == "new" && matches!(u.kind, SemanticUnitKind::Function))
454 );
455 }
456
457 #[test]
458 fn test_extract_test_module() {
459 let code = r#"
460 fn production() {}
461
462 #[cfg(test)]
463 mod tests {
464 fn helper() {}
465
466 #[test]
467 fn test_it() {}
468 }
469 "#;
470 let file = syn::parse_file(code).expect("parse failed");
471 let units = SemanticUnitVisitor::extract(&file);
472
473 let prod_fn = units
474 .iter()
475 .find(|u| u.name == "production")
476 .expect("production not found");
477 assert!(!prod_fn.has_attribute("cfg_test"));
478
479 let helper_fn = units
480 .iter()
481 .find(|u| u.name == "helper")
482 .expect("helper not found");
483 assert!(helper_fn.has_attribute("cfg_test"));
484
485 let test_fn = units
486 .iter()
487 .find(|u| u.name == "test_it")
488 .expect("test_it not found");
489 assert!(test_fn.has_attribute("test"));
490 assert!(test_fn.has_attribute("cfg_test"));
491 }
492}