1mod entity_extractor;
2mod languages;
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::plugin::SemanticParserPlugin;
9use languages::{get_all_code_extensions, get_language_config};
10use entity_extractor::extract_entities;
11
12pub struct CodeParserPlugin;
13
14thread_local! {
17 static PARSER_CACHE: RefCell<HashMap<&'static str, tree_sitter::Parser>> = RefCell::new(HashMap::new());
18}
19
20impl SemanticParserPlugin for CodeParserPlugin {
21 fn id(&self) -> &str {
22 "code"
23 }
24
25 fn extensions(&self) -> &[&str] {
26 get_all_code_extensions()
27 }
28
29 fn extract_entities(&self, content: &str, file_path: &str) -> Vec<SemanticEntity> {
30 let ext = std::path::Path::new(file_path)
31 .extension()
32 .and_then(|e| e.to_str())
33 .map(|e| format!(".{}", e.to_lowercase()))
34 .unwrap_or_default();
35
36 let config = match get_language_config(&ext) {
37 Some(c) => c,
38 None => return Vec::new(),
39 };
40
41 let language = match (config.get_language)() {
42 Some(lang) => lang,
43 None => return Vec::new(),
44 };
45
46 PARSER_CACHE.with(|cache| {
47 let mut cache = cache.borrow_mut();
48 let parser = cache.entry(config.id).or_insert_with(|| {
49 let mut p = tree_sitter::Parser::new();
50 let _ = p.set_language(&language);
51 p
52 });
53
54 let tree = match parser.parse(content.as_bytes(), None) {
55 Some(t) => t,
56 None => return Vec::new(),
57 };
58
59 extract_entities(&tree, file_path, config, content)
60 })
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67
68 #[test]
69 fn test_java_entity_extraction() {
70 let code = r#"
71package com.example;
72
73import java.util.List;
74
75public class UserService {
76 private String name;
77
78 public UserService(String name) {
79 this.name = name;
80 }
81
82 public List<User> getUsers() {
83 return db.findAll();
84 }
85
86 public void createUser(User user) {
87 db.save(user);
88 }
89}
90
91interface Repository<T> {
92 T findById(String id);
93 List<T> findAll();
94}
95
96enum Status {
97 ACTIVE,
98 INACTIVE,
99 DELETED
100}
101"#;
102 let plugin = CodeParserPlugin;
103 let entities = plugin.extract_entities(code, "UserService.java");
104 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
105 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
106 eprintln!("Java entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
107
108 assert!(names.contains(&"UserService"), "Should find class UserService, got: {:?}", names);
109 assert!(names.contains(&"Repository"), "Should find interface Repository, got: {:?}", names);
110 assert!(names.contains(&"Status"), "Should find enum Status, got: {:?}", names);
111 }
112
113 #[test]
114 fn test_java_nested_methods() {
115 let code = r#"
116public class Calculator {
117 public int add(int a, int b) {
118 return a + b;
119 }
120
121 public int subtract(int a, int b) {
122 return a - b;
123 }
124}
125"#;
126 let plugin = CodeParserPlugin;
127 let entities = plugin.extract_entities(code, "Calculator.java");
128 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
129 eprintln!("Java nested: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
130
131 assert!(names.contains(&"Calculator"), "Should find Calculator class");
132 assert!(names.contains(&"add"), "Should find add method, got: {:?}", names);
133 assert!(names.contains(&"subtract"), "Should find subtract method, got: {:?}", names);
134
135 let add = entities.iter().find(|e| e.name == "add").unwrap();
137 assert!(add.parent_id.is_some(), "add should have parent_id");
138 }
139
140 #[test]
141 fn test_c_entity_extraction() {
142 let code = r#"
143#include <stdio.h>
144
145struct Point {
146 int x;
147 int y;
148};
149
150enum Color {
151 RED,
152 GREEN,
153 BLUE
154};
155
156typedef struct {
157 char name[50];
158 int age;
159} Person;
160
161void greet(const char* name) {
162 printf("Hello, %s!\n", name);
163}
164
165int add(int a, int b) {
166 return a + b;
167}
168
169int main() {
170 greet("world");
171 return 0;
172}
173"#;
174 let plugin = CodeParserPlugin;
175 let entities = plugin.extract_entities(code, "main.c");
176 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
177 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
178 eprintln!("C entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
179
180 assert!(names.contains(&"greet"), "Should find greet function, got: {:?}", names);
181 assert!(names.contains(&"add"), "Should find add function, got: {:?}", names);
182 assert!(names.contains(&"main"), "Should find main function, got: {:?}", names);
183 assert!(names.contains(&"Point"), "Should find Point struct, got: {:?}", names);
184 assert!(names.contains(&"Color"), "Should find Color enum, got: {:?}", names);
185 }
186
187 #[test]
188 fn test_cpp_entity_extraction() {
189 let code = "namespace math {\nclass Vector3 {\npublic:\n float length() const { return 0; }\n};\n}\nvoid greet() {}\n";
190 let plugin = CodeParserPlugin;
191 let entities = plugin.extract_entities(code, "main.cpp");
192 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
193 assert!(names.contains(&"math"), "got: {:?}", names);
194 assert!(names.contains(&"Vector3"), "got: {:?}", names);
195 assert!(names.contains(&"greet"), "got: {:?}", names);
196 }
197
198 #[test]
199 fn test_ruby_entity_extraction() {
200 let code = "module Auth\n class User\n def greet\n \"hi\"\n end\n end\nend\ndef helper(x)\n x * 2\nend\n";
201 let plugin = CodeParserPlugin;
202 let entities = plugin.extract_entities(code, "auth.rb");
203 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
204 assert!(names.contains(&"Auth"), "got: {:?}", names);
205 assert!(names.contains(&"User"), "got: {:?}", names);
206 assert!(names.contains(&"helper"), "got: {:?}", names);
207 }
208
209 #[test]
210 fn test_csharp_entity_extraction() {
211 let code = "namespace MyApp {\npublic class User {\n public string GetName() { return \"\"; }\n}\npublic enum Role { Admin, User }\n}\n";
212 let plugin = CodeParserPlugin;
213 let entities = plugin.extract_entities(code, "Models.cs");
214 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
215 assert!(names.contains(&"MyApp"), "got: {:?}", names);
216 assert!(names.contains(&"User"), "got: {:?}", names);
217 assert!(names.contains(&"Role"), "got: {:?}", names);
218 }
219
220 #[test]
221 fn test_swift_entity_extraction() {
222 let code = r#"
223import Foundation
224
225class UserService {
226 var name: String
227
228 init(name: String) {
229 self.name = name
230 }
231
232 func getUsers() -> [User] {
233 return db.findAll()
234 }
235}
236
237struct Point {
238 var x: Double
239 var y: Double
240}
241
242enum Status {
243 case active
244 case inactive
245 case deleted
246}
247
248protocol Repository {
249 associatedtype Item
250 func findById(id: String) -> Item?
251 func findAll() -> [Item]
252}
253
254func helper(x: Int) -> Int {
255 return x * 2
256}
257"#;
258 let plugin = CodeParserPlugin;
259 let entities = plugin.extract_entities(code, "UserService.swift");
260 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
261 eprintln!("Swift entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
262
263 assert!(names.contains(&"UserService"), "Should find class UserService, got: {:?}", names);
264 assert!(names.contains(&"Point"), "Should find struct Point, got: {:?}", names);
265 assert!(names.contains(&"Status"), "Should find enum Status, got: {:?}", names);
266 assert!(names.contains(&"Repository"), "Should find protocol Repository, got: {:?}", names);
267 assert!(names.contains(&"helper"), "Should find function helper, got: {:?}", names);
268 }
269
270 #[test]
271 fn test_elixir_entity_extraction() {
272 let code = r#"
273defmodule MyApp.Accounts do
274 def create_user(attrs) do
275 %User{}
276 |> User.changeset(attrs)
277 |> Repo.insert()
278 end
279
280 defp validate(attrs) do
281 # private helper
282 :ok
283 end
284
285 defmacro is_admin(user) do
286 quote do
287 unquote(user).role == :admin
288 end
289 end
290
291 defguard is_positive(x) when is_integer(x) and x > 0
292end
293
294defprotocol Printable do
295 def to_string(data)
296end
297
298defimpl Printable, for: Integer do
299 def to_string(i), do: Integer.to_string(i)
300end
301"#;
302 let plugin = CodeParserPlugin;
303 let entities = plugin.extract_entities(code, "accounts.ex");
304 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
305 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
306 eprintln!("Elixir entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
307
308 assert!(names.contains(&"MyApp.Accounts"), "Should find module, got: {:?}", names);
309 assert!(names.contains(&"create_user"), "Should find def, got: {:?}", names);
310 assert!(names.contains(&"validate"), "Should find defp, got: {:?}", names);
311 assert!(names.contains(&"is_admin"), "Should find defmacro, got: {:?}", names);
312 assert!(names.contains(&"Printable"), "Should find defprotocol, got: {:?}", names);
313
314 let create_user = entities.iter().find(|e| e.name == "create_user").unwrap();
316 assert!(create_user.parent_id.is_some(), "create_user should be nested under module");
317 }
318
319 #[test]
320 fn test_bash_entity_extraction() {
321 let code = r#"#!/bin/bash
322
323greet() {
324 echo "Hello, $1!"
325}
326
327function deploy {
328 echo "deploying..."
329}
330
331# not a function
332echo "main script"
333"#;
334 let plugin = CodeParserPlugin;
335 let entities = plugin.extract_entities(code, "deploy.sh");
336 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
337 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
338 eprintln!("Bash entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
339
340 assert!(names.contains(&"greet"), "Should find greet(), got: {:?}", names);
341 assert!(names.contains(&"deploy"), "Should find function deploy, got: {:?}", names);
342 assert_eq!(entities.len(), 2, "Should only find functions, got: {:?}", names);
343 }
344
345 #[test]
346 fn test_typescript_entity_extraction() {
347 let code = r#"
349export function hello(): string {
350 return "hello";
351}
352
353export class Greeter {
354 greet(name: string): string {
355 return `Hello, ${name}!`;
356 }
357}
358"#;
359 let plugin = CodeParserPlugin;
360 let entities = plugin.extract_entities(code, "test.ts");
361 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
362 assert!(names.contains(&"hello"), "Should find hello function");
363 assert!(names.contains(&"Greeter"), "Should find Greeter class");
364 }
365
366 #[test]
367 fn test_nested_functions_typescript() {
368 let code = r#"
369function outer() {
370 function inner() {
371 return 42;
372 }
373 return inner();
374}
375"#;
376 let plugin = CodeParserPlugin;
377 let entities = plugin.extract_entities(code, "nested.ts");
378 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
379 eprintln!("Nested TS: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
380
381 assert!(names.contains(&"outer"), "Should find outer, got: {:?}", names);
382 assert!(names.contains(&"inner"), "Should find inner, got: {:?}", names);
383
384 let inner = entities.iter().find(|e| e.name == "inner").unwrap();
385 assert!(inner.parent_id.is_some(), "inner should have parent_id");
386 }
387
388 #[test]
389 fn test_nested_functions_python() {
390 let code = "def outer():\n def inner():\n return 42\n return inner()\n";
391 let plugin = CodeParserPlugin;
392 let entities = plugin.extract_entities(code, "nested.py");
393 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
394
395 assert!(names.contains(&"outer"), "got: {:?}", names);
396 assert!(names.contains(&"inner"), "got: {:?}", names);
397
398 let inner = entities.iter().find(|e| e.name == "inner").unwrap();
399 assert!(inner.parent_id.is_some(), "inner should have parent_id");
400 }
401
402 #[test]
403 fn test_nested_functions_rust() {
404 let code = "fn outer() {\n fn inner() -> i32 {\n 42\n }\n inner();\n}\n";
405 let plugin = CodeParserPlugin;
406 let entities = plugin.extract_entities(code, "nested.rs");
407 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
408
409 assert!(names.contains(&"outer"), "got: {:?}", names);
410 assert!(names.contains(&"inner"), "got: {:?}", names);
411
412 let inner = entities.iter().find(|e| e.name == "inner").unwrap();
413 assert!(inner.parent_id.is_some(), "inner should have parent_id");
414 }
415
416 #[test]
417 fn test_rust_impl_blocks_unique_names() {
418 let code = r#"
419trait Greeting {
420 fn greet(&self) -> String;
421}
422
423struct Person;
424struct Robot;
425struct Cat;
426
427impl Greeting for Person {
428 fn greet(&self) -> String { "Hello".to_string() }
429}
430
431impl Greeting for Robot {
432 fn greet(&self) -> String { "Beep".to_string() }
433}
434
435impl Greeting for Cat {
436 fn greet(&self) -> String { "Meow".to_string() }
437}
438"#;
439 let plugin = CodeParserPlugin;
440 let entities = plugin.extract_entities(code, "impls.rs");
441 let impl_entities: Vec<&_> = entities.iter()
442 .filter(|e| e.entity_type == "impl")
443 .collect();
444 let names: Vec<&str> = impl_entities.iter().map(|e| e.name.as_str()).collect();
445
446 assert_eq!(impl_entities.len(), 3, "Should find 3 impl blocks, got: {:?}", names);
447 assert!(names.contains(&"Greeting for Person"), "got: {:?}", names);
448 assert!(names.contains(&"Greeting for Robot"), "got: {:?}", names);
449 assert!(names.contains(&"Greeting for Cat"), "got: {:?}", names);
450 }
451
452 #[test]
453 fn test_nested_functions_go() {
454 let code = "package main\n\nfunc outer() {\n var x int = 42\n _ = x\n}\n";
456 let plugin = CodeParserPlugin;
457 let entities = plugin.extract_entities(code, "nested.go");
458 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
459
460 assert!(names.contains(&"outer"), "got: {:?}", names);
461 }
462
463 #[test]
464 fn test_renamed_function_same_structural_hash() {
465 let code_a = "def get_card():\n return db.query('cards')\n";
466 let code_b = "def get_card_1():\n return db.query('cards')\n";
467
468 let plugin = CodeParserPlugin;
469 let entities_a = plugin.extract_entities(code_a, "a.py");
470 let entities_b = plugin.extract_entities(code_b, "b.py");
471
472 assert_eq!(entities_a.len(), 1, "Should find one entity in a");
473 assert_eq!(entities_b.len(), 1, "Should find one entity in b");
474 assert_eq!(entities_a[0].name, "get_card");
475 assert_eq!(entities_b[0].name, "get_card_1");
476
477 assert_eq!(
479 entities_a[0].structural_hash, entities_b[0].structural_hash,
480 "Renamed function with identical body should have same structural_hash"
481 );
482
483 assert_ne!(
485 entities_a[0].content_hash, entities_b[0].content_hash,
486 "Content hash should differ since raw content includes the name"
487 );
488 }
489
490 #[test]
491 fn test_hcl_entity_extraction() {
492 let code = r#"
493region = "eu-west-1"
494
495variable "image_id" {
496 type = string
497}
498
499resource "aws_instance" "web" {
500 ami = var.image_id
501
502 lifecycle {
503 create_before_destroy = true
504 }
505}
506"#;
507 let plugin = CodeParserPlugin;
508 let entities = plugin.extract_entities(code, "main.tf");
509 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
510 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
511 eprintln!("HCL entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
512
513 assert!(names.contains(&"region"), "Should find top-level attribute, got: {:?}", names);
514 assert!(names.contains(&"variable.image_id"), "Should find variable block, got: {:?}", names);
515 assert!(names.contains(&"resource.aws_instance.web"), "Should find resource block, got: {:?}", names);
516 assert!(
517 names.contains(&"resource.aws_instance.web.lifecycle"),
518 "Should find nested lifecycle block with qualified name, got: {:?}",
519 names
520 );
521 assert!(!names.contains(&"ami"), "Should skip nested attributes inside blocks, got: {:?}", names);
522 assert!(
523 !names.contains(&"create_before_destroy"),
524 "Should skip nested attributes inside nested blocks, got: {:?}",
525 names
526 );
527
528 let lifecycle = entities
529 .iter()
530 .find(|e| e.name == "resource.aws_instance.web.lifecycle")
531 .unwrap();
532 assert!(lifecycle.parent_id.is_some(), "lifecycle should be nested under resource");
533 assert!(types.contains(&"attribute"), "Should preserve attribute entity type for top-level attributes");
534 }
535
536 #[test]
537 fn test_kotlin_entity_extraction() {
538 let code = r#"
539class UserService {
540 val name: String = ""
541
542 fun greet(): String {
543 return "Hello, $name"
544 }
545
546 companion object {
547 fun create(): UserService = UserService()
548 }
549}
550
551interface Repository {
552 fun findById(id: Int): Any?
553}
554
555object AppConfig {
556 val version = "1.0"
557}
558
559fun topLevel(x: Int): Int = x * 2
560"#;
561 let plugin = CodeParserPlugin;
562 let entities = plugin.extract_entities(code, "App.kt");
563 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
564 eprintln!("Kotlin entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
565 assert!(names.contains(&"UserService"), "got: {:?}", names);
566 assert!(names.contains(&"greet"), "got: {:?}", names);
567 assert!(names.contains(&"Repository"), "got: {:?}", names);
568 assert!(names.contains(&"findById"), "got: {:?}", names);
569 assert!(names.contains(&"AppConfig"), "got: {:?}", names);
570 assert!(names.contains(&"topLevel"), "got: {:?}", names);
571 }
572
573 #[test]
574 fn test_xml_entity_extraction() {
575 let code = r#"<?xml version="1.0" encoding="UTF-8"?>
576<project>
577 <groupId>com.example</groupId>
578 <artifactId>my-app</artifactId>
579 <dependencies>
580 <dependency>
581 <groupId>junit</groupId>
582 <artifactId>junit</artifactId>
583 </dependency>
584 </dependencies>
585 <build>
586 <plugins>
587 <plugin>
588 <groupId>org.apache.maven</groupId>
589 </plugin>
590 </plugins>
591 </build>
592</project>
593"#;
594 let plugin = CodeParserPlugin;
595 let entities = plugin.extract_entities(code, "pom.xml");
596 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
597 eprintln!("XML entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
598 assert!(names.contains(&"project"), "got: {:?}", names);
599 assert!(names.contains(&"dependencies"), "got: {:?}", names);
600 assert!(names.contains(&"build"), "got: {:?}", names);
601 }
602
603 #[test]
604 fn test_go_var_declaration() {
605 let code = r#"package featuremgmt
606
607type FeatureFlag struct {
608 Name string
609 Description string
610 Stage string
611}
612
613var standardFeatureFlags = []FeatureFlag{
614 {
615 Name: "panelTitleSearch",
616 Description: "Search for dashboards using panel title",
617 Stage: "PublicPreview",
618 },
619}
620
621func GetFlags() []FeatureFlag {
622 return standardFeatureFlags
623}
624"#;
625 let plugin = CodeParserPlugin;
626 let entities = plugin.extract_entities(code, "flags.go");
627 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
628 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
629 eprintln!("Go entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
630
631 assert!(names.contains(&"FeatureFlag"), "Should find type FeatureFlag, got: {:?}", names);
632 assert!(names.contains(&"standardFeatureFlags"), "Should find var standardFeatureFlags, got: {:?}", names);
633 assert!(names.contains(&"GetFlags"), "Should find func GetFlags, got: {:?}", names);
634 }
635
636 #[test]
637 fn test_go_grouped_var_declaration() {
638 let code = r#"package test
639
640var (
641 simple = 42
642 flags = []string{"a", "b"}
643)
644
645const (
646 x = 1
647 y = 2
648)
649
650func main() {}
651"#;
652 let plugin = CodeParserPlugin;
653 let entities = plugin.extract_entities(code, "test.go");
654 let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
655 let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
656 eprintln!("Go grouped entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
657
658 assert!(names.contains(&"flags") || names.contains(&"simple"), "Should find grouped var, got: {:?}", names);
659 assert!(names.contains(&"x"), "Should find grouped const x, got: {:?}", names);
660 assert!(names.contains(&"main"), "Should find func main, got: {:?}", names);
661 }
662}