Skip to main content

sem_core/parser/plugins/code/
mod.rs

1mod entity_extractor;
2mod languages;
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::plugin::SemanticParserPlugin;
9use languages::{get_all_code_extensions, get_language_config};
10use entity_extractor::extract_entities;
11
12pub struct CodeParserPlugin;
13
14// Thread-local parser cache: one Parser per language per thread.
15// Avoids creating a new Parser for every file during parallel graph builds.
16thread_local! {
17    static PARSER_CACHE: RefCell<HashMap<&'static str, tree_sitter::Parser>> = RefCell::new(HashMap::new());
18}
19
20impl SemanticParserPlugin for CodeParserPlugin {
21    fn id(&self) -> &str {
22        "code"
23    }
24
25    fn extensions(&self) -> &[&str] {
26        get_all_code_extensions()
27    }
28
29    fn extract_entities(&self, content: &str, file_path: &str) -> Vec<SemanticEntity> {
30        let ext = std::path::Path::new(file_path)
31            .extension()
32            .and_then(|e| e.to_str())
33            .map(|e| format!(".{}", e.to_lowercase()))
34            .unwrap_or_default();
35
36        let config = match get_language_config(&ext) {
37            Some(c) => c,
38            None => return Vec::new(),
39        };
40
41        let language = match (config.get_language)() {
42            Some(lang) => lang,
43            None => return Vec::new(),
44        };
45
46        PARSER_CACHE.with(|cache| {
47            let mut cache = cache.borrow_mut();
48            let parser = cache.entry(config.id).or_insert_with(|| {
49                let mut p = tree_sitter::Parser::new();
50                let _ = p.set_language(&language);
51                p
52            });
53
54            let tree = match parser.parse(content.as_bytes(), None) {
55                Some(t) => t,
56                None => return Vec::new(),
57            };
58
59            extract_entities(&tree, file_path, config, content)
60        })
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn test_java_entity_extraction() {
70        let code = r#"
71package com.example;
72
73import java.util.List;
74
75public class UserService {
76    private String name;
77
78    public UserService(String name) {
79        this.name = name;
80    }
81
82    public List<User> getUsers() {
83        return db.findAll();
84    }
85
86    public void createUser(User user) {
87        db.save(user);
88    }
89}
90
91interface Repository<T> {
92    T findById(String id);
93    List<T> findAll();
94}
95
96enum Status {
97    ACTIVE,
98    INACTIVE,
99    DELETED
100}
101"#;
102        let plugin = CodeParserPlugin;
103        let entities = plugin.extract_entities(code, "UserService.java");
104        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
105        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
106        eprintln!("Java entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
107
108        assert!(names.contains(&"UserService"), "Should find class UserService, got: {:?}", names);
109        assert!(names.contains(&"Repository"), "Should find interface Repository, got: {:?}", names);
110        assert!(names.contains(&"Status"), "Should find enum Status, got: {:?}", names);
111    }
112
113    #[test]
114    fn test_java_nested_methods() {
115        let code = r#"
116public class Calculator {
117    public int add(int a, int b) {
118        return a + b;
119    }
120
121    public int subtract(int a, int b) {
122        return a - b;
123    }
124}
125"#;
126        let plugin = CodeParserPlugin;
127        let entities = plugin.extract_entities(code, "Calculator.java");
128        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
129        eprintln!("Java nested: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
130
131        assert!(names.contains(&"Calculator"), "Should find Calculator class");
132        assert!(names.contains(&"add"), "Should find add method, got: {:?}", names);
133        assert!(names.contains(&"subtract"), "Should find subtract method, got: {:?}", names);
134
135        // Methods should have Calculator as parent
136        let add = entities.iter().find(|e| e.name == "add").unwrap();
137        assert!(add.parent_id.is_some(), "add should have parent_id");
138    }
139
140    #[test]
141    fn test_c_entity_extraction() {
142        let code = r#"
143#include <stdio.h>
144
145struct Point {
146    int x;
147    int y;
148};
149
150enum Color {
151    RED,
152    GREEN,
153    BLUE
154};
155
156typedef struct {
157    char name[50];
158    int age;
159} Person;
160
161void greet(const char* name) {
162    printf("Hello, %s!\n", name);
163}
164
165int add(int a, int b) {
166    return a + b;
167}
168
169int main() {
170    greet("world");
171    return 0;
172}
173"#;
174        let plugin = CodeParserPlugin;
175        let entities = plugin.extract_entities(code, "main.c");
176        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
177        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
178        eprintln!("C entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
179
180        assert!(names.contains(&"greet"), "Should find greet function, got: {:?}", names);
181        assert!(names.contains(&"add"), "Should find add function, got: {:?}", names);
182        assert!(names.contains(&"main"), "Should find main function, got: {:?}", names);
183        assert!(names.contains(&"Point"), "Should find Point struct, got: {:?}", names);
184        assert!(names.contains(&"Color"), "Should find Color enum, got: {:?}", names);
185    }
186
187    #[test]
188    fn test_cpp_entity_extraction() {
189        let code = "namespace math {\nclass Vector3 {\npublic:\n    float length() const { return 0; }\n};\n}\nvoid greet() {}\n";
190        let plugin = CodeParserPlugin;
191        let entities = plugin.extract_entities(code, "main.cpp");
192        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
193        assert!(names.contains(&"math"), "got: {:?}", names);
194        assert!(names.contains(&"Vector3"), "got: {:?}", names);
195        assert!(names.contains(&"greet"), "got: {:?}", names);
196    }
197
198    #[test]
199    fn test_ruby_entity_extraction() {
200        let code = "module Auth\n  class User\n    def greet\n      \"hi\"\n    end\n  end\nend\ndef helper(x)\n  x * 2\nend\n";
201        let plugin = CodeParserPlugin;
202        let entities = plugin.extract_entities(code, "auth.rb");
203        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
204        assert!(names.contains(&"Auth"), "got: {:?}", names);
205        assert!(names.contains(&"User"), "got: {:?}", names);
206        assert!(names.contains(&"helper"), "got: {:?}", names);
207    }
208
209    #[test]
210    fn test_csharp_entity_extraction() {
211        let code = "namespace MyApp {\npublic class User {\n    public string GetName() { return \"\"; }\n}\npublic enum Role { Admin, User }\n}\n";
212        let plugin = CodeParserPlugin;
213        let entities = plugin.extract_entities(code, "Models.cs");
214        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
215        assert!(names.contains(&"MyApp"), "got: {:?}", names);
216        assert!(names.contains(&"User"), "got: {:?}", names);
217        assert!(names.contains(&"Role"), "got: {:?}", names);
218    }
219
220    #[test]
221    fn test_swift_entity_extraction() {
222        let code = r#"
223import Foundation
224
225class UserService {
226    var name: String
227
228    init(name: String) {
229        self.name = name
230    }
231
232    func getUsers() -> [User] {
233        return db.findAll()
234    }
235}
236
237struct Point {
238    var x: Double
239    var y: Double
240}
241
242enum Status {
243    case active
244    case inactive
245    case deleted
246}
247
248protocol Repository {
249    associatedtype Item
250    func findById(id: String) -> Item?
251    func findAll() -> [Item]
252}
253
254func helper(x: Int) -> Int {
255    return x * 2
256}
257"#;
258        let plugin = CodeParserPlugin;
259        let entities = plugin.extract_entities(code, "UserService.swift");
260        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
261        eprintln!("Swift entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
262
263        assert!(names.contains(&"UserService"), "Should find class UserService, got: {:?}", names);
264        assert!(names.contains(&"Point"), "Should find struct Point, got: {:?}", names);
265        assert!(names.contains(&"Status"), "Should find enum Status, got: {:?}", names);
266        assert!(names.contains(&"Repository"), "Should find protocol Repository, got: {:?}", names);
267        assert!(names.contains(&"helper"), "Should find function helper, got: {:?}", names);
268    }
269
270    #[test]
271    fn test_elixir_entity_extraction() {
272        let code = r#"
273defmodule MyApp.Accounts do
274  def create_user(attrs) do
275    %User{}
276    |> User.changeset(attrs)
277    |> Repo.insert()
278  end
279
280  defp validate(attrs) do
281    # private helper
282    :ok
283  end
284
285  defmacro is_admin(user) do
286    quote do
287      unquote(user).role == :admin
288    end
289  end
290
291  defguard is_positive(x) when is_integer(x) and x > 0
292end
293
294defprotocol Printable do
295  def to_string(data)
296end
297
298defimpl Printable, for: Integer do
299  def to_string(i), do: Integer.to_string(i)
300end
301"#;
302        let plugin = CodeParserPlugin;
303        let entities = plugin.extract_entities(code, "accounts.ex");
304        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
305        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
306        eprintln!("Elixir entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
307
308        assert!(names.contains(&"MyApp.Accounts"), "Should find module, got: {:?}", names);
309        assert!(names.contains(&"create_user"), "Should find def, got: {:?}", names);
310        assert!(names.contains(&"validate"), "Should find defp, got: {:?}", names);
311        assert!(names.contains(&"is_admin"), "Should find defmacro, got: {:?}", names);
312        assert!(names.contains(&"Printable"), "Should find defprotocol, got: {:?}", names);
313
314        // Verify nesting: create_user should have MyApp.Accounts as parent
315        let create_user = entities.iter().find(|e| e.name == "create_user").unwrap();
316        assert!(create_user.parent_id.is_some(), "create_user should be nested under module");
317    }
318
319    #[test]
320    fn test_bash_entity_extraction() {
321        let code = r#"#!/bin/bash
322
323greet() {
324    echo "Hello, $1!"
325}
326
327function deploy {
328    echo "deploying..."
329}
330
331# not a function
332echo "main script"
333"#;
334        let plugin = CodeParserPlugin;
335        let entities = plugin.extract_entities(code, "deploy.sh");
336        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
337        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
338        eprintln!("Bash entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
339
340        assert!(names.contains(&"greet"), "Should find greet(), got: {:?}", names);
341        assert!(names.contains(&"deploy"), "Should find function deploy, got: {:?}", names);
342        assert_eq!(entities.len(), 2, "Should only find functions, got: {:?}", names);
343    }
344
345    #[test]
346    fn test_typescript_entity_extraction() {
347        // Existing language should still work
348        let code = r#"
349export function hello(): string {
350    return "hello";
351}
352
353export class Greeter {
354    greet(name: string): string {
355        return `Hello, ${name}!`;
356    }
357}
358"#;
359        let plugin = CodeParserPlugin;
360        let entities = plugin.extract_entities(code, "test.ts");
361        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
362        assert!(names.contains(&"hello"), "Should find hello function");
363        assert!(names.contains(&"Greeter"), "Should find Greeter class");
364    }
365
366    #[test]
367    fn test_nested_functions_typescript() {
368        let code = r#"
369function outer() {
370    function inner() {
371        return 42;
372    }
373    return inner();
374}
375"#;
376        let plugin = CodeParserPlugin;
377        let entities = plugin.extract_entities(code, "nested.ts");
378        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
379        eprintln!("Nested TS: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
380
381        assert!(names.contains(&"outer"), "Should find outer, got: {:?}", names);
382        assert!(names.contains(&"inner"), "Should find inner, got: {:?}", names);
383
384        let inner = entities.iter().find(|e| e.name == "inner").unwrap();
385        assert!(inner.parent_id.is_some(), "inner should have parent_id");
386    }
387
388    #[test]
389    fn test_nested_functions_python() {
390        let code = "def outer():\n    def inner():\n        return 42\n    return inner()\n";
391        let plugin = CodeParserPlugin;
392        let entities = plugin.extract_entities(code, "nested.py");
393        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
394
395        assert!(names.contains(&"outer"), "got: {:?}", names);
396        assert!(names.contains(&"inner"), "got: {:?}", names);
397
398        let inner = entities.iter().find(|e| e.name == "inner").unwrap();
399        assert!(inner.parent_id.is_some(), "inner should have parent_id");
400    }
401
402    #[test]
403    fn test_nested_functions_rust() {
404        let code = "fn outer() {\n    fn inner() -> i32 {\n        42\n    }\n    inner();\n}\n";
405        let plugin = CodeParserPlugin;
406        let entities = plugin.extract_entities(code, "nested.rs");
407        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
408
409        assert!(names.contains(&"outer"), "got: {:?}", names);
410        assert!(names.contains(&"inner"), "got: {:?}", names);
411
412        let inner = entities.iter().find(|e| e.name == "inner").unwrap();
413        assert!(inner.parent_id.is_some(), "inner should have parent_id");
414    }
415
416    #[test]
417    fn test_rust_impl_blocks_unique_names() {
418        let code = r#"
419trait Greeting {
420    fn greet(&self) -> String;
421}
422
423struct Person;
424struct Robot;
425struct Cat;
426
427impl Greeting for Person {
428    fn greet(&self) -> String { "Hello".to_string() }
429}
430
431impl Greeting for Robot {
432    fn greet(&self) -> String { "Beep".to_string() }
433}
434
435impl Greeting for Cat {
436    fn greet(&self) -> String { "Meow".to_string() }
437}
438"#;
439        let plugin = CodeParserPlugin;
440        let entities = plugin.extract_entities(code, "impls.rs");
441        let impl_entities: Vec<&_> = entities.iter()
442            .filter(|e| e.entity_type == "impl")
443            .collect();
444        let names: Vec<&str> = impl_entities.iter().map(|e| e.name.as_str()).collect();
445
446        assert_eq!(impl_entities.len(), 3, "Should find 3 impl blocks, got: {:?}", names);
447        assert!(names.contains(&"Greeting for Person"), "got: {:?}", names);
448        assert!(names.contains(&"Greeting for Robot"), "got: {:?}", names);
449        assert!(names.contains(&"Greeting for Cat"), "got: {:?}", names);
450    }
451
452    #[test]
453    fn test_nested_functions_go() {
454        // Go doesn't have named nested functions, but has nested type/var declarations
455        let code = "package main\n\nfunc outer() {\n    var x int = 42\n    _ = x\n}\n";
456        let plugin = CodeParserPlugin;
457        let entities = plugin.extract_entities(code, "nested.go");
458        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
459
460        assert!(names.contains(&"outer"), "got: {:?}", names);
461    }
462
463    #[test]
464    fn test_renamed_function_same_structural_hash() {
465        let code_a = "def get_card():\n    return db.query('cards')\n";
466        let code_b = "def get_card_1():\n    return db.query('cards')\n";
467
468        let plugin = CodeParserPlugin;
469        let entities_a = plugin.extract_entities(code_a, "a.py");
470        let entities_b = plugin.extract_entities(code_b, "b.py");
471
472        assert_eq!(entities_a.len(), 1, "Should find one entity in a");
473        assert_eq!(entities_b.len(), 1, "Should find one entity in b");
474        assert_eq!(entities_a[0].name, "get_card");
475        assert_eq!(entities_b[0].name, "get_card_1");
476
477        // Structural hash should match since only the name differs
478        assert_eq!(
479            entities_a[0].structural_hash, entities_b[0].structural_hash,
480            "Renamed function with identical body should have same structural_hash"
481        );
482
483        // Content hash should differ (it includes the name)
484        assert_ne!(
485            entities_a[0].content_hash, entities_b[0].content_hash,
486            "Content hash should differ since raw content includes the name"
487        );
488    }
489
490    #[test]
491    fn test_hcl_entity_extraction() {
492        let code = r#"
493region = "eu-west-1"
494
495variable "image_id" {
496  type = string
497}
498
499resource "aws_instance" "web" {
500  ami = var.image_id
501
502  lifecycle {
503    create_before_destroy = true
504  }
505}
506"#;
507        let plugin = CodeParserPlugin;
508        let entities = plugin.extract_entities(code, "main.tf");
509        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
510        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
511        eprintln!("HCL entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
512
513        assert!(names.contains(&"region"), "Should find top-level attribute, got: {:?}", names);
514        assert!(names.contains(&"variable.image_id"), "Should find variable block, got: {:?}", names);
515        assert!(names.contains(&"resource.aws_instance.web"), "Should find resource block, got: {:?}", names);
516        assert!(
517            names.contains(&"resource.aws_instance.web.lifecycle"),
518            "Should find nested lifecycle block with qualified name, got: {:?}",
519            names
520        );
521        assert!(!names.contains(&"ami"), "Should skip nested attributes inside blocks, got: {:?}", names);
522        assert!(
523            !names.contains(&"create_before_destroy"),
524            "Should skip nested attributes inside nested blocks, got: {:?}",
525            names
526        );
527
528        let lifecycle = entities
529            .iter()
530            .find(|e| e.name == "resource.aws_instance.web.lifecycle")
531            .unwrap();
532        assert!(lifecycle.parent_id.is_some(), "lifecycle should be nested under resource");
533        assert!(types.contains(&"attribute"), "Should preserve attribute entity type for top-level attributes");
534    }
535
536    #[test]
537    fn test_kotlin_entity_extraction() {
538        let code = r#"
539class UserService {
540    val name: String = ""
541
542    fun greet(): String {
543        return "Hello, $name"
544    }
545
546    companion object {
547        fun create(): UserService = UserService()
548    }
549}
550
551interface Repository {
552    fun findById(id: Int): Any?
553}
554
555object AppConfig {
556    val version = "1.0"
557}
558
559fun topLevel(x: Int): Int = x * 2
560"#;
561        let plugin = CodeParserPlugin;
562        let entities = plugin.extract_entities(code, "App.kt");
563        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
564        eprintln!("Kotlin entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
565        assert!(names.contains(&"UserService"), "got: {:?}", names);
566        assert!(names.contains(&"greet"), "got: {:?}", names);
567        assert!(names.contains(&"Repository"), "got: {:?}", names);
568        assert!(names.contains(&"findById"), "got: {:?}", names);
569        assert!(names.contains(&"AppConfig"), "got: {:?}", names);
570        assert!(names.contains(&"topLevel"), "got: {:?}", names);
571    }
572
573    #[test]
574    fn test_xml_entity_extraction() {
575        let code = r#"<?xml version="1.0" encoding="UTF-8"?>
576<project>
577    <groupId>com.example</groupId>
578    <artifactId>my-app</artifactId>
579    <dependencies>
580        <dependency>
581            <groupId>junit</groupId>
582            <artifactId>junit</artifactId>
583        </dependency>
584    </dependencies>
585    <build>
586        <plugins>
587            <plugin>
588                <groupId>org.apache.maven</groupId>
589            </plugin>
590        </plugins>
591    </build>
592</project>
593"#;
594        let plugin = CodeParserPlugin;
595        let entities = plugin.extract_entities(code, "pom.xml");
596        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
597        eprintln!("XML entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
598        assert!(names.contains(&"project"), "got: {:?}", names);
599        assert!(names.contains(&"dependencies"), "got: {:?}", names);
600        assert!(names.contains(&"build"), "got: {:?}", names);
601    }
602
603    #[test]
604    fn test_go_var_declaration() {
605        let code = r#"package featuremgmt
606
607type FeatureFlag struct {
608	Name        string
609	Description string
610	Stage       string
611}
612
613var standardFeatureFlags = []FeatureFlag{
614	{
615		Name:        "panelTitleSearch",
616		Description: "Search for dashboards using panel title",
617		Stage:       "PublicPreview",
618	},
619}
620
621func GetFlags() []FeatureFlag {
622	return standardFeatureFlags
623}
624"#;
625        let plugin = CodeParserPlugin;
626        let entities = plugin.extract_entities(code, "flags.go");
627        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
628        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
629        eprintln!("Go entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
630
631        assert!(names.contains(&"FeatureFlag"), "Should find type FeatureFlag, got: {:?}", names);
632        assert!(names.contains(&"standardFeatureFlags"), "Should find var standardFeatureFlags, got: {:?}", names);
633        assert!(names.contains(&"GetFlags"), "Should find func GetFlags, got: {:?}", names);
634    }
635
636    #[test]
637    fn test_go_grouped_var_declaration() {
638        let code = r#"package test
639
640var (
641	simple = 42
642	flags = []string{"a", "b"}
643)
644
645const (
646	x = 1
647	y = 2
648)
649
650func main() {}
651"#;
652        let plugin = CodeParserPlugin;
653        let entities = plugin.extract_entities(code, "test.go");
654        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
655        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
656        eprintln!("Go grouped entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
657
658        assert!(names.contains(&"flags") || names.contains(&"simple"), "Should find grouped var, got: {:?}", names);
659        assert!(names.contains(&"x"), "Should find grouped const x, got: {:?}", names);
660        assert!(names.contains(&"main"), "Should find func main, got: {:?}", names);
661    }
662}