1use std::collections::HashMap;
2use std::path::Path;
3use rayon::prelude::*;
4
5use crate::model::entity::SemanticEntity;
6use super::plugin::SemanticParserPlugin;
7
8pub struct ParserRegistry {
9 plugins: Vec<Box<dyn SemanticParserPlugin>>,
10 extension_map: HashMap<String, usize>, custom_ext_canonical: HashMap<String, String>, }
13
14impl ParserRegistry {
15 pub fn new() -> Self {
16 Self {
17 plugins: Vec::new(),
18 extension_map: HashMap::new(),
19 custom_ext_canonical: HashMap::new(),
20 }
21 }
22
23 pub fn register(&mut self, plugin: Box<dyn SemanticParserPlugin>) {
24 let idx = self.plugins.len();
25 for ext in plugin.extensions() {
26 self.extension_map.insert(ext.to_string(), idx);
27 }
28 self.plugins.push(plugin);
29 }
30
31 pub fn get_plugin(&self, file_path: &str) -> Option<&dyn SemanticParserPlugin> {
32 for ext in get_extensions(file_path) {
33 if let Some(&idx) = self.extension_map.get(&ext) {
34 return Some(self.plugins[idx].as_ref());
35 }
36 }
37 self.get_plugin_by_id("fallback")
39 }
40
41 pub fn get_plugin_with_content(&self, file_path: &str, content: &str) -> Option<&dyn SemanticParserPlugin> {
44 for ext in get_extensions(file_path) {
46 if let Some(&idx) = self.extension_map.get(&ext) {
47 return Some(self.plugins[idx].as_ref());
48 }
49 }
50 if let Some(plugin) = self.detect_from_shebang(content) {
52 return Some(plugin);
53 }
54 self.get_plugin_by_id("fallback")
56 }
57
58 fn detect_from_shebang(&self, content: &str) -> Option<&dyn SemanticParserPlugin> {
59 if let Some(ext) = detect_ext_from_content(content) {
60 if let Some(&idx) = self.extension_map.get(ext.as_str()) {
61 return Some(self.plugins[idx].as_ref());
62 }
63 }
64 None
65 }
66
67 pub fn get_plugin_by_id(&self, id: &str) -> Option<&dyn SemanticParserPlugin> {
68 self.plugins
69 .iter()
70 .find(|p| p.id() == id)
71 .map(|p| p.as_ref())
72 }
73
74 pub fn add_extension_mapping(&mut self, ext: &str, language: &str) -> bool {
77 let ext = if ext.starts_with('.') {
78 ext.to_lowercase()
79 } else {
80 format!(".{}", ext.to_lowercase())
81 };
82
83 let target_ext = LANG_MAPPING
85 .iter()
86 .find(|(kw, _)| *kw == language.to_lowercase())
87 .map(|(_, e)| *e);
88
89 if let Some(target) = target_ext {
90 if let Some(&idx) = self.extension_map.get(target) {
91 self.custom_ext_canonical.insert(ext.clone(), target.to_string());
92 self.extension_map.insert(ext, idx);
93 return true;
94 }
95 }
96
97 let direct_ext = format!(".{}", language.to_lowercase());
99 if let Some(&idx) = self.extension_map.get(&direct_ext) {
100 self.custom_ext_canonical.insert(ext.clone(), direct_ext);
101 self.extension_map.insert(ext, idx);
102 return true;
103 }
104
105 false
106 }
107
108 pub fn load_semrc(&mut self, root: &Path) {
115 let semrc_path = root.join(".semrc");
116 if !semrc_path.exists() {
117 return;
118 }
119 let content = match std::fs::read_to_string(&semrc_path) {
120 Ok(c) => c,
121 Err(_) => return,
122 };
123 for line in content.lines() {
124 let line = line.trim();
125 if line.is_empty() || line.starts_with('#') {
126 continue;
127 }
128 if let Some((ext, lang)) = line.split_once('=') {
129 self.add_extension_mapping(ext.trim(), lang.trim());
130 }
131 }
132 }
133
134 pub fn load_gitattributes(&mut self, root: &Path) {
138 let ga_path = root.join(".gitattributes");
139 if !ga_path.exists() {
140 return;
141 }
142 let content = match std::fs::read_to_string(&ga_path) {
143 Ok(c) => c,
144 Err(_) => return,
145 };
146 for line in content.lines() {
147 let line = line.trim();
148 if line.is_empty() || line.starts_with('#') {
149 continue;
150 }
151 let mut parts = line.split_whitespace();
152 let pattern = match parts.next() {
153 Some(p) => p,
154 None => continue,
155 };
156 let ext = match pattern.strip_prefix("*.") {
158 Some(e) => e,
159 None => continue,
160 };
161 let ext_key = format!(".{}", ext.to_lowercase());
163 if self.custom_ext_canonical.contains_key(&ext_key) {
164 continue;
165 }
166 for attr in parts {
168 if let Some(lang) = attr.strip_prefix("diff=") {
169 self.add_extension_mapping(ext, lang);
170 break;
171 }
172 if let Some(lang) = attr.strip_prefix("linguist-language=") {
173 self.add_extension_mapping(ext, lang);
174 break;
175 }
176 }
177 }
178 }
179
180 pub fn resolve_file_path(&self, file_path: &str) -> Option<String> {
184 let path = Path::new(file_path);
185 let ext = path
186 .extension()
187 .and_then(|e| e.to_str())
188 .map(|e| format!(".{}", e.to_lowercase()))?;
189
190 let canonical = self.custom_ext_canonical.get(&ext)?;
191 let stem = path.file_stem().and_then(|s| s.to_str())?;
192
193 if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
194 Some(format!("{}/{}{}", parent.display(), stem, canonical))
195 } else {
196 Some(format!("{}{}", stem, canonical))
197 }
198 }
199
200 pub fn extract_entities(&self, file_path: &str, content: &str) -> Vec<SemanticEntity> {
204 let resolved = self.resolve_file_path(file_path);
205 let detection_path = resolved.as_deref().unwrap_or(file_path);
206
207 let plugin = match self.get_plugin_with_content(detection_path, content) {
208 Some(p) => p,
209 None => return Vec::new(),
210 };
211
212 let mut entities = plugin.extract_entities(content, detection_path);
213 if let Some(ref rp) = resolved {
214 fix_entity_paths(&mut entities, file_path, rp);
215 }
216 entities
217 }
218
219 pub fn extract_entities_with_tree(
221 &self,
222 file_path: &str,
223 content: &str,
224 ) -> Option<(Vec<SemanticEntity>, Option<tree_sitter::Tree>)> {
225 let resolved = self.resolve_file_path(file_path);
226 let detection_path = resolved.as_deref().unwrap_or(file_path);
227
228 let plugin = self.get_plugin_with_content(detection_path, content)?;
229 let (mut entities, tree) = plugin.extract_entities_with_tree(content, detection_path);
230 if let Some(ref rp) = resolved {
231 fix_entity_paths(&mut entities, file_path, rp);
232 }
233 Some((entities, tree))
234 }
235
236 pub fn extract_all_entities(
238 &self,
239 root: &Path,
240 file_paths: &[String],
241 ) -> Vec<SemanticEntity> {
242 file_paths
243 .par_iter()
244 .flat_map(|fp| {
245 let full = root.join(fp);
246 let content = match std::fs::read_to_string(&full) {
247 Ok(c) => c,
248 Err(_) => return Vec::new(),
249 };
250 self.extract_entities(fp, &content)
251 })
252 .collect()
253 }
254}
255
256fn fix_entity_paths(entities: &mut [SemanticEntity], original: &str, resolved: &str) {
258 for entity in entities {
259 entity.file_path = original.to_string();
260 entity.id = entity.id.replace(resolved, original);
261 if let Some(ref mut pid) = entity.parent_id {
262 *pid = pid.replace(resolved, original);
263 }
264 }
265}
266
267fn get_extensions(file_path: &str) -> Vec<String> {
268 let Some(file_name) = Path::new(file_path)
269 .file_name()
270 .and_then(|name| name.to_str())
271 else {
272 return Vec::new();
273 };
274
275 let file_name = file_name.to_lowercase();
276 let mut extensions = Vec::new();
277
278 for (idx, ch) in file_name.char_indices() {
279 if ch == '.' {
280 extensions.push(file_name[idx..].to_string());
281 }
282 }
283
284 extensions
285}
286
287const LANG_MAPPING: &[(&str, &str)] = &[
288 ("perl", ".pl"),
289 ("python", ".py"),
290 ("ruby", ".rb"),
291 ("bash", ".sh"),
292 ("shell", ".sh"),
293 ("/sh", ".sh"),
294 ("node", ".js"),
295 ("javascript", ".js"),
296 ("typescript", ".ts"),
297 ("tsx", ".tsx"),
298 ("swift", ".swift"),
299 ("elixir", ".ex"),
300 ("rust", ".rs"),
301 ("go", ".go"),
302 ("golang", ".go"),
303 ("kotlin", ".kt"),
304 ("dart", ".dart"),
305 ("php", ".php"),
306 ("java", ".java"),
307 ("c", ".c"),
308 ("cpp", ".cpp"),
309 ("c++", ".cpp"),
310 ("cs", ".cs"),
311 ("csharp", ".cs"),
312 ("c#", ".cs"),
313 ("fortran", ".f90"),
314 ("terraform", ".tf"),
315 ("hcl", ".hcl"),
316 ("ocaml", ".ml"),
317 ("scala", ".scala"),
318 ("zig", ".zig"),
319 ("xml", ".xml"),
320 ("json", ".json"),
321 ("yaml", ".yaml"),
322 ("yml", ".yaml"),
323 ("toml", ".toml"),
324 ("markdown", ".md"),
325 ("csv", ".csv"),
326 ("eruby", ".erb"),
327 ("erb", ".erb"),
328 ("vue", ".vue"),
329 ("svelte", ".svelte"),
330];
331
332pub fn detect_ext_from_content(content: &str) -> Option<String> {
334 if let Some(first_line) = content.lines().next() {
336 if first_line.starts_with("#!") {
337 let shebang = first_line.to_lowercase();
338 for (keyword, ext) in LANG_MAPPING {
339 if shebang.contains(keyword) {
340 return Some(ext.to_string());
341 }
342 }
343 }
344 }
345
346 let lines: Vec<&str> = content.lines().collect();
349 let check_lines = lines.iter().take(5).chain(lines.iter().rev().take(5));
350 for line in check_lines {
351 if let Some(ft) = extract_vim_filetype(line) {
352 let ft_lower = ft.to_lowercase();
353 for (keyword, ext) in LANG_MAPPING {
354 if ft_lower == *keyword {
355 return Some(ext.to_string());
356 }
357 }
358 }
359 }
360
361 if let Some(ext) = detect_from_content_heuristics(content) {
363 return Some(ext);
364 }
365
366 None
367}
368
369fn detect_from_content_heuristics(content: &str) -> Option<String> {
372 let first_line = content.lines().next().unwrap_or("").trim();
373
374 if first_line.starts_with("<?php") || first_line.starts_with("<?PHP") {
376 return Some(".php".to_string());
377 }
378
379 if first_line.starts_with("<?xml") {
381 return Some(".xml".to_string());
382 }
383 if first_line.starts_with("<!DOCTYPE") || first_line.starts_with("<!doctype") {
384 return Some(".xml".to_string());
385 }
386
387 for line in content.lines().take(20) {
389 let trimmed = line.trim();
390
391 if trimmed.starts_with("<?php") || trimmed.starts_with("<?PHP") || trimmed == "<?=" {
393 return Some(".php".to_string());
394 }
395
396 if trimmed.starts_with("#include ") || trimmed.starts_with("#include\t") {
398 if content.lines().take(30).any(|l| {
400 let t = l.trim();
401 t.starts_with("using namespace")
402 || t.starts_with("class ")
403 || t.starts_with("#include <iostream")
404 || t.starts_with("#include <vector")
405 || t.starts_with("#include <string>")
406 || t.starts_with("#include <memory")
407 }) {
408 return Some(".cpp".to_string());
409 }
410 return Some(".c".to_string());
411 }
412
413 if trimmed.starts_with("package ") && trimmed.contains('.') && trimmed.ends_with(';') {
415 return Some(".java".to_string());
416 }
417
418 if trimmed.starts_with("package ") && !trimmed.contains('.') && !trimmed.contains(';') {
420 return Some(".go".to_string());
421 }
422
423 if (trimmed.starts_with("use std::") || trimmed.starts_with("use crate::"))
425 && trimmed.ends_with(';')
426 {
427 return Some(".rs".to_string());
428 }
429
430 if trimmed.starts_with("defmodule ") {
432 return Some(".ex".to_string());
433 }
434
435 if trimmed.starts_with("package ") && trimmed.contains('.') && !trimmed.ends_with(';') {
437 return Some(".kt".to_string());
438 }
439
440 if trimmed.starts_with("using System") && trimmed.ends_with(';') {
442 return Some(".cs".to_string());
443 }
444 if trimmed.starts_with("namespace ") && trimmed.ends_with('{') {
445 return Some(".cs".to_string());
448 }
449
450 if trimmed == "import Foundation"
452 || trimmed == "import UIKit"
453 || trimmed == "import SwiftUI"
454 || trimmed == "import Combine"
455 {
456 return Some(".swift".to_string());
457 }
458
459 if trimmed.starts_with("import 'dart:") || trimmed.starts_with("import \"dart:") {
461 return Some(".dart".to_string());
462 }
463
464 if trimmed.starts_with("object ") || trimmed.starts_with("trait ") {
466 return Some(".scala".to_string());
467 }
468
469 if trimmed.contains("@import(") {
471 return Some(".zig".to_string());
472 }
473
474 if trimmed.starts_with("resource \"")
476 || trimmed.starts_with("variable \"")
477 || trimmed.starts_with("terraform {")
478 || trimmed.starts_with("provider \"")
479 {
480 return Some(".tf".to_string());
481 }
482
483 let lower = trimmed.to_lowercase();
485 if lower.starts_with("program ") || lower.starts_with("module ")
486 || lower.starts_with("subroutine ") || lower == "implicit none"
487 {
488 if lower.starts_with("program ") || lower == "implicit none" {
491 return Some(".f90".to_string());
492 }
493 if content.lines().take(20).any(|l| l.trim().to_lowercase() == "implicit none") {
494 return Some(".f90".to_string());
495 }
496 }
497
498 if (trimmed.starts_with("def ") || trimmed.starts_with("class "))
500 && trimmed.ends_with(':')
501 && line.starts_with(trimmed.chars().next().unwrap_or(' '))
502 {
503 return Some(".py".to_string());
504 }
505
506 if trimmed.starts_with("require '") || trimmed.starts_with("require \"")
508 || trimmed.starts_with("require_relative ")
509 {
510 return Some(".rb".to_string());
511 }
512
513 if trimmed == "use strict;"
515 || trimmed == "use warnings;"
516 || trimmed.starts_with("my $")
517 || trimmed.starts_with("my @")
518 || trimmed.starts_with("my %")
519 {
520 return Some(".pl".to_string());
521 }
522 }
523
524 None
525}
526
527fn extract_vim_filetype(line: &str) -> Option<&str> {
528 let line = line.trim();
530 let vim_idx = line.find("vim:")?;
531 let after_vim = &line[vim_idx + 4..];
532
533 for token in after_vim.split_whitespace() {
534 if let Some(val) = token.strip_prefix("ft=") {
535 return Some(val.trim_end_matches(':'));
536 }
537 if let Some(val) = token.strip_prefix("filetype=") {
538 return Some(val.trim_end_matches(':'));
539 }
540 }
541 None
542}
543
544#[cfg(test)]
545mod tests {
546 use crate::parser::plugins::create_default_registry;
547
548 #[test]
549 fn test_registry_matches_compound_svelte_typescript_suffix() {
550 let registry = create_default_registry();
551 let plugin = registry
552 .get_plugin("src/routes/+page.svelte.ts")
553 .expect("plugin should exist");
554
555 assert_eq!(plugin.id(), "svelte");
556 }
557
558 #[test]
559 fn test_registry_matches_compound_svelte_javascript_suffix() {
560 let registry = create_default_registry();
561 let plugin = registry
562 .get_plugin("src/routes/+layout.svelte.js")
563 .expect("plugin should exist");
564
565 assert_eq!(plugin.id(), "svelte");
566 }
567
568 #[test]
569 fn test_registry_matches_svelte_test_suffix() {
570 let registry = create_default_registry();
571 let plugin = registry
572 .get_plugin("src/lib/multiplier.svelte.test.js")
573 .expect("plugin should exist");
574
575 assert_eq!(plugin.id(), "svelte");
576 }
577
578 #[test]
579 fn test_registry_prefers_svelte_plugin_for_component_files() {
580 let registry = create_default_registry();
581 let plugin = registry
582 .get_plugin("src/lib/Component.svelte")
583 .expect("plugin should exist");
584
585 assert_eq!(plugin.id(), "svelte");
586 }
587
588 #[test]
589 fn test_registry_matches_typescript_module_suffix() {
590 let registry = create_default_registry();
591 let plugin = registry
592 .get_plugin("src/lib/index.mts")
593 .expect("plugin should exist");
594
595 assert_eq!(plugin.id(), "code");
596 }
597
598 #[test]
599 fn test_registry_matches_typescript_commonjs_suffix() {
600 let registry = create_default_registry();
601 let plugin = registry
602 .get_plugin("src/lib/index.cts")
603 .expect("plugin should exist");
604
605 assert_eq!(plugin.id(), "code");
606 }
607
608 #[test]
609 fn test_detect_php_from_opening_tag() {
610 let registry = create_default_registry();
611 let content = "<?php\nclass Vendor {\n function get_name() { return $this->name; }\n}\n";
612 let plugin = registry
613 .get_plugin_with_content("vendor.inc2", content)
614 .expect("should detect PHP");
615 let entities = plugin.extract_entities(content, "vendor.inc2");
616 assert!(entities.iter().any(|e| e.entity_type == "class"));
617 }
618
619 #[test]
620 fn test_detect_c_from_include() {
621 let registry = create_default_registry();
622 let content = "#include <stdio.h>\n\nint main() {\n printf(\"hello\");\n return 0;\n}\n";
623 let plugin = registry
624 .get_plugin_with_content("main.xyz", content)
625 .expect("should detect C");
626 let entities = plugin.extract_entities(content, "main.xyz");
627 assert!(entities.iter().any(|e| e.name == "main"));
628 }
629
630 #[test]
631 fn test_detect_java_from_package() {
632 let registry = create_default_registry();
633 let content = "package com.example.app;\n\npublic class Main {\n public static void main(String[] args) {}\n}\n";
634 let plugin = registry
635 .get_plugin_with_content("Main", content)
636 .expect("should detect Java");
637 let entities = plugin.extract_entities(content, "Main");
638 assert!(entities.iter().any(|e| e.name == "Main"));
639 }
640
641 #[test]
642 fn test_detect_go_from_package() {
643 let registry = create_default_registry();
644 let content = "package main\n\nimport \"fmt\"\n\nfunc hello() {\n fmt.Println(\"hi\")\n}\n";
645 let plugin = registry
646 .get_plugin_with_content("main", content)
647 .expect("should detect Go");
648 let entities = plugin.extract_entities(content, "main");
649 assert!(entities.iter().any(|e| e.name == "hello"));
650 }
651
652 #[test]
653 fn test_detect_rust_from_use_std() {
654 let registry = create_default_registry();
655 let content = "use std::collections::HashMap;\n\nfn process() {\n let m = HashMap::new();\n}\n";
656 let plugin = registry
657 .get_plugin_with_content("lib", content)
658 .expect("should detect Rust");
659 let entities = plugin.extract_entities(content, "lib");
660 assert!(entities.iter().any(|e| e.name == "process"));
661 }
662
663 #[test]
664 fn test_extension_takes_priority_over_heuristics() {
665 let registry = create_default_registry();
666 let content = "<?php\nclass Foo {}\n";
668 let plugin = registry
669 .get_plugin_with_content("script.py", content)
670 .expect("should use Python parser");
671 assert_eq!(plugin.id(), "code"); }
673
674 #[test]
675 fn test_custom_extension_mapping_extracts_entities() {
676 let mut registry = create_default_registry();
677 registry.add_extension_mapping(".mypy", "python");
678
679 let content = "def hello():\n print(\"hello world\")\n\nclass Calculator:\n def multiply(self, a, b):\n return a * b\n";
680 let entities = registry.extract_entities("utils.mypy", content);
681
682 assert!(!entities.is_empty(), "Should extract entities via custom mapping");
683 assert!(entities.iter().any(|e| e.name == "hello"), "Should find hello function");
684 assert!(entities.iter().any(|e| e.name == "Calculator"), "Should find Calculator class");
685 assert!(entities.iter().any(|e| e.name == "multiply"), "Should find multiply method");
686
687 for entity in &entities {
689 assert_eq!(entity.file_path, "utils.mypy", "Entity file_path should use original extension");
690 assert!(entity.id.starts_with("utils.mypy::"), "Entity ID should use original file path");
691 }
692 }
693}