Skip to main content

sem_core/parser/
registry.rs

1use std::collections::HashMap;
2use std::path::Path;
3use rayon::prelude::*;
4
5use crate::model::entity::SemanticEntity;
6use super::plugin::SemanticParserPlugin;
7
8pub struct ParserRegistry {
9    plugins: Vec<Box<dyn SemanticParserPlugin>>,
10    extension_map: HashMap<String, usize>, // ext → index into plugins
11    custom_ext_canonical: HashMap<String, String>, // ".mypy" → ".py" (custom → canonical)
12}
13
14impl ParserRegistry {
15    pub fn new() -> Self {
16        Self {
17            plugins: Vec::new(),
18            extension_map: HashMap::new(),
19            custom_ext_canonical: HashMap::new(),
20        }
21    }
22
23    pub fn register(&mut self, plugin: Box<dyn SemanticParserPlugin>) {
24        let idx = self.plugins.len();
25        for ext in plugin.extensions() {
26            self.extension_map.insert(ext.to_string(), idx);
27        }
28        self.plugins.push(plugin);
29    }
30
31    pub fn get_plugin(&self, file_path: &str) -> Option<&dyn SemanticParserPlugin> {
32        for ext in get_extensions(file_path) {
33            if let Some(&idx) = self.extension_map.get(&ext) {
34                return Some(self.plugins[idx].as_ref());
35            }
36        }
37        // Fallback plugin
38        self.get_plugin_by_id("fallback")
39    }
40
41    /// Try to detect language from shebang line when extension-based lookup fails.
42    /// Call this as a fallback when file content is available.
43    pub fn get_plugin_with_content(&self, file_path: &str, content: &str) -> Option<&dyn SemanticParserPlugin> {
44        // Try extension first
45        for ext in get_extensions(file_path) {
46            if let Some(&idx) = self.extension_map.get(&ext) {
47                return Some(self.plugins[idx].as_ref());
48            }
49        }
50        // Try shebang detection
51        if let Some(plugin) = self.detect_from_shebang(content) {
52            return Some(plugin);
53        }
54        // Fallback plugin
55        self.get_plugin_by_id("fallback")
56    }
57
58    fn detect_from_shebang(&self, content: &str) -> Option<&dyn SemanticParserPlugin> {
59        if let Some(ext) = detect_ext_from_content(content) {
60            if let Some(&idx) = self.extension_map.get(ext.as_str()) {
61                return Some(self.plugins[idx].as_ref());
62            }
63        }
64        None
65    }
66
67    pub fn get_plugin_by_id(&self, id: &str) -> Option<&dyn SemanticParserPlugin> {
68        self.plugins
69            .iter()
70            .find(|p| p.id() == id)
71            .map(|p| p.as_ref())
72    }
73
74    /// Register a custom extension mapping from a .semrc file.
75    /// Maps an extension (e.g. ".inc") to an existing plugin by language name.
76    pub fn add_extension_mapping(&mut self, ext: &str, language: &str) -> bool {
77        let ext = if ext.starts_with('.') {
78            ext.to_lowercase()
79        } else {
80            format!(".{}", ext.to_lowercase())
81        };
82
83        // Find plugin index by matching language name against known extensions
84        let target_ext = LANG_MAPPING
85            .iter()
86            .find(|(kw, _)| *kw == language.to_lowercase())
87            .map(|(_, e)| *e);
88
89        if let Some(target) = target_ext {
90            if let Some(&idx) = self.extension_map.get(target) {
91                self.custom_ext_canonical.insert(ext.clone(), target.to_string());
92                self.extension_map.insert(ext, idx);
93                return true;
94            }
95        }
96
97        // Also try matching directly against registered extensions
98        let direct_ext = format!(".{}", language.to_lowercase());
99        if let Some(&idx) = self.extension_map.get(&direct_ext) {
100            self.custom_ext_canonical.insert(ext.clone(), direct_ext);
101            self.extension_map.insert(ext, idx);
102            return true;
103        }
104
105        false
106    }
107
108    /// Load extension mappings from a .semrc file at the given root directory.
109    /// File format (one mapping per line): `.ext = language`
110    /// Example:
111    ///   .inc = php
112    ///   .j = json
113    ///   .xyz = cpp
114    pub fn load_semrc(&mut self, root: &Path) {
115        let semrc_path = root.join(".semrc");
116        if !semrc_path.exists() {
117            return;
118        }
119        let content = match std::fs::read_to_string(&semrc_path) {
120            Ok(c) => c,
121            Err(_) => return,
122        };
123        for line in content.lines() {
124            let line = line.trim();
125            if line.is_empty() || line.starts_with('#') {
126                continue;
127            }
128            if let Some((ext, lang)) = line.split_once('=') {
129                self.add_extension_mapping(ext.trim(), lang.trim());
130            }
131        }
132    }
133
134    /// Load extension mappings from `.gitattributes` at the given root directory.
135    /// Parses `*.ext diff=language` and `*.ext linguist-language=Language` patterns.
136    /// Only processes `*.ext` glob patterns (not path-based patterns).
137    pub fn load_gitattributes(&mut self, root: &Path) {
138        let ga_path = root.join(".gitattributes");
139        if !ga_path.exists() {
140            return;
141        }
142        let content = match std::fs::read_to_string(&ga_path) {
143            Ok(c) => c,
144            Err(_) => return,
145        };
146        for line in content.lines() {
147            let line = line.trim();
148            if line.is_empty() || line.starts_with('#') {
149                continue;
150            }
151            let mut parts = line.split_whitespace();
152            let pattern = match parts.next() {
153                Some(p) => p,
154                None => continue,
155            };
156            // Only handle *.ext patterns
157            let ext = match pattern.strip_prefix("*.") {
158                Some(e) => e,
159                None => continue,
160            };
161            // Already mapped (e.g. by .semrc which takes priority)
162            let ext_key = format!(".{}", ext.to_lowercase());
163            if self.custom_ext_canonical.contains_key(&ext_key) {
164                continue;
165            }
166            // Look for diff= or linguist-language= attributes
167            for attr in parts {
168                if let Some(lang) = attr.strip_prefix("diff=") {
169                    self.add_extension_mapping(ext, lang);
170                    break;
171                }
172                if let Some(lang) = attr.strip_prefix("linguist-language=") {
173                    self.add_extension_mapping(ext, lang);
174                    break;
175                }
176            }
177        }
178    }
179
180    /// Resolve custom extension mappings in a file path.
181    /// E.g. if `.mypy` is mapped to `python` (canonical `.py`),
182    /// `"utils.mypy"` becomes `"utils.py"`.
183    pub fn resolve_file_path(&self, file_path: &str) -> Option<String> {
184        let path = Path::new(file_path);
185        let ext = path
186            .extension()
187            .and_then(|e| e.to_str())
188            .map(|e| format!(".{}", e.to_lowercase()))?;
189
190        let canonical = self.custom_ext_canonical.get(&ext)?;
191        let stem = path.file_stem().and_then(|s| s.to_str())?;
192
193        if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
194            Some(format!("{}/{}{}", parent.display(), stem, canonical))
195        } else {
196            Some(format!("{}{}", stem, canonical))
197        }
198    }
199
200    /// Extract entities, transparently handling custom extension mappings.
201    /// Uses the resolved path for language detection but restores the original
202    /// file path in entity metadata (file_path, id, parent_id).
203    pub fn extract_entities(&self, file_path: &str, content: &str) -> Vec<SemanticEntity> {
204        let resolved = self.resolve_file_path(file_path);
205        let detection_path = resolved.as_deref().unwrap_or(file_path);
206
207        let plugin = match self.get_plugin_with_content(detection_path, content) {
208            Some(p) => p,
209            None => return Vec::new(),
210        };
211
212        let mut entities = plugin.extract_entities(content, detection_path);
213        if let Some(ref rp) = resolved {
214            fix_entity_paths(&mut entities, file_path, rp);
215        }
216        entities
217    }
218
219    /// Extract entities with tree, transparently handling custom extension mappings.
220    pub fn extract_entities_with_tree(
221        &self,
222        file_path: &str,
223        content: &str,
224    ) -> Option<(Vec<SemanticEntity>, Option<tree_sitter::Tree>)> {
225        let resolved = self.resolve_file_path(file_path);
226        let detection_path = resolved.as_deref().unwrap_or(file_path);
227
228        let plugin = self.get_plugin_with_content(detection_path, content)?;
229        let (mut entities, tree) = plugin.extract_entities_with_tree(content, detection_path);
230        if let Some(ref rp) = resolved {
231            fix_entity_paths(&mut entities, file_path, rp);
232        }
233        Some((entities, tree))
234    }
235
236    /// Extract entities from multiple files in parallel.
237    pub fn extract_all_entities(
238        &self,
239        root: &Path,
240        file_paths: &[String],
241    ) -> Vec<SemanticEntity> {
242        file_paths
243            .par_iter()
244            .flat_map(|fp| {
245                let full = root.join(fp);
246                let content = match std::fs::read_to_string(&full) {
247                    Ok(c) => c,
248                    Err(_) => return Vec::new(),
249                };
250                self.extract_entities(fp, &content)
251            })
252            .collect()
253    }
254}
255
256/// Restore original file path in entities when a custom extension mapping was used.
257fn fix_entity_paths(entities: &mut [SemanticEntity], original: &str, resolved: &str) {
258    for entity in entities {
259        entity.file_path = original.to_string();
260        entity.id = entity.id.replace(resolved, original);
261        if let Some(ref mut pid) = entity.parent_id {
262            *pid = pid.replace(resolved, original);
263        }
264    }
265}
266
267fn get_extensions(file_path: &str) -> Vec<String> {
268    let Some(file_name) = Path::new(file_path)
269        .file_name()
270        .and_then(|name| name.to_str())
271    else {
272        return Vec::new();
273    };
274
275    let file_name = file_name.to_lowercase();
276    let mut extensions = Vec::new();
277
278    for (idx, ch) in file_name.char_indices() {
279        if ch == '.' {
280            extensions.push(file_name[idx..].to_string());
281        }
282    }
283
284    extensions
285}
286
287const LANG_MAPPING: &[(&str, &str)] = &[
288    ("perl", ".pl"),
289    ("python", ".py"),
290    ("ruby", ".rb"),
291    ("bash", ".sh"),
292    ("shell", ".sh"),
293    ("/sh", ".sh"),
294    ("node", ".js"),
295    ("javascript", ".js"),
296    ("typescript", ".ts"),
297    ("tsx", ".tsx"),
298    ("swift", ".swift"),
299    ("elixir", ".ex"),
300    ("rust", ".rs"),
301    ("go", ".go"),
302    ("golang", ".go"),
303    ("kotlin", ".kt"),
304    ("dart", ".dart"),
305    ("php", ".php"),
306    ("java", ".java"),
307    ("c", ".c"),
308    ("cpp", ".cpp"),
309    ("c++", ".cpp"),
310    ("cs", ".cs"),
311    ("csharp", ".cs"),
312    ("c#", ".cs"),
313    ("fortran", ".f90"),
314    ("terraform", ".tf"),
315    ("hcl", ".hcl"),
316    ("ocaml", ".ml"),
317    ("scala", ".scala"),
318    ("zig", ".zig"),
319    ("xml", ".xml"),
320    ("json", ".json"),
321    ("yaml", ".yaml"),
322    ("yml", ".yaml"),
323    ("toml", ".toml"),
324    ("markdown", ".md"),
325    ("csv", ".csv"),
326    ("eruby", ".erb"),
327    ("erb", ".erb"),
328    ("vue", ".vue"),
329    ("svelte", ".svelte"),
330];
331
332/// Detect file extension from shebang line, vim modeline, or content heuristics.
333pub fn detect_ext_from_content(content: &str) -> Option<String> {
334    // Try shebang (first line)
335    if let Some(first_line) = content.lines().next() {
336        if first_line.starts_with("#!") {
337            let shebang = first_line.to_lowercase();
338            for (keyword, ext) in LANG_MAPPING {
339                if shebang.contains(keyword) {
340                    return Some(ext.to_string());
341                }
342            }
343        }
344    }
345
346    // Try vim modeline (first 5 or last 5 lines)
347    // Formats: `vim: ft=perl`, `vim: filetype=perl`, `vim: set ft=perl`
348    let lines: Vec<&str> = content.lines().collect();
349    let check_lines = lines.iter().take(5).chain(lines.iter().rev().take(5));
350    for line in check_lines {
351        if let Some(ft) = extract_vim_filetype(line) {
352            let ft_lower = ft.to_lowercase();
353            for (keyword, ext) in LANG_MAPPING {
354                if ft_lower == *keyword {
355                    return Some(ext.to_string());
356                }
357            }
358        }
359    }
360
361    // Try content heuristics (first-line markers and early declarations)
362    if let Some(ext) = detect_from_content_heuristics(content) {
363        return Some(ext);
364    }
365
366    None
367}
368
369/// High-confidence content-based language detection.
370/// Only uses markers with near-zero false-positive rates.
371fn detect_from_content_heuristics(content: &str) -> Option<String> {
372    let first_line = content.lines().next().unwrap_or("").trim();
373
374    // PHP: opening tag is unambiguous
375    if first_line.starts_with("<?php") || first_line.starts_with("<?PHP") {
376        return Some(".php".to_string());
377    }
378
379    // XML/SVG/HTML: XML declaration or doctype
380    if first_line.starts_with("<?xml") {
381        return Some(".xml".to_string());
382    }
383    if first_line.starts_with("<!DOCTYPE") || first_line.starts_with("<!doctype") {
384        return Some(".xml".to_string());
385    }
386
387    // Scan first ~20 lines for language-specific patterns
388    for line in content.lines().take(20) {
389        let trimmed = line.trim();
390
391        // PHP: opening tag anywhere in early lines
392        if trimmed.starts_with("<?php") || trimmed.starts_with("<?PHP") || trimmed == "<?=" {
393            return Some(".php".to_string());
394        }
395
396        // C/C++: #include directive
397        if trimmed.starts_with("#include ") || trimmed.starts_with("#include\t") {
398            // Could be C or C++. Check for C++ indicators.
399            if content.lines().take(30).any(|l| {
400                let t = l.trim();
401                t.starts_with("using namespace")
402                    || t.starts_with("class ")
403                    || t.starts_with("#include <iostream")
404                    || t.starts_with("#include <vector")
405                    || t.starts_with("#include <string>")
406                    || t.starts_with("#include <memory")
407            }) {
408                return Some(".cpp".to_string());
409            }
410            return Some(".c".to_string());
411        }
412
413        // Java: package declaration with dots
414        if trimmed.starts_with("package ") && trimmed.contains('.') && trimmed.ends_with(';') {
415            return Some(".java".to_string());
416        }
417
418        // Go: package declaration without dots or semicolons
419        if trimmed.starts_with("package ") && !trimmed.contains('.') && !trimmed.contains(';') {
420            return Some(".go".to_string());
421        }
422
423        // Rust: common top-level declarations
424        if (trimmed.starts_with("use std::") || trimmed.starts_with("use crate::"))
425            && trimmed.ends_with(';')
426        {
427            return Some(".rs".to_string());
428        }
429
430        // Elixir: defmodule
431        if trimmed.starts_with("defmodule ") {
432            return Some(".ex".to_string());
433        }
434
435        // Kotlin: package with dots but no semicolon (Kotlin doesn't require semicolons)
436        if trimmed.starts_with("package ") && trimmed.contains('.') && !trimmed.ends_with(';') {
437            return Some(".kt".to_string());
438        }
439
440        // C#: using System or namespace with braces
441        if trimmed.starts_with("using System") && trimmed.ends_with(';') {
442            return Some(".cs".to_string());
443        }
444        if trimmed.starts_with("namespace ") && trimmed.ends_with('{') {
445            // Could be C++ too, but C++ usually has #include before namespace
446            // If we got here without matching #include, it's likely C#
447            return Some(".cs".to_string());
448        }
449
450        // Swift: import Foundation/UIKit/SwiftUI
451        if trimmed == "import Foundation"
452            || trimmed == "import UIKit"
453            || trimmed == "import SwiftUI"
454            || trimmed == "import Combine"
455        {
456            return Some(".swift".to_string());
457        }
458
459        // Dart: import 'dart:
460        if trimmed.starts_with("import 'dart:") || trimmed.starts_with("import \"dart:") {
461            return Some(".dart".to_string());
462        }
463
464        // Scala: object/trait at top level
465        if trimmed.starts_with("object ") || trimmed.starts_with("trait ") {
466            return Some(".scala".to_string());
467        }
468
469        // Zig: const std = @import
470        if trimmed.contains("@import(") {
471            return Some(".zig".to_string());
472        }
473
474        // HCL/Terraform: resource/variable/terraform blocks
475        if trimmed.starts_with("resource \"")
476            || trimmed.starts_with("variable \"")
477            || trimmed.starts_with("terraform {")
478            || trimmed.starts_with("provider \"")
479        {
480            return Some(".tf".to_string());
481        }
482
483        // Fortran: program/module/subroutine (case-insensitive)
484        let lower = trimmed.to_lowercase();
485        if lower.starts_with("program ") || lower.starts_with("module ")
486            || lower.starts_with("subroutine ") || lower == "implicit none"
487        {
488            // "module " could be Ruby, but Ruby uses "module X" without "implicit none"
489            // Check for Fortran-specific follow-up
490            if lower.starts_with("program ") || lower == "implicit none" {
491                return Some(".f90".to_string());
492            }
493            if content.lines().take(20).any(|l| l.trim().to_lowercase() == "implicit none") {
494                return Some(".f90".to_string());
495            }
496        }
497
498        // Python: def/class at indentation level 0 with colon
499        if (trimmed.starts_with("def ") || trimmed.starts_with("class "))
500            && trimmed.ends_with(':')
501            && line.starts_with(trimmed.chars().next().unwrap_or(' '))
502        {
503            return Some(".py".to_string());
504        }
505
506        // Ruby: require or module/class without colon (Python uses colon)
507        if trimmed.starts_with("require '") || trimmed.starts_with("require \"")
508            || trimmed.starts_with("require_relative ")
509        {
510            return Some(".rb".to_string());
511        }
512
513        // Perl: use strict/warnings, or variable declarations with sigils
514        if trimmed == "use strict;"
515            || trimmed == "use warnings;"
516            || trimmed.starts_with("my $")
517            || trimmed.starts_with("my @")
518            || trimmed.starts_with("my %")
519        {
520            return Some(".pl".to_string());
521        }
522    }
523
524    None
525}
526
527fn extract_vim_filetype(line: &str) -> Option<&str> {
528    // Match patterns: `vim: ft=X`, `vim: filetype=X`, `vim: set ft=X`
529    let line = line.trim();
530    let vim_idx = line.find("vim:")?;
531    let after_vim = &line[vim_idx + 4..];
532
533    for token in after_vim.split_whitespace() {
534        if let Some(val) = token.strip_prefix("ft=") {
535            return Some(val.trim_end_matches(':'));
536        }
537        if let Some(val) = token.strip_prefix("filetype=") {
538            return Some(val.trim_end_matches(':'));
539        }
540    }
541    None
542}
543
544#[cfg(test)]
545mod tests {
546    use crate::parser::plugins::create_default_registry;
547
548    #[test]
549    fn test_registry_matches_compound_svelte_typescript_suffix() {
550        let registry = create_default_registry();
551        let plugin = registry
552            .get_plugin("src/routes/+page.svelte.ts")
553            .expect("plugin should exist");
554
555        assert_eq!(plugin.id(), "svelte");
556    }
557
558    #[test]
559    fn test_registry_matches_compound_svelte_javascript_suffix() {
560        let registry = create_default_registry();
561        let plugin = registry
562            .get_plugin("src/routes/+layout.svelte.js")
563            .expect("plugin should exist");
564
565        assert_eq!(plugin.id(), "svelte");
566    }
567
568    #[test]
569    fn test_registry_matches_svelte_test_suffix() {
570        let registry = create_default_registry();
571        let plugin = registry
572            .get_plugin("src/lib/multiplier.svelte.test.js")
573            .expect("plugin should exist");
574
575        assert_eq!(plugin.id(), "svelte");
576    }
577
578    #[test]
579    fn test_registry_prefers_svelte_plugin_for_component_files() {
580        let registry = create_default_registry();
581        let plugin = registry
582            .get_plugin("src/lib/Component.svelte")
583            .expect("plugin should exist");
584
585        assert_eq!(plugin.id(), "svelte");
586    }
587
588    #[test]
589    fn test_registry_matches_typescript_module_suffix() {
590        let registry = create_default_registry();
591        let plugin = registry
592            .get_plugin("src/lib/index.mts")
593            .expect("plugin should exist");
594
595        assert_eq!(plugin.id(), "code");
596    }
597
598    #[test]
599    fn test_registry_matches_typescript_commonjs_suffix() {
600        let registry = create_default_registry();
601        let plugin = registry
602            .get_plugin("src/lib/index.cts")
603            .expect("plugin should exist");
604
605        assert_eq!(plugin.id(), "code");
606    }
607
608    #[test]
609    fn test_detect_php_from_opening_tag() {
610        let registry = create_default_registry();
611        let content = "<?php\nclass Vendor {\n    function get_name() { return $this->name; }\n}\n";
612        let plugin = registry
613            .get_plugin_with_content("vendor.inc2", content)
614            .expect("should detect PHP");
615        let entities = plugin.extract_entities(content, "vendor.inc2");
616        assert!(entities.iter().any(|e| e.entity_type == "class"));
617    }
618
619    #[test]
620    fn test_detect_c_from_include() {
621        let registry = create_default_registry();
622        let content = "#include <stdio.h>\n\nint main() {\n    printf(\"hello\");\n    return 0;\n}\n";
623        let plugin = registry
624            .get_plugin_with_content("main.xyz", content)
625            .expect("should detect C");
626        let entities = plugin.extract_entities(content, "main.xyz");
627        assert!(entities.iter().any(|e| e.name == "main"));
628    }
629
630    #[test]
631    fn test_detect_java_from_package() {
632        let registry = create_default_registry();
633        let content = "package com.example.app;\n\npublic class Main {\n    public static void main(String[] args) {}\n}\n";
634        let plugin = registry
635            .get_plugin_with_content("Main", content)
636            .expect("should detect Java");
637        let entities = plugin.extract_entities(content, "Main");
638        assert!(entities.iter().any(|e| e.name == "Main"));
639    }
640
641    #[test]
642    fn test_detect_go_from_package() {
643        let registry = create_default_registry();
644        let content = "package main\n\nimport \"fmt\"\n\nfunc hello() {\n    fmt.Println(\"hi\")\n}\n";
645        let plugin = registry
646            .get_plugin_with_content("main", content)
647            .expect("should detect Go");
648        let entities = plugin.extract_entities(content, "main");
649        assert!(entities.iter().any(|e| e.name == "hello"));
650    }
651
652    #[test]
653    fn test_detect_rust_from_use_std() {
654        let registry = create_default_registry();
655        let content = "use std::collections::HashMap;\n\nfn process() {\n    let m = HashMap::new();\n}\n";
656        let plugin = registry
657            .get_plugin_with_content("lib", content)
658            .expect("should detect Rust");
659        let entities = plugin.extract_entities(content, "lib");
660        assert!(entities.iter().any(|e| e.name == "process"));
661    }
662
663    #[test]
664    fn test_extension_takes_priority_over_heuristics() {
665        let registry = create_default_registry();
666        // Content looks like PHP but file has .py extension
667        let content = "<?php\nclass Foo {}\n";
668        let plugin = registry
669            .get_plugin_with_content("script.py", content)
670            .expect("should use Python parser");
671        assert_eq!(plugin.id(), "code"); // Python uses code plugin, not PHP
672    }
673
674    #[test]
675    fn test_custom_extension_mapping_extracts_entities() {
676        let mut registry = create_default_registry();
677        registry.add_extension_mapping(".mypy", "python");
678
679        let content = "def hello():\n    print(\"hello world\")\n\nclass Calculator:\n    def multiply(self, a, b):\n        return a * b\n";
680        let entities = registry.extract_entities("utils.mypy", content);
681
682        assert!(!entities.is_empty(), "Should extract entities via custom mapping");
683        assert!(entities.iter().any(|e| e.name == "hello"), "Should find hello function");
684        assert!(entities.iter().any(|e| e.name == "Calculator"), "Should find Calculator class");
685        assert!(entities.iter().any(|e| e.name == "multiply"), "Should find multiply method");
686
687        // File path should preserve the original extension
688        for entity in &entities {
689            assert_eq!(entity.file_path, "utils.mypy", "Entity file_path should use original extension");
690            assert!(entity.id.starts_with("utils.mypy::"), "Entity ID should use original file path");
691        }
692    }
693}