rust_diff_analyzer/analysis/
ast_visitor.rs1use proc_macro2::Span;
5use syn::{
6 Attribute, File, ImplItem, ItemConst, ItemEnum, ItemFn, ItemImpl, ItemMacro, ItemMod,
7 ItemStatic, ItemStruct, ItemTrait, ItemType, TraitItem, Visibility as SynVisibility,
8 spanned::Spanned, visit::Visit,
9};
10
11use crate::types::{LineSpan, SemanticUnit, SemanticUnitKind, Visibility};
12
13pub struct SemanticUnitVisitor {
15 units: Vec<SemanticUnit>,
16 in_test_module: bool,
17}
18
19impl SemanticUnitVisitor {
20 pub fn new() -> Self {
34 Self {
35 units: Vec::new(),
36 in_test_module: false,
37 }
38 }
39
40 pub fn extract(file: &File) -> Vec<SemanticUnit> {
61 let mut visitor = Self::new();
62 visitor.visit_file(file);
63 visitor.units
64 }
65
66 fn span_to_line_span(&self, span: Span) -> LineSpan {
67 let start = span.start();
68 let end = span.end();
69 LineSpan::new(start.line, end.line)
70 }
71
72 fn convert_visibility(&self, vis: &SynVisibility) -> Visibility {
73 match vis {
74 SynVisibility::Public(_) => Visibility::Public,
75 SynVisibility::Restricted(r) => {
76 if r.path.is_ident("crate") {
77 Visibility::Crate
78 } else {
79 Visibility::Restricted
80 }
81 }
82 SynVisibility::Inherited => Visibility::Private,
83 }
84 }
85
86 fn extract_attributes(&self, attrs: &[Attribute]) -> Vec<String> {
87 attrs
88 .iter()
89 .filter_map(|attr| attr.path().get_ident().map(|ident| ident.to_string()))
90 .collect()
91 }
92
93 fn has_test_attribute(&self, attrs: &[Attribute]) -> bool {
94 attrs.iter().any(|attr| {
95 let path = attr.path();
96 if path.is_ident("test") || path.is_ident("bench") {
97 return true;
98 }
99 if path.is_ident("cfg")
100 && let Ok(meta) = attr.meta.require_list()
101 {
102 let tokens = meta.tokens.to_string();
103 if tokens.contains("test") {
104 return true;
105 }
106 }
107 false
108 })
109 }
110
111 fn is_test_module(&self, attrs: &[Attribute]) -> bool {
112 attrs.iter().any(|attr| {
113 if attr.path().is_ident("cfg")
114 && let Ok(meta) = attr.meta.require_list()
115 {
116 let tokens = meta.tokens.to_string();
117 return tokens.contains("test");
118 }
119 false
120 })
121 }
122
123 fn add_unit(
124 &mut self,
125 kind: SemanticUnitKind,
126 name: String,
127 visibility: Visibility,
128 span: Span,
129 attrs: &[Attribute],
130 ) {
131 let mut attributes = self.extract_attributes(attrs);
132
133 if self.in_test_module && !attributes.contains(&"cfg_test".to_string()) {
134 attributes.push("cfg_test".to_string());
135 }
136
137 if self.has_test_attribute(attrs) && !attributes.contains(&"test".to_string()) {
138 attributes.push("test".to_string());
139 }
140
141 let unit = SemanticUnit::new(
142 kind,
143 name,
144 visibility,
145 self.span_to_line_span(span),
146 attributes,
147 );
148 self.units.push(unit);
149 }
150}
151
152impl Default for SemanticUnitVisitor {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl<'ast> Visit<'ast> for SemanticUnitVisitor {
159 fn visit_item_fn(&mut self, node: &'ast ItemFn) {
160 self.add_unit(
161 SemanticUnitKind::Function,
162 node.sig.ident.to_string(),
163 self.convert_visibility(&node.vis),
164 node.span(),
165 &node.attrs,
166 );
167 syn::visit::visit_item_fn(self, node);
168 }
169
170 fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
171 self.add_unit(
172 SemanticUnitKind::Struct,
173 node.ident.to_string(),
174 self.convert_visibility(&node.vis),
175 node.span(),
176 &node.attrs,
177 );
178 syn::visit::visit_item_struct(self, node);
179 }
180
181 fn visit_item_enum(&mut self, node: &'ast ItemEnum) {
182 self.add_unit(
183 SemanticUnitKind::Enum,
184 node.ident.to_string(),
185 self.convert_visibility(&node.vis),
186 node.span(),
187 &node.attrs,
188 );
189 syn::visit::visit_item_enum(self, node);
190 }
191
192 fn visit_item_trait(&mut self, node: &'ast ItemTrait) {
193 self.add_unit(
194 SemanticUnitKind::Trait,
195 node.ident.to_string(),
196 self.convert_visibility(&node.vis),
197 node.span(),
198 &node.attrs,
199 );
200 syn::visit::visit_item_trait(self, node);
201 }
202
203 fn visit_item_impl(&mut self, node: &'ast ItemImpl) {
204 let name = if let Some((_, path, _)) = &node.trait_ {
205 format!(
206 "{} for {}",
207 path.segments
208 .last()
209 .map(|s| s.ident.to_string())
210 .unwrap_or_default(),
211 type_to_string(&node.self_ty)
212 )
213 } else {
214 type_to_string(&node.self_ty)
215 };
216
217 self.add_unit(
218 SemanticUnitKind::Impl,
219 name,
220 Visibility::Private,
221 node.span(),
222 &node.attrs,
223 );
224
225 for item in &node.items {
226 match item {
227 ImplItem::Fn(method) => {
228 self.add_unit(
229 SemanticUnitKind::Function,
230 method.sig.ident.to_string(),
231 self.convert_visibility(&method.vis),
232 method.span(),
233 &method.attrs,
234 );
235 }
236 ImplItem::Const(c) => {
237 self.add_unit(
238 SemanticUnitKind::Const,
239 c.ident.to_string(),
240 self.convert_visibility(&c.vis),
241 c.span(),
242 &c.attrs,
243 );
244 }
245 ImplItem::Type(t) => {
246 self.add_unit(
247 SemanticUnitKind::TypeAlias,
248 t.ident.to_string(),
249 self.convert_visibility(&t.vis),
250 t.span(),
251 &t.attrs,
252 );
253 }
254 _ => {}
255 }
256 }
257 }
258
259 fn visit_item_const(&mut self, node: &'ast ItemConst) {
260 self.add_unit(
261 SemanticUnitKind::Const,
262 node.ident.to_string(),
263 self.convert_visibility(&node.vis),
264 node.span(),
265 &node.attrs,
266 );
267 }
268
269 fn visit_item_static(&mut self, node: &'ast ItemStatic) {
270 self.add_unit(
271 SemanticUnitKind::Static,
272 node.ident.to_string(),
273 self.convert_visibility(&node.vis),
274 node.span(),
275 &node.attrs,
276 );
277 }
278
279 fn visit_item_type(&mut self, node: &'ast ItemType) {
280 self.add_unit(
281 SemanticUnitKind::TypeAlias,
282 node.ident.to_string(),
283 self.convert_visibility(&node.vis),
284 node.span(),
285 &node.attrs,
286 );
287 }
288
289 fn visit_item_macro(&mut self, node: &'ast ItemMacro) {
290 if let Some(ident) = &node.ident {
291 self.add_unit(
292 SemanticUnitKind::Macro,
293 ident.to_string(),
294 Visibility::Private,
295 node.span(),
296 &node.attrs,
297 );
298 }
299 }
300
301 fn visit_item_mod(&mut self, node: &'ast ItemMod) {
302 let is_test = self.is_test_module(&node.attrs) || node.ident == "tests";
303
304 self.add_unit(
305 SemanticUnitKind::Module,
306 node.ident.to_string(),
307 self.convert_visibility(&node.vis),
308 node.span(),
309 &node.attrs,
310 );
311
312 if let Some((_, items)) = &node.content {
313 let was_in_test = self.in_test_module;
314 self.in_test_module = is_test || was_in_test;
315
316 for item in items {
317 self.visit_item(item);
318 }
319
320 self.in_test_module = was_in_test;
321 }
322 }
323
324 fn visit_trait_item(&mut self, node: &'ast TraitItem) {
325 match node {
326 TraitItem::Fn(method) => {
327 self.add_unit(
328 SemanticUnitKind::Function,
329 method.sig.ident.to_string(),
330 Visibility::Public,
331 method.span(),
332 &method.attrs,
333 );
334 }
335 TraitItem::Const(c) => {
336 self.add_unit(
337 SemanticUnitKind::Const,
338 c.ident.to_string(),
339 Visibility::Public,
340 c.span(),
341 &c.attrs,
342 );
343 }
344 TraitItem::Type(t) => {
345 self.add_unit(
346 SemanticUnitKind::TypeAlias,
347 t.ident.to_string(),
348 Visibility::Public,
349 t.span(),
350 &t.attrs,
351 );
352 }
353 _ => {}
354 }
355 syn::visit::visit_trait_item(self, node);
356 }
357}
358
359fn type_to_string(ty: &syn::Type) -> String {
360 match ty {
361 syn::Type::Path(p) => p
362 .path
363 .segments
364 .last()
365 .map(|s| s.ident.to_string())
366 .unwrap_or_else(|| "Unknown".to_string()),
367 _ => "Unknown".to_string(),
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_extract_function() {
377 let code = "pub fn hello() {}";
378 let file = syn::parse_file(code).expect("parse failed");
379 let units = SemanticUnitVisitor::extract(&file);
380
381 assert_eq!(units.len(), 1);
382 assert_eq!(units[0].name, "hello");
383 assert!(matches!(units[0].kind, SemanticUnitKind::Function));
384 assert!(matches!(units[0].visibility, Visibility::Public));
385 }
386
387 #[test]
388 fn test_extract_struct() {
389 let code = "struct Point { x: i32, y: i32 }";
390 let file = syn::parse_file(code).expect("parse failed");
391 let units = SemanticUnitVisitor::extract(&file);
392
393 assert_eq!(units.len(), 1);
394 assert_eq!(units[0].name, "Point");
395 assert!(matches!(units[0].kind, SemanticUnitKind::Struct));
396 }
397
398 #[test]
399 fn test_extract_test_function() {
400 let code = r#"
401 #[test]
402 fn test_something() {}
403 "#;
404 let file = syn::parse_file(code).expect("parse failed");
405 let units = SemanticUnitVisitor::extract(&file);
406
407 assert_eq!(units.len(), 1);
408 assert!(units[0].has_attribute("test"));
409 }
410
411 #[test]
412 fn test_extract_impl_block() {
413 let code = r#"
414 struct Foo;
415 impl Foo {
416 pub fn new() -> Self { Foo }
417 }
418 "#;
419 let file = syn::parse_file(code).expect("parse failed");
420 let units = SemanticUnitVisitor::extract(&file);
421
422 assert_eq!(units.len(), 3);
423 assert!(
424 units
425 .iter()
426 .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Struct))
427 );
428 assert!(
429 units
430 .iter()
431 .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Impl))
432 );
433 assert!(
434 units
435 .iter()
436 .any(|u| u.name == "new" && matches!(u.kind, SemanticUnitKind::Function))
437 );
438 }
439
440 #[test]
441 fn test_extract_test_module() {
442 let code = r#"
443 fn production() {}
444
445 #[cfg(test)]
446 mod tests {
447 fn helper() {}
448
449 #[test]
450 fn test_it() {}
451 }
452 "#;
453 let file = syn::parse_file(code).expect("parse failed");
454 let units = SemanticUnitVisitor::extract(&file);
455
456 let prod_fn = units
457 .iter()
458 .find(|u| u.name == "production")
459 .expect("production not found");
460 assert!(!prod_fn.has_attribute("cfg_test"));
461
462 let helper_fn = units
463 .iter()
464 .find(|u| u.name == "helper")
465 .expect("helper not found");
466 assert!(helper_fn.has_attribute("cfg_test"));
467
468 let test_fn = units
469 .iter()
470 .find(|u| u.name == "test_it")
471 .expect("test_it not found");
472 assert!(test_fn.has_attribute("test"));
473 assert!(test_fn.has_attribute("cfg_test"));
474 }
475}