Skip to main content

tldr_core/patterns/
import_patterns.rs

1//! Import organization pattern detection
2//!
3//! Detects import patterns:
4//! - Absolute vs relative imports
5//! - Import grouping (stdlib, third-party, local)
6//! - Star import usage
7//! - Common alias conventions (np, pd, etc.)
8
9use std::collections::HashMap;
10
11use super::signals::PatternSignals;
12use crate::types::{
13    AliasConvention, Evidence, ImportGrouping, ImportPattern, ImportStyle, StarImportUsage,
14};
15
16/// Convert signals to import pattern
17pub fn signals_to_pattern(
18    signals: &PatternSignals,
19    evidence_limit: usize,
20) -> Option<ImportPattern> {
21    let import_patterns = &signals.import_patterns;
22
23    if !import_patterns.has_signals() {
24        return None;
25    }
26
27    // Determine absolute vs relative preference
28    let absolute_count = import_patterns.absolute_imports.len();
29    let relative_count = import_patterns.relative_imports.len();
30    let total_imports = absolute_count + relative_count;
31
32    let absolute_vs_relative = if total_imports == 0 {
33        ImportStyle::Mixed
34    } else {
35        let ratio = absolute_count as f64 / total_imports as f64;
36        if ratio >= 0.8 {
37            ImportStyle::Absolute
38        } else if ratio <= 0.2 {
39            ImportStyle::Relative
40        } else {
41            ImportStyle::Mixed
42        }
43    };
44
45    // Determine star import usage
46    let star_import_count = import_patterns.star_imports.len();
47    let star_imports = if star_import_count == 0 {
48        StarImportUsage::None
49    } else if star_import_count <= 2 {
50        StarImportUsage::Rare
51    } else {
52        StarImportUsage::Common
53    };
54
55    // Detect grouping style from collected groupings
56    let grouping_style = detect_grouping_style(&import_patterns.groupings);
57
58    // Convert aliases to AliasConvention
59    let alias_conventions = convert_aliases(&import_patterns.aliases);
60
61    // Collect evidence (limited)
62    let evidence: Vec<Evidence> = import_patterns
63        .star_imports
64        .iter()
65        .take(evidence_limit)
66        .cloned()
67        .collect();
68
69    Some(ImportPattern {
70        grouping_style,
71        absolute_vs_relative,
72        star_imports,
73        alias_conventions,
74        evidence,
75    })
76}
77
78/// Detect the import grouping style from collected groupings
79fn detect_grouping_style(groupings: &[super::signals::ImportGrouping]) -> ImportGrouping {
80    if groupings.is_empty() {
81        return ImportGrouping::Ungrouped;
82    }
83
84    // Count patterns observed across files
85    let mut stdlib_first_count = 0;
86    let mut local_first_count = 0;
87    let mut third_party_first_count = 0;
88
89    for grouping in groupings {
90        // Determine which type appears first (non-empty)
91        if !grouping.stdlib_imports.is_empty() {
92            if grouping.third_party_imports.is_empty() || !grouping.local_imports.is_empty() {
93                stdlib_first_count += 1;
94            }
95        } else if !grouping.third_party_imports.is_empty() {
96            third_party_first_count += 1;
97        } else if !grouping.local_imports.is_empty() {
98            local_first_count += 1;
99        }
100    }
101
102    // Determine majority pattern
103    if stdlib_first_count >= third_party_first_count && stdlib_first_count >= local_first_count {
104        if stdlib_first_count > 0 {
105            ImportGrouping::StdlibFirst
106        } else {
107            ImportGrouping::Ungrouped
108        }
109    } else if third_party_first_count >= local_first_count {
110        ImportGrouping::ThirdPartyFirst
111    } else {
112        ImportGrouping::LocalFirst
113    }
114}
115
116/// Convert alias map to AliasConvention list, filtering out identity aliases
117/// where the alias name equals the original module name (e.g. `echo -> echo`).
118fn convert_aliases(aliases: &HashMap<String, String>) -> Vec<AliasConvention> {
119    aliases
120        .iter()
121        .filter(|(module, alias)| module != alias)
122        .map(|(module, alias)| AliasConvention {
123            module: module.clone(),
124            alias: alias.clone(),
125            count: 1,
126        })
127        .collect()
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_no_signals_returns_none() {
136        let signals = PatternSignals::default();
137        assert!(signals_to_pattern(&signals, 3).is_none());
138    }
139
140    #[test]
141    fn test_absolute_imports_preferred() {
142        let mut signals = PatternSignals::default();
143        // Add 8 absolute imports
144        for i in 0..8 {
145            signals
146                .import_patterns
147                .absolute_imports
148                .push((format!("module_{}", i), "file.py".to_string()));
149        }
150        // Add 2 relative imports
151        signals
152            .import_patterns
153            .relative_imports
154            .push((".local".to_string(), "file.py".to_string()));
155        signals
156            .import_patterns
157            .relative_imports
158            .push((".utils".to_string(), "file.py".to_string()));
159
160        let pattern = signals_to_pattern(&signals, 3).unwrap();
161        assert_eq!(pattern.absolute_vs_relative, ImportStyle::Absolute);
162    }
163
164    #[test]
165    fn test_relative_imports_preferred() {
166        let mut signals = PatternSignals::default();
167        // Add 2 absolute imports
168        signals
169            .import_patterns
170            .absolute_imports
171            .push(("os".to_string(), "file.py".to_string()));
172        signals
173            .import_patterns
174            .absolute_imports
175            .push(("sys".to_string(), "file.py".to_string()));
176        // Add 8 relative imports
177        for i in 0..8 {
178            signals
179                .import_patterns
180                .relative_imports
181                .push((format!(".module_{}", i), "file.py".to_string()));
182        }
183
184        let pattern = signals_to_pattern(&signals, 3).unwrap();
185        assert_eq!(pattern.absolute_vs_relative, ImportStyle::Relative);
186    }
187
188    #[test]
189    fn test_star_imports_detected() {
190        let mut signals = PatternSignals::default();
191        signals
192            .import_patterns
193            .absolute_imports
194            .push(("module".to_string(), "file.py".to_string()));
195        signals.import_patterns.star_imports.push(Evidence::new(
196            "file.py",
197            5,
198            "from module import *",
199        ));
200
201        let pattern = signals_to_pattern(&signals, 3).unwrap();
202        assert_eq!(pattern.star_imports, StarImportUsage::Rare);
203    }
204
205    #[test]
206    fn test_alias_conventions_detected() {
207        let mut signals = PatternSignals::default();
208        signals
209            .import_patterns
210            .absolute_imports
211            .push(("numpy".to_string(), "file.py".to_string()));
212        signals
213            .import_patterns
214            .aliases
215            .insert("numpy".to_string(), "np".to_string());
216
217        let pattern = signals_to_pattern(&signals, 3).unwrap();
218        assert!(!pattern.alias_conventions.is_empty());
219        assert_eq!(pattern.alias_conventions[0].module, "numpy");
220        assert_eq!(pattern.alias_conventions[0].alias, "np");
221    }
222
223    #[test]
224    fn test_identity_aliases_filtered_out() {
225        let mut signals = PatternSignals::default();
226        signals
227            .import_patterns
228            .absolute_imports
229            .push(("click".to_string(), "file.py".to_string()));
230        // Identity alias: module name == alias name (should be filtered)
231        signals
232            .import_patterns
233            .aliases
234            .insert("echo".to_string(), "echo".to_string());
235        signals
236            .import_patterns
237            .aliases
238            .insert("style".to_string(), "style".to_string());
239        // Non-identity alias: module name != alias name (should be kept)
240        signals
241            .import_patterns
242            .aliases
243            .insert("typing".to_string(), "t".to_string());
244        signals
245            .import_patterns
246            .aliases
247            .insert("collections.abc".to_string(), "cabc".to_string());
248
249        let pattern = signals_to_pattern(&signals, 3).unwrap();
250        // Only the non-identity aliases should remain
251        assert_eq!(pattern.alias_conventions.len(), 2);
252        let modules: Vec<&str> = pattern
253            .alias_conventions
254            .iter()
255            .map(|a| a.module.as_str())
256            .collect();
257        assert!(modules.contains(&"typing"));
258        assert!(modules.contains(&"collections.abc"));
259        // Identity aliases should NOT be present
260        assert!(!modules.contains(&"echo"));
261        assert!(!modules.contains(&"style"));
262    }
263
264    #[test]
265    fn test_all_identity_aliases_results_in_empty_list() {
266        let mut signals = PatternSignals::default();
267        signals
268            .import_patterns
269            .absolute_imports
270            .push(("click".to_string(), "file.py".to_string()));
271        // All identity aliases
272        signals
273            .import_patterns
274            .aliases
275            .insert("echo".to_string(), "echo".to_string());
276        signals
277            .import_patterns
278            .aliases
279            .insert("option".to_string(), "option".to_string());
280
281        let pattern = signals_to_pattern(&signals, 3).unwrap();
282        assert!(pattern.alias_conventions.is_empty());
283    }
284}