Skip to main content

tldr_core/inheritance/
mod.rs

1//! Inheritance analysis module for class hierarchy extraction
2//!
3//! This module provides class hierarchy extraction and analysis for:
4//! - Python classes (with ABC, Protocol, metaclass support - A12)
5//! - TypeScript classes and interfaces
6//! - Go struct embedding (modeled as composition - A14)
7//! - Rust trait impl blocks (A16)
8//! - Java classes, interfaces, enums, and records
9//! - Kotlin classes, interfaces, objects, and data classes
10//! - Scala classes, traits, objects, and case classes
11//! - Swift classes, protocols, structs, and enums
12//! - C# classes, interfaces, and structs
13//! - Ruby classes and modules
14//! - PHP classes, interfaces, and traits
15//!
16//! # Architecture
17//!
18//! 1. Extract classes from source files using tree-sitter
19//! 2. Build inheritance graph with edges for extends/implements/embeds
20//! 3. Detect patterns: ABC/Protocol, mixins, diamonds
21//! 4. Resolve external bases (stdlib vs project vs unresolved)
22//!
23//! # Mitigations Addressed
24//!
25//! - A2: Diamond detection using BFS + set intersection (O(|ancestors|) not O(n^3))
26//! - A12: Python metaclass extraction via keywords
27//! - A14: Go struct embedding as Embeds edges
28//! - A16: Rust trait impl blocks as Implements edges
29//! - A17: --depth without --class validation
30//! - A19: DOT output escaping for special characters
31//!
32//! # Example
33//!
34//! ```rust,ignore
35//! use tldr_core::inheritance::{extract_inheritance, InheritanceOptions};
36//!
37//! let options = InheritanceOptions::default();
38//! let report = extract_inheritance(Path::new("src"), Some(Language::Python), &options)?;
39//! println!("Found {} classes", report.count);
40//! ```
41
42pub mod csharp;
43pub mod filter;
44pub mod format;
45pub mod go;
46pub mod java;
47pub mod kotlin;
48pub mod patterns;
49pub mod php;
50pub mod python;
51pub mod resolve;
52pub mod ruby;
53pub mod rust;
54pub mod scala;
55pub mod swift;
56pub mod typescript;
57
58use std::collections::HashSet;
59use std::path::{Path, PathBuf};
60use std::time::Instant;
61
62use walkdir::WalkDir;
63
64use crate::ast::parser::ParserPool;
65use crate::error::TldrError;
66use crate::types::{
67    BaseResolution, InheritanceEdge, InheritanceGraph, InheritanceReport, Language,
68};
69use crate::TldrResult;
70
71pub use filter::{filter_by_class, get_fuzzy_suggestions};
72pub use format::{escape_dot_string, format_dot, format_text};
73pub use patterns::{detect_abc_protocol, detect_diamonds, detect_mixins};
74pub use resolve::{is_stdlib_class, resolve_base, PYTHON_STDLIB_CLASSES};
75
76/// Options for inheritance analysis
77#[derive(Debug, Clone, Default)]
78pub struct InheritanceOptions {
79    /// Filter to specific class (show ancestors + descendants)
80    pub class_filter: Option<String>,
81    /// Limit traversal depth (requires class_filter)
82    pub depth: Option<usize>,
83    /// Skip external base resolution
84    pub no_external: bool,
85    /// Skip ABC/mixin/diamond detection
86    pub no_patterns: bool,
87    /// Maximum nodes for DOT output (A39)
88    pub max_nodes: Option<usize>,
89    /// Cluster nodes by file in DOT output (A39)
90    pub cluster_by_file: bool,
91}
92
93impl InheritanceOptions {
94    /// Validate options - depth requires class_filter (A17)
95    pub fn validate(&self) -> TldrResult<()> {
96        if self.depth.is_some() && self.class_filter.is_none() {
97            return Err(TldrError::InvalidArgs {
98                arg: "--depth".to_string(),
99                message: "--depth requires --class. Use --class <NAME> --depth N to limit traversal depth.".to_string(),
100                suggestion: Some("To scan entire project without depth limit, omit --depth.".to_string()),
101            });
102        }
103        Ok(())
104    }
105}
106
107/// Main entry point for inheritance analysis
108pub fn extract_inheritance(
109    path: &Path,
110    lang: Option<Language>,
111    options: &InheritanceOptions,
112) -> TldrResult<InheritanceReport> {
113    // Validate options first (A17)
114    options.validate()?;
115
116    let start = Instant::now();
117    let parser_pool = ParserPool::new();
118
119    // Collect files matching language filter
120    let files = collect_source_files(path, lang);
121    if files.is_empty() {
122        return Ok(InheritanceReport::new(path.to_path_buf()));
123    }
124
125    // Build inheritance graph
126    let mut graph = InheritanceGraph::new();
127    let mut languages_seen = HashSet::new();
128
129    for file_path in &files {
130        let file_lang = Language::from_path(file_path).unwrap_or(Language::Python);
131
132        // Skip if language filter is specified and doesn't match
133        if let Some(filter_lang) = lang {
134            if file_lang != filter_lang {
135                continue;
136            }
137        }
138
139        languages_seen.insert(file_lang);
140
141        // Extract classes based on language
142        let source = match std::fs::read_to_string(file_path) {
143            Ok(s) => s,
144            Err(_) => continue, // Skip unreadable files
145        };
146
147        let classes = match file_lang {
148            Language::Python => python::extract_classes(&source, file_path, &parser_pool)?,
149            Language::TypeScript | Language::JavaScript => {
150                typescript::extract_classes(&source, file_path, &parser_pool)?
151            }
152            Language::Go => go::extract_classes(&source, file_path, &parser_pool)?,
153            Language::Rust => rust::extract_classes(&source, file_path, &parser_pool)?,
154            Language::Java => java::extract_classes(&source, file_path, &parser_pool)?,
155            Language::Kotlin => kotlin::extract_classes(&source, file_path, &parser_pool)?,
156            Language::Scala => scala::extract_classes(&source, file_path, &parser_pool)?,
157            Language::Swift => swift::extract_classes(&source, file_path, &parser_pool)?,
158            Language::CSharp => csharp::extract_classes(&source, file_path, &parser_pool)?,
159            Language::Ruby => ruby::extract_classes(&source, file_path, &parser_pool)?,
160            Language::Php => php::extract_classes(&source, file_path, &parser_pool)?,
161            _ => Vec::new(), // Unsupported language
162        };
163
164        // Add classes to graph
165        for class in classes {
166            let class_name = class.name.clone();
167            let bases = class.bases.clone();
168
169            graph.add_node(class);
170
171            // Add edges for each base
172            for base in bases {
173                graph.add_edge(&class_name, &base);
174            }
175        }
176    }
177
178    // Resolve external bases unless disabled
179    if !options.no_external {
180        resolve::resolve_all_bases(&mut graph, path)?;
181    }
182
183    // Detect patterns unless disabled
184    let diamonds = if options.no_patterns {
185        Vec::new()
186    } else {
187        // Detect ABC/Protocol/Interface
188        patterns::detect_abc_protocol(&mut graph);
189        // Detect mixins
190        patterns::detect_mixins(&mut graph);
191        // Detect diamonds
192        patterns::detect_diamonds(&graph)
193    };
194
195    // Apply class filter if specified
196    let filtered_graph = if let Some(ref class_name) = options.class_filter {
197        filter::filter_by_class(&graph, class_name, options.depth)?
198    } else {
199        graph
200    };
201
202    // Build report
203    let mut report = InheritanceReport::new(path.to_path_buf());
204    report.count = filtered_graph.nodes.len();
205    report.languages = languages_seen.into_iter().collect();
206    report.scan_time_ms = start.elapsed().as_millis() as u64;
207    report.diamonds = diamonds;
208
209    // Convert graph to edges and nodes for report
210    report.nodes = filtered_graph.nodes.values().cloned().collect();
211    report.edges = build_edges(&filtered_graph, path);
212    report.roots = filtered_graph.find_roots();
213    report.leaves = filtered_graph.find_leaves();
214
215    Ok(report)
216}
217
218/// Collect source files matching the optional language filter
219fn collect_source_files(path: &Path, lang: Option<Language>) -> Vec<PathBuf> {
220    let mut files = Vec::new();
221
222    if path.is_file() {
223        // Single file
224        if let Some(file_lang) = Language::from_path(path) {
225            if lang.is_none() || lang == Some(file_lang) {
226                files.push(path.to_path_buf());
227            }
228        }
229        return files;
230    }
231
232    // Walk directory
233    for entry in WalkDir::new(path)
234        .follow_links(true)
235        .into_iter()
236        .filter_map(|e| e.ok())
237    {
238        let entry_path = entry.path();
239
240        // Skip hidden files and directories
241        if entry_path
242            .file_name()
243            .map(|n| n.to_string_lossy().starts_with('.'))
244            .unwrap_or(false)
245        {
246            continue;
247        }
248
249        // Skip non-files
250        if !entry_path.is_file() {
251            continue;
252        }
253
254        // Check language
255        if let Some(file_lang) = Language::from_path(entry_path) {
256            if lang.is_none() || lang == Some(file_lang) {
257                files.push(entry_path.to_path_buf());
258            }
259        }
260    }
261
262    files
263}
264
265/// Build InheritanceEdge structs from graph
266fn build_edges(graph: &InheritanceGraph, _project_root: &Path) -> Vec<InheritanceEdge> {
267    let mut edges = Vec::new();
268
269    for (child_name, parents) in &graph.parents {
270        let child_node = match graph.nodes.get(child_name) {
271            Some(n) => n,
272            None => continue,
273        };
274
275        for parent_name in parents {
276            let parent_node = graph.nodes.get(parent_name);
277            let (resolution, external) = if parent_node.is_some() {
278                (BaseResolution::Project, false)
279            } else if resolve::is_stdlib_class(parent_name, child_node.language) {
280                (BaseResolution::Stdlib, true)
281            } else {
282                (BaseResolution::Unresolved, true)
283            };
284
285            let edge = if external {
286                if resolution == BaseResolution::Stdlib {
287                    InheritanceEdge::stdlib(
288                        child_name,
289                        parent_name,
290                        child_node.file.clone(),
291                        child_node.line,
292                    )
293                } else {
294                    InheritanceEdge::unresolved(
295                        child_name,
296                        parent_name,
297                        child_node.file.clone(),
298                        child_node.line,
299                    )
300                }
301            } else {
302                let pn = parent_node.unwrap();
303                InheritanceEdge::project(
304                    child_name,
305                    parent_name,
306                    child_node.file.clone(),
307                    child_node.line,
308                    pn.file.clone(),
309                    pn.line,
310                )
311            };
312
313            edges.push(edge);
314        }
315    }
316
317    edges
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use tempfile::TempDir;
324
325    fn create_test_file(dir: &TempDir, name: &str, content: &str) -> PathBuf {
326        let path = dir.path().join(name);
327        if let Some(parent) = path.parent() {
328            std::fs::create_dir_all(parent).unwrap();
329        }
330        std::fs::write(&path, content).unwrap();
331        path
332    }
333
334    #[test]
335    fn test_options_validation_depth_without_class() {
336        let options = InheritanceOptions {
337            depth: Some(3),
338            class_filter: None,
339            ..Default::default()
340        };
341
342        let result = options.validate();
343        assert!(result.is_err());
344        let err = result.unwrap_err();
345        assert!(err.to_string().contains("--depth requires --class"));
346    }
347
348    #[test]
349    fn test_options_validation_depth_with_class() {
350        let options = InheritanceOptions {
351            depth: Some(3),
352            class_filter: Some("MyClass".to_string()),
353            ..Default::default()
354        };
355
356        assert!(options.validate().is_ok());
357    }
358
359    #[test]
360    fn test_extract_empty_project() {
361        let dir = TempDir::new().unwrap();
362        create_test_file(&dir, "empty.py", "# No classes here\npass\n");
363
364        let options = InheritanceOptions::default();
365        let report = extract_inheritance(dir.path(), Some(Language::Python), &options).unwrap();
366
367        assert!(report.nodes.is_empty());
368        assert!(report.edges.is_empty());
369        assert_eq!(report.count, 0);
370    }
371}