Skip to main content

sem_core/parser/
registry.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use super::plugin::SemanticParserPlugin;
5
6pub struct ParserRegistry {
7    plugins: Vec<Box<dyn SemanticParserPlugin>>,
8    extension_map: HashMap<String, usize>, // ext → index into plugins
9    custom_ext_canonical: HashMap<String, String>, // ".mypy" → ".py" (custom → canonical)
10}
11
12impl ParserRegistry {
13    pub fn new() -> Self {
14        Self {
15            plugins: Vec::new(),
16            extension_map: HashMap::new(),
17            custom_ext_canonical: HashMap::new(),
18        }
19    }
20
21    pub fn register(&mut self, plugin: Box<dyn SemanticParserPlugin>) {
22        let idx = self.plugins.len();
23        for ext in plugin.extensions() {
24            self.extension_map.insert(ext.to_string(), idx);
25        }
26        self.plugins.push(plugin);
27    }
28
29    pub fn get_plugin(&self, file_path: &str) -> Option<&dyn SemanticParserPlugin> {
30        for ext in get_extensions(file_path) {
31            if let Some(&idx) = self.extension_map.get(&ext) {
32                return Some(self.plugins[idx].as_ref());
33            }
34        }
35        // Fallback plugin
36        self.get_plugin_by_id("fallback")
37    }
38
39    /// Try to detect language from shebang line when extension-based lookup fails.
40    /// Call this as a fallback when file content is available.
41    pub fn get_plugin_with_content(&self, file_path: &str, content: &str) -> Option<&dyn SemanticParserPlugin> {
42        // Try extension first
43        for ext in get_extensions(file_path) {
44            if let Some(&idx) = self.extension_map.get(&ext) {
45                return Some(self.plugins[idx].as_ref());
46            }
47        }
48        // Try shebang detection
49        if let Some(plugin) = self.detect_from_shebang(content) {
50            return Some(plugin);
51        }
52        // Fallback plugin
53        self.get_plugin_by_id("fallback")
54    }
55
56    fn detect_from_shebang(&self, content: &str) -> Option<&dyn SemanticParserPlugin> {
57        if let Some(ext) = detect_ext_from_content(content) {
58            if let Some(&idx) = self.extension_map.get(ext.as_str()) {
59                return Some(self.plugins[idx].as_ref());
60            }
61        }
62        None
63    }
64
65    pub fn get_plugin_by_id(&self, id: &str) -> Option<&dyn SemanticParserPlugin> {
66        self.plugins
67            .iter()
68            .find(|p| p.id() == id)
69            .map(|p| p.as_ref())
70    }
71
72    /// Register a custom extension mapping from a .semrc file.
73    /// Maps an extension (e.g. ".inc") to an existing plugin by language name.
74    pub fn add_extension_mapping(&mut self, ext: &str, language: &str) -> bool {
75        let ext = if ext.starts_with('.') {
76            ext.to_lowercase()
77        } else {
78            format!(".{}", ext.to_lowercase())
79        };
80
81        // Find plugin index by matching language name against known extensions
82        let target_ext = LANG_MAPPING
83            .iter()
84            .find(|(kw, _)| *kw == language.to_lowercase())
85            .map(|(_, e)| *e);
86
87        if let Some(target) = target_ext {
88            if let Some(&idx) = self.extension_map.get(target) {
89                self.custom_ext_canonical.insert(ext.clone(), target.to_string());
90                self.extension_map.insert(ext, idx);
91                return true;
92            }
93        }
94
95        // Also try matching directly against registered extensions
96        let direct_ext = format!(".{}", language.to_lowercase());
97        if let Some(&idx) = self.extension_map.get(&direct_ext) {
98            self.custom_ext_canonical.insert(ext.clone(), direct_ext);
99            self.extension_map.insert(ext, idx);
100            return true;
101        }
102
103        false
104    }
105
106    /// Load extension mappings from a .semrc file at the given root directory.
107    /// File format (one mapping per line): `.ext = language`
108    /// Example:
109    ///   .inc = php
110    ///   .j = json
111    ///   .xyz = cpp
112    pub fn load_semrc(&mut self, root: &Path) {
113        let semrc_path = root.join(".semrc");
114        if !semrc_path.exists() {
115            return;
116        }
117        let content = match std::fs::read_to_string(&semrc_path) {
118            Ok(c) => c,
119            Err(_) => return,
120        };
121        for line in content.lines() {
122            let line = line.trim();
123            if line.is_empty() || line.starts_with('#') {
124                continue;
125            }
126            if let Some((ext, lang)) = line.split_once('=') {
127                self.add_extension_mapping(ext.trim(), lang.trim());
128            }
129        }
130    }
131
132    /// Load extension mappings from `.gitattributes` at the given root directory.
133    /// Parses `*.ext diff=language` and `*.ext linguist-language=Language` patterns.
134    /// Only processes `*.ext` glob patterns (not path-based patterns).
135    pub fn load_gitattributes(&mut self, root: &Path) {
136        let ga_path = root.join(".gitattributes");
137        if !ga_path.exists() {
138            return;
139        }
140        let content = match std::fs::read_to_string(&ga_path) {
141            Ok(c) => c,
142            Err(_) => return,
143        };
144        for line in content.lines() {
145            let line = line.trim();
146            if line.is_empty() || line.starts_with('#') {
147                continue;
148            }
149            let mut parts = line.split_whitespace();
150            let pattern = match parts.next() {
151                Some(p) => p,
152                None => continue,
153            };
154            // Only handle *.ext patterns
155            let ext = match pattern.strip_prefix("*.") {
156                Some(e) => e,
157                None => continue,
158            };
159            // Already mapped (e.g. by .semrc which takes priority)
160            let ext_key = format!(".{}", ext.to_lowercase());
161            if self.custom_ext_canonical.contains_key(&ext_key) {
162                continue;
163            }
164            // Look for diff= or linguist-language= attributes
165            for attr in parts {
166                if let Some(lang) = attr.strip_prefix("diff=") {
167                    self.add_extension_mapping(ext, lang);
168                    break;
169                }
170                if let Some(lang) = attr.strip_prefix("linguist-language=") {
171                    self.add_extension_mapping(ext, lang);
172                    break;
173                }
174            }
175        }
176    }
177
178    /// Resolve custom extension mappings in a file path.
179    /// E.g. if `.mypy` is mapped to `python` (canonical `.py`),
180    /// `"utils.mypy"` becomes `"utils.py"`.
181    pub fn resolve_file_path(&self, file_path: &str) -> Option<String> {
182        let path = Path::new(file_path);
183        let ext = path
184            .extension()
185            .and_then(|e| e.to_str())
186            .map(|e| format!(".{}", e.to_lowercase()))?;
187
188        let canonical = self.custom_ext_canonical.get(&ext)?;
189        let stem = path.file_stem().and_then(|s| s.to_str())?;
190
191        if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
192            Some(format!("{}/{}{}", parent.display(), stem, canonical))
193        } else {
194            Some(format!("{}{}", stem, canonical))
195        }
196    }
197
198    /// Extract entities, transparently handling custom extension mappings.
199    /// Uses the resolved path for language detection but restores the original
200    /// file path in entity metadata (file_path, id, parent_id).
201    pub fn extract_entities(&self, file_path: &str, content: &str) -> Vec<crate::model::entity::SemanticEntity> {
202        let resolved = self.resolve_file_path(file_path);
203        let detection_path = resolved.as_deref().unwrap_or(file_path);
204
205        let plugin = match self.get_plugin_with_content(detection_path, content) {
206            Some(p) => p,
207            None => return Vec::new(),
208        };
209
210        let mut entities = plugin.extract_entities(content, detection_path);
211        if let Some(ref rp) = resolved {
212            fix_entity_paths(&mut entities, file_path, rp);
213        }
214        entities
215    }
216
217    /// Extract entities with tree, transparently handling custom extension mappings.
218    pub fn extract_entities_with_tree(
219        &self,
220        file_path: &str,
221        content: &str,
222    ) -> Option<(Vec<crate::model::entity::SemanticEntity>, Option<tree_sitter::Tree>)> {
223        let resolved = self.resolve_file_path(file_path);
224        let detection_path = resolved.as_deref().unwrap_or(file_path);
225
226        let plugin = self.get_plugin_with_content(detection_path, content)?;
227        let (mut entities, tree) = plugin.extract_entities_with_tree(content, detection_path);
228        if let Some(ref rp) = resolved {
229            fix_entity_paths(&mut entities, file_path, rp);
230        }
231        Some((entities, tree))
232    }
233}
234
235/// Restore original file path in entities when a custom extension mapping was used.
236fn fix_entity_paths(entities: &mut [crate::model::entity::SemanticEntity], original: &str, resolved: &str) {
237    for entity in entities {
238        entity.file_path = original.to_string();
239        entity.id = entity.id.replace(resolved, original);
240        if let Some(ref mut pid) = entity.parent_id {
241            *pid = pid.replace(resolved, original);
242        }
243    }
244}
245
246fn get_extensions(file_path: &str) -> Vec<String> {
247    let Some(file_name) = Path::new(file_path)
248        .file_name()
249        .and_then(|name| name.to_str())
250    else {
251        return Vec::new();
252    };
253
254    let file_name = file_name.to_lowercase();
255    let mut extensions = Vec::new();
256
257    for (idx, ch) in file_name.char_indices() {
258        if ch == '.' {
259            extensions.push(file_name[idx..].to_string());
260        }
261    }
262
263    extensions
264}
265
266const LANG_MAPPING: &[(&str, &str)] = &[
267    ("perl", ".pl"),
268    ("python", ".py"),
269    ("ruby", ".rb"),
270    ("bash", ".sh"),
271    ("shell", ".sh"),
272    ("/sh", ".sh"),
273    ("node", ".js"),
274    ("javascript", ".js"),
275    ("typescript", ".ts"),
276    ("tsx", ".tsx"),
277    ("swift", ".swift"),
278    ("elixir", ".ex"),
279    ("rust", ".rs"),
280    ("go", ".go"),
281    ("golang", ".go"),
282    ("kotlin", ".kt"),
283    ("dart", ".dart"),
284    ("php", ".php"),
285    ("java", ".java"),
286    ("c", ".c"),
287    ("cpp", ".cpp"),
288    ("c++", ".cpp"),
289    ("cs", ".cs"),
290    ("csharp", ".cs"),
291    ("c#", ".cs"),
292    ("fortran", ".f90"),
293    ("terraform", ".tf"),
294    ("hcl", ".hcl"),
295    ("ocaml", ".ml"),
296    ("scala", ".scala"),
297    ("zig", ".zig"),
298    ("xml", ".xml"),
299    ("json", ".json"),
300    ("yaml", ".yaml"),
301    ("yml", ".yaml"),
302    ("toml", ".toml"),
303    ("markdown", ".md"),
304    ("csv", ".csv"),
305    ("eruby", ".erb"),
306    ("erb", ".erb"),
307    ("vue", ".vue"),
308    ("svelte", ".svelte"),
309];
310
311/// Detect file extension from shebang line, vim modeline, or content heuristics.
312pub fn detect_ext_from_content(content: &str) -> Option<String> {
313    // Try shebang (first line)
314    if let Some(first_line) = content.lines().next() {
315        if first_line.starts_with("#!") {
316            let shebang = first_line.to_lowercase();
317            for (keyword, ext) in LANG_MAPPING {
318                if shebang.contains(keyword) {
319                    return Some(ext.to_string());
320                }
321            }
322        }
323    }
324
325    // Try vim modeline (first 5 or last 5 lines)
326    // Formats: `vim: ft=perl`, `vim: filetype=perl`, `vim: set ft=perl`
327    let lines: Vec<&str> = content.lines().collect();
328    let check_lines = lines.iter().take(5).chain(lines.iter().rev().take(5));
329    for line in check_lines {
330        if let Some(ft) = extract_vim_filetype(line) {
331            let ft_lower = ft.to_lowercase();
332            for (keyword, ext) in LANG_MAPPING {
333                if ft_lower == *keyword {
334                    return Some(ext.to_string());
335                }
336            }
337        }
338    }
339
340    // Try content heuristics (first-line markers and early declarations)
341    if let Some(ext) = detect_from_content_heuristics(content) {
342        return Some(ext);
343    }
344
345    None
346}
347
348/// High-confidence content-based language detection.
349/// Only uses markers with near-zero false-positive rates.
350fn detect_from_content_heuristics(content: &str) -> Option<String> {
351    let first_line = content.lines().next().unwrap_or("").trim();
352
353    // PHP: opening tag is unambiguous
354    if first_line.starts_with("<?php") || first_line.starts_with("<?PHP") {
355        return Some(".php".to_string());
356    }
357
358    // XML/SVG/HTML: XML declaration or doctype
359    if first_line.starts_with("<?xml") {
360        return Some(".xml".to_string());
361    }
362    if first_line.starts_with("<!DOCTYPE") || first_line.starts_with("<!doctype") {
363        return Some(".xml".to_string());
364    }
365
366    // Scan first ~20 lines for language-specific patterns
367    for line in content.lines().take(20) {
368        let trimmed = line.trim();
369
370        // PHP: opening tag anywhere in early lines
371        if trimmed.starts_with("<?php") || trimmed.starts_with("<?PHP") || trimmed == "<?=" {
372            return Some(".php".to_string());
373        }
374
375        // C/C++: #include directive
376        if trimmed.starts_with("#include ") || trimmed.starts_with("#include\t") {
377            // Could be C or C++. Check for C++ indicators.
378            if content.lines().take(30).any(|l| {
379                let t = l.trim();
380                t.starts_with("using namespace")
381                    || t.starts_with("class ")
382                    || t.starts_with("#include <iostream")
383                    || t.starts_with("#include <vector")
384                    || t.starts_with("#include <string>")
385                    || t.starts_with("#include <memory")
386            }) {
387                return Some(".cpp".to_string());
388            }
389            return Some(".c".to_string());
390        }
391
392        // Java: package declaration with dots
393        if trimmed.starts_with("package ") && trimmed.contains('.') && trimmed.ends_with(';') {
394            return Some(".java".to_string());
395        }
396
397        // Go: package declaration without dots or semicolons
398        if trimmed.starts_with("package ") && !trimmed.contains('.') && !trimmed.contains(';') {
399            return Some(".go".to_string());
400        }
401
402        // Rust: common top-level declarations
403        if (trimmed.starts_with("use std::") || trimmed.starts_with("use crate::"))
404            && trimmed.ends_with(';')
405        {
406            return Some(".rs".to_string());
407        }
408
409        // Elixir: defmodule
410        if trimmed.starts_with("defmodule ") {
411            return Some(".ex".to_string());
412        }
413
414        // Kotlin: package with dots but no semicolon (Kotlin doesn't require semicolons)
415        if trimmed.starts_with("package ") && trimmed.contains('.') && !trimmed.ends_with(';') {
416            return Some(".kt".to_string());
417        }
418
419        // C#: using System or namespace with braces
420        if trimmed.starts_with("using System") && trimmed.ends_with(';') {
421            return Some(".cs".to_string());
422        }
423        if trimmed.starts_with("namespace ") && trimmed.ends_with('{') {
424            // Could be C++ too, but C++ usually has #include before namespace
425            // If we got here without matching #include, it's likely C#
426            return Some(".cs".to_string());
427        }
428
429        // Swift: import Foundation/UIKit/SwiftUI
430        if trimmed == "import Foundation"
431            || trimmed == "import UIKit"
432            || trimmed == "import SwiftUI"
433            || trimmed == "import Combine"
434        {
435            return Some(".swift".to_string());
436        }
437
438        // Dart: import 'dart:
439        if trimmed.starts_with("import 'dart:") || trimmed.starts_with("import \"dart:") {
440            return Some(".dart".to_string());
441        }
442
443        // Scala: object/trait at top level
444        if trimmed.starts_with("object ") || trimmed.starts_with("trait ") {
445            return Some(".scala".to_string());
446        }
447
448        // Zig: const std = @import
449        if trimmed.contains("@import(") {
450            return Some(".zig".to_string());
451        }
452
453        // HCL/Terraform: resource/variable/terraform blocks
454        if trimmed.starts_with("resource \"")
455            || trimmed.starts_with("variable \"")
456            || trimmed.starts_with("terraform {")
457            || trimmed.starts_with("provider \"")
458        {
459            return Some(".tf".to_string());
460        }
461
462        // Fortran: program/module/subroutine (case-insensitive)
463        let lower = trimmed.to_lowercase();
464        if lower.starts_with("program ") || lower.starts_with("module ")
465            || lower.starts_with("subroutine ") || lower == "implicit none"
466        {
467            // "module " could be Ruby, but Ruby uses "module X" without "implicit none"
468            // Check for Fortran-specific follow-up
469            if lower.starts_with("program ") || lower == "implicit none" {
470                return Some(".f90".to_string());
471            }
472            if content.lines().take(20).any(|l| l.trim().to_lowercase() == "implicit none") {
473                return Some(".f90".to_string());
474            }
475        }
476
477        // Python: def/class at indentation level 0 with colon
478        if (trimmed.starts_with("def ") || trimmed.starts_with("class "))
479            && trimmed.ends_with(':')
480            && line.starts_with(trimmed.chars().next().unwrap_or(' '))
481        {
482            return Some(".py".to_string());
483        }
484
485        // Ruby: require or module/class without colon (Python uses colon)
486        if trimmed.starts_with("require '") || trimmed.starts_with("require \"")
487            || trimmed.starts_with("require_relative ")
488        {
489            return Some(".rb".to_string());
490        }
491
492        // Perl: use strict/warnings, or variable declarations with sigils
493        if trimmed == "use strict;"
494            || trimmed == "use warnings;"
495            || trimmed.starts_with("my $")
496            || trimmed.starts_with("my @")
497            || trimmed.starts_with("my %")
498        {
499            return Some(".pl".to_string());
500        }
501    }
502
503    None
504}
505
506fn extract_vim_filetype(line: &str) -> Option<&str> {
507    // Match patterns: `vim: ft=X`, `vim: filetype=X`, `vim: set ft=X`
508    let line = line.trim();
509    let vim_idx = line.find("vim:")?;
510    let after_vim = &line[vim_idx + 4..];
511
512    for token in after_vim.split_whitespace() {
513        if let Some(val) = token.strip_prefix("ft=") {
514            return Some(val.trim_end_matches(':'));
515        }
516        if let Some(val) = token.strip_prefix("filetype=") {
517            return Some(val.trim_end_matches(':'));
518        }
519    }
520    None
521}
522
523#[cfg(test)]
524mod tests {
525    use crate::parser::plugins::create_default_registry;
526
527    #[test]
528    fn test_registry_matches_compound_svelte_typescript_suffix() {
529        let registry = create_default_registry();
530        let plugin = registry
531            .get_plugin("src/routes/+page.svelte.ts")
532            .expect("plugin should exist");
533
534        assert_eq!(plugin.id(), "svelte");
535    }
536
537    #[test]
538    fn test_registry_matches_compound_svelte_javascript_suffix() {
539        let registry = create_default_registry();
540        let plugin = registry
541            .get_plugin("src/routes/+layout.svelte.js")
542            .expect("plugin should exist");
543
544        assert_eq!(plugin.id(), "svelte");
545    }
546
547    #[test]
548    fn test_registry_matches_svelte_test_suffix() {
549        let registry = create_default_registry();
550        let plugin = registry
551            .get_plugin("src/lib/multiplier.svelte.test.js")
552            .expect("plugin should exist");
553
554        assert_eq!(plugin.id(), "svelte");
555    }
556
557    #[test]
558    fn test_registry_prefers_svelte_plugin_for_component_files() {
559        let registry = create_default_registry();
560        let plugin = registry
561            .get_plugin("src/lib/Component.svelte")
562            .expect("plugin should exist");
563
564        assert_eq!(plugin.id(), "svelte");
565    }
566
567    #[test]
568    fn test_registry_matches_typescript_module_suffix() {
569        let registry = create_default_registry();
570        let plugin = registry
571            .get_plugin("src/lib/index.mts")
572            .expect("plugin should exist");
573
574        assert_eq!(plugin.id(), "code");
575    }
576
577    #[test]
578    fn test_registry_matches_typescript_commonjs_suffix() {
579        let registry = create_default_registry();
580        let plugin = registry
581            .get_plugin("src/lib/index.cts")
582            .expect("plugin should exist");
583
584        assert_eq!(plugin.id(), "code");
585    }
586
587    #[test]
588    fn test_detect_php_from_opening_tag() {
589        let registry = create_default_registry();
590        let content = "<?php\nclass Vendor {\n    function get_name() { return $this->name; }\n}\n";
591        let plugin = registry
592            .get_plugin_with_content("vendor.inc2", content)
593            .expect("should detect PHP");
594        let entities = plugin.extract_entities(content, "vendor.inc2");
595        assert!(entities.iter().any(|e| e.entity_type == "class"));
596    }
597
598    #[test]
599    fn test_detect_c_from_include() {
600        let registry = create_default_registry();
601        let content = "#include <stdio.h>\n\nint main() {\n    printf(\"hello\");\n    return 0;\n}\n";
602        let plugin = registry
603            .get_plugin_with_content("main.xyz", content)
604            .expect("should detect C");
605        let entities = plugin.extract_entities(content, "main.xyz");
606        assert!(entities.iter().any(|e| e.name == "main"));
607    }
608
609    #[test]
610    fn test_detect_java_from_package() {
611        let registry = create_default_registry();
612        let content = "package com.example.app;\n\npublic class Main {\n    public static void main(String[] args) {}\n}\n";
613        let plugin = registry
614            .get_plugin_with_content("Main", content)
615            .expect("should detect Java");
616        let entities = plugin.extract_entities(content, "Main");
617        assert!(entities.iter().any(|e| e.name == "Main"));
618    }
619
620    #[test]
621    fn test_detect_go_from_package() {
622        let registry = create_default_registry();
623        let content = "package main\n\nimport \"fmt\"\n\nfunc hello() {\n    fmt.Println(\"hi\")\n}\n";
624        let plugin = registry
625            .get_plugin_with_content("main", content)
626            .expect("should detect Go");
627        let entities = plugin.extract_entities(content, "main");
628        assert!(entities.iter().any(|e| e.name == "hello"));
629    }
630
631    #[test]
632    fn test_detect_rust_from_use_std() {
633        let registry = create_default_registry();
634        let content = "use std::collections::HashMap;\n\nfn process() {\n    let m = HashMap::new();\n}\n";
635        let plugin = registry
636            .get_plugin_with_content("lib", content)
637            .expect("should detect Rust");
638        let entities = plugin.extract_entities(content, "lib");
639        assert!(entities.iter().any(|e| e.name == "process"));
640    }
641
642    #[test]
643    fn test_extension_takes_priority_over_heuristics() {
644        let registry = create_default_registry();
645        // Content looks like PHP but file has .py extension
646        let content = "<?php\nclass Foo {}\n";
647        let plugin = registry
648            .get_plugin_with_content("script.py", content)
649            .expect("should use Python parser");
650        assert_eq!(plugin.id(), "code"); // Python uses code plugin, not PHP
651    }
652
653    #[test]
654    fn test_custom_extension_mapping_extracts_entities() {
655        let mut registry = create_default_registry();
656        registry.add_extension_mapping(".mypy", "python");
657
658        let content = "def hello():\n    print(\"hello world\")\n\nclass Calculator:\n    def multiply(self, a, b):\n        return a * b\n";
659        let entities = registry.extract_entities("utils.mypy", content);
660
661        assert!(!entities.is_empty(), "Should extract entities via custom mapping");
662        assert!(entities.iter().any(|e| e.name == "hello"), "Should find hello function");
663        assert!(entities.iter().any(|e| e.name == "Calculator"), "Should find Calculator class");
664        assert!(entities.iter().any(|e| e.name == "multiply"), "Should find multiply method");
665
666        // File path should preserve the original extension
667        for entity in &entities {
668            assert_eq!(entity.file_path, "utils.mypy", "Entity file_path should use original extension");
669            assert!(entity.id.starts_with("utils.mypy::"), "Entity ID should use original file path");
670        }
671    }
672}