1use std::collections::HashMap;
2use std::path::Path;
3#[cfg(feature = "parallel")]
4use rayon::prelude::*;
5
6use crate::model::entity::{build_entity_id, SemanticEntity};
7
8macro_rules! maybe_par_iter {
9 ($slice:expr) => {{
10 #[cfg(feature = "parallel")]
11 { $slice.par_iter() }
12 #[cfg(not(feature = "parallel"))]
13 { $slice.iter() }
14 }};
15}
16use super::plugin::SemanticParserPlugin;
17
18pub struct ParserRegistry {
19 plugins: Vec<Box<dyn SemanticParserPlugin>>,
20 extension_map: HashMap<String, usize>, custom_ext_canonical: HashMap<String, String>, }
23
24impl ParserRegistry {
25 pub fn new() -> Self {
26 Self {
27 plugins: Vec::new(),
28 extension_map: HashMap::new(),
29 custom_ext_canonical: HashMap::new(),
30 }
31 }
32
33 pub fn register(&mut self, plugin: Box<dyn SemanticParserPlugin>) {
34 let idx = self.plugins.len();
35 for ext in plugin.extensions() {
36 self.extension_map.insert(ext.to_string(), idx);
37 }
38 self.plugins.push(plugin);
39 }
40
41 pub fn get_plugin(&self, file_path: &str) -> Option<&dyn SemanticParserPlugin> {
42 for ext in get_extensions(file_path) {
43 if let Some(&idx) = self.extension_map.get(&ext) {
44 return Some(self.plugins[idx].as_ref());
45 }
46 }
47 self.get_plugin_by_id("fallback")
49 }
50
51 pub fn get_plugin_with_content(&self, file_path: &str, content: &str) -> Option<&dyn SemanticParserPlugin> {
54 for ext in get_extensions(file_path) {
56 if let Some(&idx) = self.extension_map.get(&ext) {
57 return Some(self.plugins[idx].as_ref());
58 }
59 }
60 if let Some(plugin) = self.detect_from_shebang(content) {
62 return Some(plugin);
63 }
64 self.get_plugin_by_id("fallback")
66 }
67
68 fn detect_from_shebang(&self, content: &str) -> Option<&dyn SemanticParserPlugin> {
69 if let Some(ext) = detect_ext_from_content(content) {
70 if let Some(&idx) = self.extension_map.get(ext.as_str()) {
71 return Some(self.plugins[idx].as_ref());
72 }
73 }
74 None
75 }
76
77 pub fn get_plugin_by_id(&self, id: &str) -> Option<&dyn SemanticParserPlugin> {
78 self.plugins
79 .iter()
80 .find(|p| p.id() == id)
81 .map(|p| p.as_ref())
82 }
83
84 pub fn add_extension_mapping(&mut self, ext: &str, language: &str) -> bool {
87 let ext = if ext.starts_with('.') {
88 ext.to_lowercase()
89 } else {
90 format!(".{}", ext.to_lowercase())
91 };
92
93 let target_ext = LANG_MAPPING
95 .iter()
96 .find(|(kw, _)| *kw == language.to_lowercase())
97 .map(|(_, e)| *e);
98
99 if let Some(target) = target_ext {
100 if let Some(&idx) = self.extension_map.get(target) {
101 self.custom_ext_canonical.insert(ext.clone(), target.to_string());
102 self.extension_map.insert(ext, idx);
103 return true;
104 }
105 }
106
107 let direct_ext = format!(".{}", language.to_lowercase());
109 if let Some(&idx) = self.extension_map.get(&direct_ext) {
110 self.custom_ext_canonical.insert(ext.clone(), direct_ext);
111 self.extension_map.insert(ext, idx);
112 return true;
113 }
114
115 false
116 }
117
118 pub fn load_semrc(&mut self, root: &Path) {
125 let semrc_path = root.join(".semrc");
126 if !semrc_path.exists() {
127 return;
128 }
129 let content = match std::fs::read_to_string(&semrc_path) {
130 Ok(c) => c,
131 Err(_) => return,
132 };
133 for line in content.lines() {
134 let line = line.trim();
135 if line.is_empty() || line.starts_with('#') {
136 continue;
137 }
138 if let Some((ext, lang)) = line.split_once('=') {
139 self.add_extension_mapping(ext.trim(), lang.trim());
140 }
141 }
142 }
143
144 pub fn load_gitattributes(&mut self, root: &Path) {
148 let ga_path = root.join(".gitattributes");
149 if !ga_path.exists() {
150 return;
151 }
152 let content = match std::fs::read_to_string(&ga_path) {
153 Ok(c) => c,
154 Err(_) => return,
155 };
156 for line in content.lines() {
157 let line = line.trim();
158 if line.is_empty() || line.starts_with('#') {
159 continue;
160 }
161 let mut parts = line.split_whitespace();
162 let pattern = match parts.next() {
163 Some(p) => p,
164 None => continue,
165 };
166 let ext = match pattern.strip_prefix("*.") {
168 Some(e) => e,
169 None => continue,
170 };
171 let ext_key = format!(".{}", ext.to_lowercase());
173 if self.custom_ext_canonical.contains_key(&ext_key) {
174 continue;
175 }
176 for attr in parts {
178 if let Some(lang) = attr.strip_prefix("diff=") {
179 self.add_extension_mapping(ext, lang);
180 break;
181 }
182 if let Some(lang) = attr.strip_prefix("linguist-language=") {
183 self.add_extension_mapping(ext, lang);
184 break;
185 }
186 }
187 }
188 }
189
190 pub fn resolve_file_path(&self, file_path: &str) -> Option<String> {
194 let path = Path::new(file_path);
195 let ext = path
196 .extension()
197 .and_then(|e| e.to_str())
198 .map(|e| format!(".{}", e.to_lowercase()))?;
199
200 let canonical = self.custom_ext_canonical.get(&ext)?;
201 let stem = path.file_stem().and_then(|s| s.to_str())?;
202
203 if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
204 Some(format!("{}/{}{}", parent.display(), stem, canonical))
205 } else {
206 Some(format!("{}{}", stem, canonical))
207 }
208 }
209
210 pub fn extract_entities(&self, file_path: &str, content: &str) -> Vec<SemanticEntity> {
214 let resolved = self.resolve_file_path(file_path);
215 let detection_path = resolved.as_deref().unwrap_or(file_path);
216
217 let plugin = match self.get_plugin_with_content(detection_path, content) {
218 Some(p) => p,
219 None => return Vec::new(),
220 };
221
222 let mut entities = plugin.extract_entities(content, detection_path);
223 if let Some(ref rp) = resolved {
224 fix_entity_paths(&mut entities, file_path, rp);
225 }
226 entities
227 }
228
229 pub fn extract_entities_with_tree(
231 &self,
232 file_path: &str,
233 content: &str,
234 ) -> Option<(Vec<SemanticEntity>, Option<tree_sitter::Tree>)> {
235 let resolved = self.resolve_file_path(file_path);
236 let detection_path = resolved.as_deref().unwrap_or(file_path);
237
238 let plugin = self.get_plugin_with_content(detection_path, content)?;
239 let (mut entities, tree) = plugin.extract_entities_with_tree(content, detection_path);
240 if let Some(ref rp) = resolved {
241 fix_entity_paths(&mut entities, file_path, rp);
242 }
243 Some((entities, tree))
244 }
245
246 pub fn extract_all_entities(
248 &self,
249 root: &Path,
250 file_paths: &[String],
251 ) -> Vec<SemanticEntity> {
252 let mut entities: Vec<SemanticEntity> = maybe_par_iter!(file_paths)
253 .flat_map(|fp| {
254 let full = root.join(fp);
255 let content = match std::fs::read_to_string(&full) {
256 Ok(c) => c,
257 Err(_) => return Vec::new(),
258 };
259 self.extract_entities(fp, &content)
260 })
261 .collect();
262 resolve_go_method_parent_ids(&mut entities);
263 entities
264 }
265}
266
267pub fn resolve_go_method_parent_ids(entities: &mut [SemanticEntity]) {
268 let mut types_by_package: HashMap<(String, String, String), String> = HashMap::new();
269
270 for entity in entities.iter() {
271 if !is_go_file(&entity.file_path) || !is_go_receiver_type_entity(entity) {
272 continue;
273 }
274
275 let package_name = go_package_name(entity).unwrap_or("");
276
277 types_by_package
278 .entry((
279 go_package_dir(&entity.file_path).to_string(),
280 package_name.to_string(),
281 entity.name.clone(),
282 ))
283 .or_insert_with(|| entity.id.clone());
284 }
285
286 for entity in entities.iter_mut() {
287 if !is_go_file(&entity.file_path) || entity.entity_type != "method" {
288 continue;
289 }
290
291 let package_name = go_package_name(entity).unwrap_or("");
292 let Some(receiver_name) = extract_go_receiver_type_name(&entity.content) else {
293 continue;
294 };
295
296 let key = (
297 go_package_dir(&entity.file_path).to_string(),
298 package_name.to_string(),
299 receiver_name,
300 );
301
302 let Some(parent_id) = types_by_package.get(&key) else {
303 continue;
304 };
305
306 if entity.parent_id.as_deref() == Some(parent_id.as_str()) {
307 continue;
308 }
309
310 entity.parent_id = Some(parent_id.clone());
311 entity.id = build_entity_id(
312 &entity.file_path,
313 &entity.entity_type,
314 &entity.name,
315 Some(parent_id),
316 );
317 }
318}
319
320fn is_go_file(file_path: &str) -> bool {
321 file_path.ends_with(".go")
322}
323
324fn is_go_receiver_type_entity(entity: &SemanticEntity) -> bool {
325 matches!(
326 entity.entity_type.as_str(),
327 "type" | "struct" | "class" | "interface"
328 )
329}
330
331fn go_package_name(entity: &SemanticEntity) -> Option<&str> {
332 entity
333 .metadata
334 .as_ref()
335 .and_then(|metadata| metadata.get("go.package"))
336 .map(String::as_str)
337}
338
339fn go_package_dir(file_path: &str) -> &str {
340 file_path.rsplit_once('/').map_or("", |(dir, _)| dir)
341}
342
343fn extract_go_receiver_type_name(content: &str) -> Option<String> {
344 let after_func = content.trim_start().strip_prefix("func")?.trim_start();
345 let receiver = after_func.strip_prefix('(')?;
346 let receiver_end = receiver.find(')')?;
347 let receiver = receiver[..receiver_end].trim();
348 if receiver.is_empty() {
349 return None;
350 }
351
352 let receiver_type = receiver.split_whitespace().last().unwrap_or(receiver);
353
354 let receiver_type = receiver_type.trim_start_matches('*').trim();
355 let receiver_type = receiver_type
356 .split_once('[')
357 .map_or(receiver_type, |(name, _)| name)
358 .trim();
359 let receiver_type = receiver_type
360 .rsplit_once('.')
361 .map_or(receiver_type, |(_, name)| name)
362 .trim();
363
364 (!receiver_type.is_empty()).then(|| receiver_type.to_string())
365}
366
367fn fix_entity_paths(entities: &mut [SemanticEntity], original: &str, resolved: &str) {
369 for entity in entities {
370 entity.file_path = original.to_string();
371 entity.id = entity.id.replace(resolved, original);
372 if let Some(ref mut pid) = entity.parent_id {
373 *pid = pid.replace(resolved, original);
374 }
375 }
376}
377
378fn get_extensions(file_path: &str) -> Vec<String> {
379 let Some(file_name) = Path::new(file_path)
380 .file_name()
381 .and_then(|name| name.to_str())
382 else {
383 return Vec::new();
384 };
385
386 let file_name = file_name.to_lowercase();
387 let mut extensions = Vec::new();
388
389 for (idx, ch) in file_name.char_indices() {
390 if ch == '.' {
391 extensions.push(file_name[idx..].to_string());
392 }
393 }
394
395 extensions
396}
397
398const LANG_MAPPING: &[(&str, &str)] = &[
399 ("perl", ".pl"),
400 ("python", ".py"),
401 ("ruby", ".rb"),
402 ("bash", ".sh"),
403 ("shell", ".sh"),
404 ("/sh", ".sh"),
405 ("node", ".js"),
406 ("javascript", ".js"),
407 ("typescript", ".ts"),
408 ("tsx", ".tsx"),
409 ("swift", ".swift"),
410 ("elixir", ".ex"),
411 ("rust", ".rs"),
412 ("go", ".go"),
413 ("golang", ".go"),
414 ("kotlin", ".kt"),
415 ("dart", ".dart"),
416 ("php", ".php"),
417 ("java", ".java"),
418 ("c", ".c"),
419 ("cpp", ".cpp"),
420 ("c++", ".cpp"),
421 ("cs", ".cs"),
422 ("csharp", ".cs"),
423 ("c#", ".cs"),
424 ("fortran", ".f90"),
425 ("terraform", ".tf"),
426 ("hcl", ".hcl"),
427 ("ocaml", ".ml"),
428 ("scala", ".scala"),
429 ("zig", ".zig"),
430 ("xml", ".xml"),
431 ("json", ".json"),
432 ("yaml", ".yaml"),
433 ("yml", ".yaml"),
434 ("toml", ".toml"),
435 ("markdown", ".md"),
436 ("csv", ".csv"),
437 ("eruby", ".erb"),
438 ("erb", ".erb"),
439 ("vue", ".vue"),
440 ("svelte", ".svelte"),
441];
442
443pub fn detect_ext_from_content(content: &str) -> Option<String> {
445 if let Some(first_line) = content.lines().next() {
447 if first_line.starts_with("#!") {
448 let shebang = first_line.to_lowercase();
449 for (keyword, ext) in LANG_MAPPING {
450 if shebang.contains(keyword) {
451 return Some(ext.to_string());
452 }
453 }
454 }
455 }
456
457 let lines: Vec<&str> = content.lines().collect();
460 let check_lines = lines.iter().take(5).chain(lines.iter().rev().take(5));
461 for line in check_lines {
462 if let Some(ft) = extract_vim_filetype(line) {
463 let ft_lower = ft.to_lowercase();
464 for (keyword, ext) in LANG_MAPPING {
465 if ft_lower == *keyword {
466 return Some(ext.to_string());
467 }
468 }
469 }
470 }
471
472 if let Some(ext) = detect_from_content_heuristics(content) {
474 return Some(ext);
475 }
476
477 None
478}
479
480fn detect_from_content_heuristics(content: &str) -> Option<String> {
483 let first_line = content.lines().next().unwrap_or("").trim();
484
485 if first_line.starts_with("<?php") || first_line.starts_with("<?PHP") {
487 return Some(".php".to_string());
488 }
489
490 if first_line.starts_with("<?xml") {
492 return Some(".xml".to_string());
493 }
494 if first_line.starts_with("<!DOCTYPE") || first_line.starts_with("<!doctype") {
495 return Some(".xml".to_string());
496 }
497
498 for line in content.lines().take(20) {
500 let trimmed = line.trim();
501
502 if trimmed.starts_with("<?php") || trimmed.starts_with("<?PHP") || trimmed == "<?=" {
504 return Some(".php".to_string());
505 }
506
507 if trimmed.starts_with("#include ") || trimmed.starts_with("#include\t") {
509 if content.lines().take(30).any(|l| {
511 let t = l.trim();
512 t.starts_with("using namespace")
513 || t.starts_with("class ")
514 || t.starts_with("#include <iostream")
515 || t.starts_with("#include <vector")
516 || t.starts_with("#include <string>")
517 || t.starts_with("#include <memory")
518 }) {
519 return Some(".cpp".to_string());
520 }
521 return Some(".c".to_string());
522 }
523
524 if trimmed.starts_with("package ") && trimmed.contains('.') && trimmed.ends_with(';') {
526 return Some(".java".to_string());
527 }
528
529 if trimmed.starts_with("package ") && !trimmed.contains('.') && !trimmed.contains(';') {
531 return Some(".go".to_string());
532 }
533
534 if (trimmed.starts_with("use std::") || trimmed.starts_with("use crate::"))
536 && trimmed.ends_with(';')
537 {
538 return Some(".rs".to_string());
539 }
540
541 if trimmed.starts_with("defmodule ") {
543 return Some(".ex".to_string());
544 }
545
546 if trimmed.starts_with("package ") && trimmed.contains('.') && !trimmed.ends_with(';') {
548 return Some(".kt".to_string());
549 }
550
551 if trimmed.starts_with("using System") && trimmed.ends_with(';') {
553 return Some(".cs".to_string());
554 }
555 if trimmed.starts_with("namespace ") && trimmed.ends_with('{') {
556 return Some(".cs".to_string());
559 }
560
561 if trimmed == "import Foundation"
563 || trimmed == "import UIKit"
564 || trimmed == "import SwiftUI"
565 || trimmed == "import Combine"
566 {
567 return Some(".swift".to_string());
568 }
569
570 if trimmed.starts_with("import 'dart:") || trimmed.starts_with("import \"dart:") {
572 return Some(".dart".to_string());
573 }
574
575 if trimmed.starts_with("object ") || trimmed.starts_with("trait ") {
577 return Some(".scala".to_string());
578 }
579
580 if trimmed.contains("@import(") {
582 return Some(".zig".to_string());
583 }
584
585 if trimmed.starts_with("resource \"")
587 || trimmed.starts_with("variable \"")
588 || trimmed.starts_with("terraform {")
589 || trimmed.starts_with("provider \"")
590 {
591 return Some(".tf".to_string());
592 }
593
594 let lower = trimmed.to_lowercase();
596 if lower.starts_with("program ") || lower.starts_with("module ")
597 || lower.starts_with("subroutine ") || lower == "implicit none"
598 {
599 if lower.starts_with("program ") || lower == "implicit none" {
602 return Some(".f90".to_string());
603 }
604 if content.lines().take(20).any(|l| l.trim().to_lowercase() == "implicit none") {
605 return Some(".f90".to_string());
606 }
607 }
608
609 if (trimmed.starts_with("def ") || trimmed.starts_with("class "))
611 && trimmed.ends_with(':')
612 && line.starts_with(trimmed.chars().next().unwrap_or(' '))
613 {
614 return Some(".py".to_string());
615 }
616
617 if trimmed.starts_with("require '") || trimmed.starts_with("require \"")
619 || trimmed.starts_with("require_relative ")
620 {
621 return Some(".rb".to_string());
622 }
623
624 if trimmed == "use strict;"
626 || trimmed == "use warnings;"
627 || trimmed.starts_with("my $")
628 || trimmed.starts_with("my @")
629 || trimmed.starts_with("my %")
630 {
631 return Some(".pl".to_string());
632 }
633 }
634
635 None
636}
637
638fn extract_vim_filetype(line: &str) -> Option<&str> {
639 let line = line.trim();
641 let vim_idx = line.find("vim:")?;
642 let after_vim = &line[vim_idx + 4..];
643
644 for token in after_vim.split_whitespace() {
645 if let Some(val) = token.strip_prefix("ft=") {
646 return Some(val.trim_end_matches(':'));
647 }
648 if let Some(val) = token.strip_prefix("filetype=") {
649 return Some(val.trim_end_matches(':'));
650 }
651 }
652 None
653}
654
655#[cfg(test)]
656mod tests {
657 use crate::parser::plugins::create_default_registry;
658 use tempfile::TempDir;
659
660 fn write_file(dir: &TempDir, name: &str, content: &str) {
661 let path = dir.path().join(name);
662 if let Some(parent) = path.parent() {
663 std::fs::create_dir_all(parent).unwrap();
664 }
665 std::fs::write(path, content).unwrap();
666 }
667
668 #[test]
669 fn test_registry_matches_compound_svelte_typescript_suffix() {
670 let registry = create_default_registry();
671 let plugin = registry
672 .get_plugin("src/routes/+page.svelte.ts")
673 .expect("plugin should exist");
674
675 assert_eq!(plugin.id(), "svelte");
676 }
677
678 #[test]
679 fn test_registry_matches_compound_svelte_javascript_suffix() {
680 let registry = create_default_registry();
681 let plugin = registry
682 .get_plugin("src/routes/+layout.svelte.js")
683 .expect("plugin should exist");
684
685 assert_eq!(plugin.id(), "svelte");
686 }
687
688 #[test]
689 fn test_registry_matches_svelte_test_suffix() {
690 let registry = create_default_registry();
691 let plugin = registry
692 .get_plugin("src/lib/multiplier.svelte.test.js")
693 .expect("plugin should exist");
694
695 assert_eq!(plugin.id(), "svelte");
696 }
697
698 #[test]
699 fn test_registry_prefers_svelte_plugin_for_component_files() {
700 let registry = create_default_registry();
701 let plugin = registry
702 .get_plugin("src/lib/Component.svelte")
703 .expect("plugin should exist");
704
705 assert_eq!(plugin.id(), "svelte");
706 }
707
708 #[test]
709 fn test_registry_matches_typescript_module_suffix() {
710 let registry = create_default_registry();
711 let plugin = registry
712 .get_plugin("src/lib/index.mts")
713 .expect("plugin should exist");
714
715 assert_eq!(plugin.id(), "code");
716 }
717
718 #[test]
719 fn test_registry_matches_typescript_commonjs_suffix() {
720 let registry = create_default_registry();
721 let plugin = registry
722 .get_plugin("src/lib/index.cts")
723 .expect("plugin should exist");
724
725 assert_eq!(plugin.id(), "code");
726 }
727
728 #[test]
729 fn test_detect_php_from_opening_tag() {
730 let registry = create_default_registry();
731 let content = "<?php\nclass Vendor {\n function get_name() { return $this->name; }\n}\n";
732 let plugin = registry
733 .get_plugin_with_content("vendor.inc2", content)
734 .expect("should detect PHP");
735 let entities = plugin.extract_entities(content, "vendor.inc2");
736 assert!(entities.iter().any(|e| e.entity_type == "class"));
737 }
738
739 #[test]
740 fn test_detect_c_from_include() {
741 let registry = create_default_registry();
742 let content = "#include <stdio.h>\n\nint main() {\n printf(\"hello\");\n return 0;\n}\n";
743 let plugin = registry
744 .get_plugin_with_content("main.xyz", content)
745 .expect("should detect C");
746 let entities = plugin.extract_entities(content, "main.xyz");
747 assert!(entities.iter().any(|e| e.name == "main"));
748 }
749
750 #[test]
751 fn test_detect_java_from_package() {
752 let registry = create_default_registry();
753 let content = "package com.example.app;\n\npublic class Main {\n public static void main(String[] args) {}\n}\n";
754 let plugin = registry
755 .get_plugin_with_content("Main", content)
756 .expect("should detect Java");
757 let entities = plugin.extract_entities(content, "Main");
758 assert!(entities.iter().any(|e| e.name == "Main"));
759 }
760
761 #[test]
762 fn test_detect_go_from_package() {
763 let registry = create_default_registry();
764 let content = "package main\n\nimport \"fmt\"\n\nfunc hello() {\n fmt.Println(\"hi\")\n}\n";
765 let plugin = registry
766 .get_plugin_with_content("main", content)
767 .expect("should detect Go");
768 let entities = plugin.extract_entities(content, "main");
769 assert!(entities.iter().any(|e| e.name == "hello"));
770 }
771
772 #[test]
773 fn test_detect_rust_from_use_std() {
774 let registry = create_default_registry();
775 let content = "use std::collections::HashMap;\n\nfn process() {\n let m = HashMap::new();\n}\n";
776 let plugin = registry
777 .get_plugin_with_content("lib", content)
778 .expect("should detect Rust");
779 let entities = plugin.extract_entities(content, "lib");
780 assert!(entities.iter().any(|e| e.name == "process"));
781 }
782
783 #[cfg(feature = "lang-go")]
784 #[test]
785 fn test_go_method_parent_resolves_across_files() {
786 let registry = create_default_registry();
787 let dir = TempDir::new().unwrap();
788 write_file(&dir, "models.go", "package demo\n\ntype Service struct{}\n");
789 write_file(
790 &dir,
791 "methods.go",
792 "package demo\n\nfunc (s *Service) Run() {}\n",
793 );
794
795 let entities = registry.extract_all_entities(
796 dir.path(),
797 &["models.go".to_string(), "methods.go".to_string()],
798 );
799 let service = entities
800 .iter()
801 .find(|e| e.name == "Service" && e.file_path == "models.go")
802 .expect("Service type should be extracted");
803 let run = entities
804 .iter()
805 .find(|e| e.name == "Run" && e.file_path == "methods.go")
806 .expect("Run method should be extracted");
807
808 assert_eq!(run.parent_id.as_deref(), Some(service.id.as_str()));
809 assert_eq!(run.id, format!("{}::Run", service.id));
810 }
811
812 #[cfg(feature = "lang-go")]
813 #[test]
814 fn test_go_method_parent_resolution_is_package_directory_scoped() {
815 let registry = create_default_registry();
816 let dir = TempDir::new().unwrap();
817 write_file(&dir, "alpha/models.go", "package demo\n\ntype Service struct{}\n");
818 write_file(
819 &dir,
820 "alpha/methods.go",
821 "package demo\n\nfunc (s *Service) Run() {}\n",
822 );
823 write_file(&dir, "beta/models.go", "package demo\n\ntype Service struct{}\n");
824 write_file(
825 &dir,
826 "beta/methods.go",
827 "package demo\n\nfunc (s *Service) Run() {}\n",
828 );
829
830 let entities = registry.extract_all_entities(
831 dir.path(),
832 &[
833 "alpha/models.go".to_string(),
834 "alpha/methods.go".to_string(),
835 "beta/models.go".to_string(),
836 "beta/methods.go".to_string(),
837 ],
838 );
839
840 let alpha_service = entities
841 .iter()
842 .find(|e| e.name == "Service" && e.file_path == "alpha/models.go")
843 .expect("alpha Service type should be extracted");
844 let beta_service = entities
845 .iter()
846 .find(|e| e.name == "Service" && e.file_path == "beta/models.go")
847 .expect("beta Service type should be extracted");
848 let alpha_run = entities
849 .iter()
850 .find(|e| e.name == "Run" && e.file_path == "alpha/methods.go")
851 .expect("alpha Run method should be extracted");
852 let beta_run = entities
853 .iter()
854 .find(|e| e.name == "Run" && e.file_path == "beta/methods.go")
855 .expect("beta Run method should be extracted");
856
857 assert_eq!(alpha_run.parent_id.as_deref(), Some(alpha_service.id.as_str()));
858 assert_eq!(beta_run.parent_id.as_deref(), Some(beta_service.id.as_str()));
859 }
860
861 #[test]
862 fn test_extension_takes_priority_over_heuristics() {
863 let registry = create_default_registry();
864 let content = "<?php\nclass Foo {}\n";
866 let plugin = registry
867 .get_plugin_with_content("script.py", content)
868 .expect("should use Python parser");
869 assert_eq!(plugin.id(), "code"); }
871
872 #[test]
873 fn test_custom_extension_mapping_extracts_entities() {
874 let mut registry = create_default_registry();
875 registry.add_extension_mapping(".mypy", "python");
876
877 let content = "def hello():\n print(\"hello world\")\n\nclass Calculator:\n def multiply(self, a, b):\n return a * b\n";
878 let entities = registry.extract_entities("utils.mypy", content);
879
880 assert!(!entities.is_empty(), "Should extract entities via custom mapping");
881 assert!(entities.iter().any(|e| e.name == "hello"), "Should find hello function");
882 assert!(entities.iter().any(|e| e.name == "Calculator"), "Should find Calculator class");
883 assert!(entities.iter().any(|e| e.name == "multiply"), "Should find multiply method");
884
885 for entity in &entities {
887 assert_eq!(entity.file_path, "utils.mypy", "Entity file_path should use original extension");
888 assert!(entity.id.starts_with("utils.mypy::"), "Entity ID should use original file path");
889 }
890 }
891}