sem_core/parser/
registry.rs1use std::collections::HashMap;
2use std::path::Path;
3
4use super::plugin::SemanticParserPlugin;
5
6pub struct ParserRegistry {
7 plugins: Vec<Box<dyn SemanticParserPlugin>>,
8 extension_map: HashMap<String, usize>, }
10
11impl ParserRegistry {
12 pub fn new() -> Self {
13 Self {
14 plugins: Vec::new(),
15 extension_map: HashMap::new(),
16 }
17 }
18
19 pub fn register(&mut self, plugin: Box<dyn SemanticParserPlugin>) {
20 let idx = self.plugins.len();
21 for ext in plugin.extensions() {
22 self.extension_map.insert(ext.to_string(), idx);
23 }
24 self.plugins.push(plugin);
25 }
26
27 pub fn get_plugin(&self, file_path: &str) -> Option<&dyn SemanticParserPlugin> {
28 for ext in get_extensions(file_path) {
29 if let Some(&idx) = self.extension_map.get(&ext) {
30 return Some(self.plugins[idx].as_ref());
31 }
32 }
33 self.get_plugin_by_id("fallback")
35 }
36
37 pub fn get_plugin_with_content(&self, file_path: &str, content: &str) -> Option<&dyn SemanticParserPlugin> {
40 for ext in get_extensions(file_path) {
42 if let Some(&idx) = self.extension_map.get(&ext) {
43 return Some(self.plugins[idx].as_ref());
44 }
45 }
46 if let Some(plugin) = self.detect_from_shebang(content) {
48 return Some(plugin);
49 }
50 self.get_plugin_by_id("fallback")
52 }
53
54 fn detect_from_shebang(&self, content: &str) -> Option<&dyn SemanticParserPlugin> {
55 if let Some(ext) = detect_ext_from_content(content) {
56 if let Some(&idx) = self.extension_map.get(ext.as_str()) {
57 return Some(self.plugins[idx].as_ref());
58 }
59 }
60 None
61 }
62
63 pub fn get_plugin_by_id(&self, id: &str) -> Option<&dyn SemanticParserPlugin> {
64 self.plugins
65 .iter()
66 .find(|p| p.id() == id)
67 .map(|p| p.as_ref())
68 }
69}
70
71fn get_extensions(file_path: &str) -> Vec<String> {
72 let Some(file_name) = Path::new(file_path)
73 .file_name()
74 .and_then(|name| name.to_str())
75 else {
76 return Vec::new();
77 };
78
79 let file_name = file_name.to_lowercase();
80 let mut extensions = Vec::new();
81
82 for (idx, ch) in file_name.char_indices() {
83 if ch == '.' {
84 extensions.push(file_name[idx..].to_string());
85 }
86 }
87
88 extensions
89}
90
91const LANG_MAPPING: &[(&str, &str)] = &[
92 ("perl", ".pl"),
93 ("python", ".py"),
94 ("ruby", ".rb"),
95 ("bash", ".sh"),
96 ("/sh", ".sh"),
97 ("node", ".js"),
98 ("javascript", ".js"),
99 ("typescript", ".ts"),
100 ("swift", ".swift"),
101 ("elixir", ".ex"),
102 ("rust", ".rs"),
103 ("go", ".go"),
104 ("kotlin", ".kt"),
105 ("dart", ".dart"),
106 ("php", ".php"),
107 ("java", ".java"),
108 ("c", ".c"),
109 ("cpp", ".cpp"),
110 ("cs", ".cs"),
111 ("csharp", ".cs"),
112 ("fortran", ".f90"),
113 ("terraform", ".tf"),
114 ("hcl", ".hcl"),
115 ("ocaml", ".ml"),
116 ("eruby", ".erb"),
117 ("vue", ".vue"),
118 ("svelte", ".svelte"),
119];
120
121pub fn detect_ext_from_content(content: &str) -> Option<String> {
123 if let Some(first_line) = content.lines().next() {
125 if first_line.starts_with("#!") {
126 let shebang = first_line.to_lowercase();
127 for (keyword, ext) in LANG_MAPPING {
128 if shebang.contains(keyword) {
129 return Some(ext.to_string());
130 }
131 }
132 }
133 }
134
135 let lines: Vec<&str> = content.lines().collect();
138 let check_lines = lines.iter().take(5).chain(lines.iter().rev().take(5));
139 for line in check_lines {
140 if let Some(ft) = extract_vim_filetype(line) {
141 let ft_lower = ft.to_lowercase();
142 for (keyword, ext) in LANG_MAPPING {
143 if ft_lower == *keyword {
144 return Some(ext.to_string());
145 }
146 }
147 }
148 }
149
150 None
151}
152
153fn extract_vim_filetype(line: &str) -> Option<&str> {
154 let line = line.trim();
156 let vim_idx = line.find("vim:")?;
157 let after_vim = &line[vim_idx + 4..];
158
159 for token in after_vim.split_whitespace() {
160 if let Some(val) = token.strip_prefix("ft=") {
161 return Some(val.trim_end_matches(':'));
162 }
163 if let Some(val) = token.strip_prefix("filetype=") {
164 return Some(val.trim_end_matches(':'));
165 }
166 }
167 None
168}
169
170#[cfg(test)]
171mod tests {
172 use crate::parser::plugins::create_default_registry;
173
174 #[test]
175 fn test_registry_matches_compound_svelte_typescript_suffix() {
176 let registry = create_default_registry();
177 let plugin = registry
178 .get_plugin("src/routes/+page.svelte.ts")
179 .expect("plugin should exist");
180
181 assert_eq!(plugin.id(), "svelte");
182 }
183
184 #[test]
185 fn test_registry_matches_compound_svelte_javascript_suffix() {
186 let registry = create_default_registry();
187 let plugin = registry
188 .get_plugin("src/routes/+layout.svelte.js")
189 .expect("plugin should exist");
190
191 assert_eq!(plugin.id(), "svelte");
192 }
193
194 #[test]
195 fn test_registry_matches_svelte_test_suffix() {
196 let registry = create_default_registry();
197 let plugin = registry
198 .get_plugin("src/lib/multiplier.svelte.test.js")
199 .expect("plugin should exist");
200
201 assert_eq!(plugin.id(), "svelte");
202 }
203
204 #[test]
205 fn test_registry_prefers_svelte_plugin_for_component_files() {
206 let registry = create_default_registry();
207 let plugin = registry
208 .get_plugin("src/lib/Component.svelte")
209 .expect("plugin should exist");
210
211 assert_eq!(plugin.id(), "svelte");
212 }
213
214 #[test]
215 fn test_registry_matches_typescript_module_suffix() {
216 let registry = create_default_registry();
217 let plugin = registry
218 .get_plugin("src/lib/index.mts")
219 .expect("plugin should exist");
220
221 assert_eq!(plugin.id(), "code");
222 }
223
224 #[test]
225 fn test_registry_matches_typescript_commonjs_suffix() {
226 let registry = create_default_registry();
227 let plugin = registry
228 .get_plugin("src/lib/index.cts")
229 .expect("plugin should exist");
230
231 assert_eq!(plugin.id(), "code");
232 }
233}