1use std::collections::HashMap;
2use std::path::Path;
3#[cfg(feature = "parallel")]
4use rayon::prelude::*;
5
6use crate::model::entity::SemanticEntity;
7
8macro_rules! maybe_par_iter {
9 ($slice:expr) => {{
10 #[cfg(feature = "parallel")]
11 { $slice.par_iter() }
12 #[cfg(not(feature = "parallel"))]
13 { $slice.iter() }
14 }};
15}
16use super::plugin::SemanticParserPlugin;
17
18pub struct ParserRegistry {
19 plugins: Vec<Box<dyn SemanticParserPlugin>>,
20 extension_map: HashMap<String, usize>, custom_ext_canonical: HashMap<String, String>, }
23
24impl ParserRegistry {
25 pub fn new() -> Self {
26 Self {
27 plugins: Vec::new(),
28 extension_map: HashMap::new(),
29 custom_ext_canonical: HashMap::new(),
30 }
31 }
32
33 pub fn register(&mut self, plugin: Box<dyn SemanticParserPlugin>) {
34 let idx = self.plugins.len();
35 for ext in plugin.extensions() {
36 self.extension_map.insert(ext.to_string(), idx);
37 }
38 self.plugins.push(plugin);
39 }
40
41 pub fn get_plugin(&self, file_path: &str) -> Option<&dyn SemanticParserPlugin> {
42 for ext in get_extensions(file_path) {
43 if let Some(&idx) = self.extension_map.get(&ext) {
44 return Some(self.plugins[idx].as_ref());
45 }
46 }
47 self.get_plugin_by_id("fallback")
49 }
50
51 pub fn get_plugin_with_content(&self, file_path: &str, content: &str) -> Option<&dyn SemanticParserPlugin> {
54 for ext in get_extensions(file_path) {
56 if let Some(&idx) = self.extension_map.get(&ext) {
57 return Some(self.plugins[idx].as_ref());
58 }
59 }
60 if let Some(plugin) = self.detect_from_shebang(content) {
62 return Some(plugin);
63 }
64 self.get_plugin_by_id("fallback")
66 }
67
68 fn detect_from_shebang(&self, content: &str) -> Option<&dyn SemanticParserPlugin> {
69 if let Some(ext) = detect_ext_from_content(content) {
70 if let Some(&idx) = self.extension_map.get(ext.as_str()) {
71 return Some(self.plugins[idx].as_ref());
72 }
73 }
74 None
75 }
76
77 pub fn get_plugin_by_id(&self, id: &str) -> Option<&dyn SemanticParserPlugin> {
78 self.plugins
79 .iter()
80 .find(|p| p.id() == id)
81 .map(|p| p.as_ref())
82 }
83
84 pub fn add_extension_mapping(&mut self, ext: &str, language: &str) -> bool {
87 let ext = if ext.starts_with('.') {
88 ext.to_lowercase()
89 } else {
90 format!(".{}", ext.to_lowercase())
91 };
92
93 let target_ext = LANG_MAPPING
95 .iter()
96 .find(|(kw, _)| *kw == language.to_lowercase())
97 .map(|(_, e)| *e);
98
99 if let Some(target) = target_ext {
100 if let Some(&idx) = self.extension_map.get(target) {
101 self.custom_ext_canonical.insert(ext.clone(), target.to_string());
102 self.extension_map.insert(ext, idx);
103 return true;
104 }
105 }
106
107 let direct_ext = format!(".{}", language.to_lowercase());
109 if let Some(&idx) = self.extension_map.get(&direct_ext) {
110 self.custom_ext_canonical.insert(ext.clone(), direct_ext);
111 self.extension_map.insert(ext, idx);
112 return true;
113 }
114
115 false
116 }
117
118 pub fn load_semrc(&mut self, root: &Path) {
125 let semrc_path = root.join(".semrc");
126 if !semrc_path.exists() {
127 return;
128 }
129 let content = match std::fs::read_to_string(&semrc_path) {
130 Ok(c) => c,
131 Err(_) => return,
132 };
133 for line in content.lines() {
134 let line = line.trim();
135 if line.is_empty() || line.starts_with('#') {
136 continue;
137 }
138 if let Some((ext, lang)) = line.split_once('=') {
139 self.add_extension_mapping(ext.trim(), lang.trim());
140 }
141 }
142 }
143
144 pub fn load_gitattributes(&mut self, root: &Path) {
148 let ga_path = root.join(".gitattributes");
149 if !ga_path.exists() {
150 return;
151 }
152 let content = match std::fs::read_to_string(&ga_path) {
153 Ok(c) => c,
154 Err(_) => return,
155 };
156 for line in content.lines() {
157 let line = line.trim();
158 if line.is_empty() || line.starts_with('#') {
159 continue;
160 }
161 let mut parts = line.split_whitespace();
162 let pattern = match parts.next() {
163 Some(p) => p,
164 None => continue,
165 };
166 let ext = match pattern.strip_prefix("*.") {
168 Some(e) => e,
169 None => continue,
170 };
171 let ext_key = format!(".{}", ext.to_lowercase());
173 if self.custom_ext_canonical.contains_key(&ext_key) {
174 continue;
175 }
176 for attr in parts {
178 if let Some(lang) = attr.strip_prefix("diff=") {
179 self.add_extension_mapping(ext, lang);
180 break;
181 }
182 if let Some(lang) = attr.strip_prefix("linguist-language=") {
183 self.add_extension_mapping(ext, lang);
184 break;
185 }
186 }
187 }
188 }
189
190 pub fn resolve_file_path(&self, file_path: &str) -> Option<String> {
194 let path = Path::new(file_path);
195 let ext = path
196 .extension()
197 .and_then(|e| e.to_str())
198 .map(|e| format!(".{}", e.to_lowercase()))?;
199
200 let canonical = self.custom_ext_canonical.get(&ext)?;
201 let stem = path.file_stem().and_then(|s| s.to_str())?;
202
203 if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
204 Some(format!("{}/{}{}", parent.display(), stem, canonical))
205 } else {
206 Some(format!("{}{}", stem, canonical))
207 }
208 }
209
210 pub fn extract_entities(&self, file_path: &str, content: &str) -> Vec<SemanticEntity> {
214 let resolved = self.resolve_file_path(file_path);
215 let detection_path = resolved.as_deref().unwrap_or(file_path);
216
217 let plugin = match self.get_plugin_with_content(detection_path, content) {
218 Some(p) => p,
219 None => return Vec::new(),
220 };
221
222 let mut entities = plugin.extract_entities(content, detection_path);
223 if let Some(ref rp) = resolved {
224 fix_entity_paths(&mut entities, file_path, rp);
225 }
226 entities
227 }
228
229 pub fn extract_entities_with_tree(
231 &self,
232 file_path: &str,
233 content: &str,
234 ) -> Option<(Vec<SemanticEntity>, Option<tree_sitter::Tree>)> {
235 let resolved = self.resolve_file_path(file_path);
236 let detection_path = resolved.as_deref().unwrap_or(file_path);
237
238 let plugin = self.get_plugin_with_content(detection_path, content)?;
239 let (mut entities, tree) = plugin.extract_entities_with_tree(content, detection_path);
240 if let Some(ref rp) = resolved {
241 fix_entity_paths(&mut entities, file_path, rp);
242 }
243 Some((entities, tree))
244 }
245
246 pub fn extract_all_entities(
248 &self,
249 root: &Path,
250 file_paths: &[String],
251 ) -> Vec<SemanticEntity> {
252 maybe_par_iter!(file_paths)
253 .flat_map(|fp| {
254 let full = root.join(fp);
255 let content = match std::fs::read_to_string(&full) {
256 Ok(c) => c,
257 Err(_) => return Vec::new(),
258 };
259 self.extract_entities(fp, &content)
260 })
261 .collect()
262 }
263}
264
265fn fix_entity_paths(entities: &mut [SemanticEntity], original: &str, resolved: &str) {
267 for entity in entities {
268 entity.file_path = original.to_string();
269 entity.id = entity.id.replace(resolved, original);
270 if let Some(ref mut pid) = entity.parent_id {
271 *pid = pid.replace(resolved, original);
272 }
273 }
274}
275
276fn get_extensions(file_path: &str) -> Vec<String> {
277 let Some(file_name) = Path::new(file_path)
278 .file_name()
279 .and_then(|name| name.to_str())
280 else {
281 return Vec::new();
282 };
283
284 let file_name = file_name.to_lowercase();
285 let mut extensions = Vec::new();
286
287 for (idx, ch) in file_name.char_indices() {
288 if ch == '.' {
289 extensions.push(file_name[idx..].to_string());
290 }
291 }
292
293 extensions
294}
295
296const LANG_MAPPING: &[(&str, &str)] = &[
297 ("perl", ".pl"),
298 ("python", ".py"),
299 ("ruby", ".rb"),
300 ("bash", ".sh"),
301 ("shell", ".sh"),
302 ("/sh", ".sh"),
303 ("node", ".js"),
304 ("javascript", ".js"),
305 ("typescript", ".ts"),
306 ("tsx", ".tsx"),
307 ("swift", ".swift"),
308 ("elixir", ".ex"),
309 ("rust", ".rs"),
310 ("go", ".go"),
311 ("golang", ".go"),
312 ("kotlin", ".kt"),
313 ("dart", ".dart"),
314 ("php", ".php"),
315 ("java", ".java"),
316 ("c", ".c"),
317 ("cpp", ".cpp"),
318 ("c++", ".cpp"),
319 ("cs", ".cs"),
320 ("csharp", ".cs"),
321 ("c#", ".cs"),
322 ("fortran", ".f90"),
323 ("terraform", ".tf"),
324 ("hcl", ".hcl"),
325 ("ocaml", ".ml"),
326 ("scala", ".scala"),
327 ("zig", ".zig"),
328 ("xml", ".xml"),
329 ("json", ".json"),
330 ("yaml", ".yaml"),
331 ("yml", ".yaml"),
332 ("toml", ".toml"),
333 ("markdown", ".md"),
334 ("csv", ".csv"),
335 ("eruby", ".erb"),
336 ("erb", ".erb"),
337 ("vue", ".vue"),
338 ("svelte", ".svelte"),
339];
340
341pub fn detect_ext_from_content(content: &str) -> Option<String> {
343 if let Some(first_line) = content.lines().next() {
345 if first_line.starts_with("#!") {
346 let shebang = first_line.to_lowercase();
347 for (keyword, ext) in LANG_MAPPING {
348 if shebang.contains(keyword) {
349 return Some(ext.to_string());
350 }
351 }
352 }
353 }
354
355 let lines: Vec<&str> = content.lines().collect();
358 let check_lines = lines.iter().take(5).chain(lines.iter().rev().take(5));
359 for line in check_lines {
360 if let Some(ft) = extract_vim_filetype(line) {
361 let ft_lower = ft.to_lowercase();
362 for (keyword, ext) in LANG_MAPPING {
363 if ft_lower == *keyword {
364 return Some(ext.to_string());
365 }
366 }
367 }
368 }
369
370 if let Some(ext) = detect_from_content_heuristics(content) {
372 return Some(ext);
373 }
374
375 None
376}
377
378fn detect_from_content_heuristics(content: &str) -> Option<String> {
381 let first_line = content.lines().next().unwrap_or("").trim();
382
383 if first_line.starts_with("<?php") || first_line.starts_with("<?PHP") {
385 return Some(".php".to_string());
386 }
387
388 if first_line.starts_with("<?xml") {
390 return Some(".xml".to_string());
391 }
392 if first_line.starts_with("<!DOCTYPE") || first_line.starts_with("<!doctype") {
393 return Some(".xml".to_string());
394 }
395
396 for line in content.lines().take(20) {
398 let trimmed = line.trim();
399
400 if trimmed.starts_with("<?php") || trimmed.starts_with("<?PHP") || trimmed == "<?=" {
402 return Some(".php".to_string());
403 }
404
405 if trimmed.starts_with("#include ") || trimmed.starts_with("#include\t") {
407 if content.lines().take(30).any(|l| {
409 let t = l.trim();
410 t.starts_with("using namespace")
411 || t.starts_with("class ")
412 || t.starts_with("#include <iostream")
413 || t.starts_with("#include <vector")
414 || t.starts_with("#include <string>")
415 || t.starts_with("#include <memory")
416 }) {
417 return Some(".cpp".to_string());
418 }
419 return Some(".c".to_string());
420 }
421
422 if trimmed.starts_with("package ") && trimmed.contains('.') && trimmed.ends_with(';') {
424 return Some(".java".to_string());
425 }
426
427 if trimmed.starts_with("package ") && !trimmed.contains('.') && !trimmed.contains(';') {
429 return Some(".go".to_string());
430 }
431
432 if (trimmed.starts_with("use std::") || trimmed.starts_with("use crate::"))
434 && trimmed.ends_with(';')
435 {
436 return Some(".rs".to_string());
437 }
438
439 if trimmed.starts_with("defmodule ") {
441 return Some(".ex".to_string());
442 }
443
444 if trimmed.starts_with("package ") && trimmed.contains('.') && !trimmed.ends_with(';') {
446 return Some(".kt".to_string());
447 }
448
449 if trimmed.starts_with("using System") && trimmed.ends_with(';') {
451 return Some(".cs".to_string());
452 }
453 if trimmed.starts_with("namespace ") && trimmed.ends_with('{') {
454 return Some(".cs".to_string());
457 }
458
459 if trimmed == "import Foundation"
461 || trimmed == "import UIKit"
462 || trimmed == "import SwiftUI"
463 || trimmed == "import Combine"
464 {
465 return Some(".swift".to_string());
466 }
467
468 if trimmed.starts_with("import 'dart:") || trimmed.starts_with("import \"dart:") {
470 return Some(".dart".to_string());
471 }
472
473 if trimmed.starts_with("object ") || trimmed.starts_with("trait ") {
475 return Some(".scala".to_string());
476 }
477
478 if trimmed.contains("@import(") {
480 return Some(".zig".to_string());
481 }
482
483 if trimmed.starts_with("resource \"")
485 || trimmed.starts_with("variable \"")
486 || trimmed.starts_with("terraform {")
487 || trimmed.starts_with("provider \"")
488 {
489 return Some(".tf".to_string());
490 }
491
492 let lower = trimmed.to_lowercase();
494 if lower.starts_with("program ") || lower.starts_with("module ")
495 || lower.starts_with("subroutine ") || lower == "implicit none"
496 {
497 if lower.starts_with("program ") || lower == "implicit none" {
500 return Some(".f90".to_string());
501 }
502 if content.lines().take(20).any(|l| l.trim().to_lowercase() == "implicit none") {
503 return Some(".f90".to_string());
504 }
505 }
506
507 if (trimmed.starts_with("def ") || trimmed.starts_with("class "))
509 && trimmed.ends_with(':')
510 && line.starts_with(trimmed.chars().next().unwrap_or(' '))
511 {
512 return Some(".py".to_string());
513 }
514
515 if trimmed.starts_with("require '") || trimmed.starts_with("require \"")
517 || trimmed.starts_with("require_relative ")
518 {
519 return Some(".rb".to_string());
520 }
521
522 if trimmed == "use strict;"
524 || trimmed == "use warnings;"
525 || trimmed.starts_with("my $")
526 || trimmed.starts_with("my @")
527 || trimmed.starts_with("my %")
528 {
529 return Some(".pl".to_string());
530 }
531 }
532
533 None
534}
535
536fn extract_vim_filetype(line: &str) -> Option<&str> {
537 let line = line.trim();
539 let vim_idx = line.find("vim:")?;
540 let after_vim = &line[vim_idx + 4..];
541
542 for token in after_vim.split_whitespace() {
543 if let Some(val) = token.strip_prefix("ft=") {
544 return Some(val.trim_end_matches(':'));
545 }
546 if let Some(val) = token.strip_prefix("filetype=") {
547 return Some(val.trim_end_matches(':'));
548 }
549 }
550 None
551}
552
553#[cfg(test)]
554mod tests {
555 use crate::parser::plugins::create_default_registry;
556
557 #[test]
558 fn test_registry_matches_compound_svelte_typescript_suffix() {
559 let registry = create_default_registry();
560 let plugin = registry
561 .get_plugin("src/routes/+page.svelte.ts")
562 .expect("plugin should exist");
563
564 assert_eq!(plugin.id(), "svelte");
565 }
566
567 #[test]
568 fn test_registry_matches_compound_svelte_javascript_suffix() {
569 let registry = create_default_registry();
570 let plugin = registry
571 .get_plugin("src/routes/+layout.svelte.js")
572 .expect("plugin should exist");
573
574 assert_eq!(plugin.id(), "svelte");
575 }
576
577 #[test]
578 fn test_registry_matches_svelte_test_suffix() {
579 let registry = create_default_registry();
580 let plugin = registry
581 .get_plugin("src/lib/multiplier.svelte.test.js")
582 .expect("plugin should exist");
583
584 assert_eq!(plugin.id(), "svelte");
585 }
586
587 #[test]
588 fn test_registry_prefers_svelte_plugin_for_component_files() {
589 let registry = create_default_registry();
590 let plugin = registry
591 .get_plugin("src/lib/Component.svelte")
592 .expect("plugin should exist");
593
594 assert_eq!(plugin.id(), "svelte");
595 }
596
597 #[test]
598 fn test_registry_matches_typescript_module_suffix() {
599 let registry = create_default_registry();
600 let plugin = registry
601 .get_plugin("src/lib/index.mts")
602 .expect("plugin should exist");
603
604 assert_eq!(plugin.id(), "code");
605 }
606
607 #[test]
608 fn test_registry_matches_typescript_commonjs_suffix() {
609 let registry = create_default_registry();
610 let plugin = registry
611 .get_plugin("src/lib/index.cts")
612 .expect("plugin should exist");
613
614 assert_eq!(plugin.id(), "code");
615 }
616
617 #[test]
618 fn test_detect_php_from_opening_tag() {
619 let registry = create_default_registry();
620 let content = "<?php\nclass Vendor {\n function get_name() { return $this->name; }\n}\n";
621 let plugin = registry
622 .get_plugin_with_content("vendor.inc2", content)
623 .expect("should detect PHP");
624 let entities = plugin.extract_entities(content, "vendor.inc2");
625 assert!(entities.iter().any(|e| e.entity_type == "class"));
626 }
627
628 #[test]
629 fn test_detect_c_from_include() {
630 let registry = create_default_registry();
631 let content = "#include <stdio.h>\n\nint main() {\n printf(\"hello\");\n return 0;\n}\n";
632 let plugin = registry
633 .get_plugin_with_content("main.xyz", content)
634 .expect("should detect C");
635 let entities = plugin.extract_entities(content, "main.xyz");
636 assert!(entities.iter().any(|e| e.name == "main"));
637 }
638
639 #[test]
640 fn test_detect_java_from_package() {
641 let registry = create_default_registry();
642 let content = "package com.example.app;\n\npublic class Main {\n public static void main(String[] args) {}\n}\n";
643 let plugin = registry
644 .get_plugin_with_content("Main", content)
645 .expect("should detect Java");
646 let entities = plugin.extract_entities(content, "Main");
647 assert!(entities.iter().any(|e| e.name == "Main"));
648 }
649
650 #[test]
651 fn test_detect_go_from_package() {
652 let registry = create_default_registry();
653 let content = "package main\n\nimport \"fmt\"\n\nfunc hello() {\n fmt.Println(\"hi\")\n}\n";
654 let plugin = registry
655 .get_plugin_with_content("main", content)
656 .expect("should detect Go");
657 let entities = plugin.extract_entities(content, "main");
658 assert!(entities.iter().any(|e| e.name == "hello"));
659 }
660
661 #[test]
662 fn test_detect_rust_from_use_std() {
663 let registry = create_default_registry();
664 let content = "use std::collections::HashMap;\n\nfn process() {\n let m = HashMap::new();\n}\n";
665 let plugin = registry
666 .get_plugin_with_content("lib", content)
667 .expect("should detect Rust");
668 let entities = plugin.extract_entities(content, "lib");
669 assert!(entities.iter().any(|e| e.name == "process"));
670 }
671
672 #[test]
673 fn test_extension_takes_priority_over_heuristics() {
674 let registry = create_default_registry();
675 let content = "<?php\nclass Foo {}\n";
677 let plugin = registry
678 .get_plugin_with_content("script.py", content)
679 .expect("should use Python parser");
680 assert_eq!(plugin.id(), "code"); }
682
683 #[test]
684 fn test_custom_extension_mapping_extracts_entities() {
685 let mut registry = create_default_registry();
686 registry.add_extension_mapping(".mypy", "python");
687
688 let content = "def hello():\n print(\"hello world\")\n\nclass Calculator:\n def multiply(self, a, b):\n return a * b\n";
689 let entities = registry.extract_entities("utils.mypy", content);
690
691 assert!(!entities.is_empty(), "Should extract entities via custom mapping");
692 assert!(entities.iter().any(|e| e.name == "hello"), "Should find hello function");
693 assert!(entities.iter().any(|e| e.name == "Calculator"), "Should find Calculator class");
694 assert!(entities.iter().any(|e| e.name == "multiply"), "Should find multiply method");
695
696 for entity in &entities {
698 assert_eq!(entity.file_path, "utils.mypy", "Entity file_path should use original extension");
699 assert!(entity.id.starts_with("utils.mypy::"), "Entity ID should use original file path");
700 }
701 }
702}