Skip to main content

sem_core/parser/plugins/code/
mod.rs

1mod entity_extractor;
2mod languages;
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::plugin::SemanticParserPlugin;
9use languages::{get_all_code_extensions, get_language_config};
10use entity_extractor::extract_entities;
11
12pub struct CodeParserPlugin;
13
14// Thread-local parser cache: one Parser per language per thread.
15// Avoids creating a new Parser for every file during parallel graph builds.
16thread_local! {
17    static PARSER_CACHE: RefCell<HashMap<&'static str, tree_sitter::Parser>> = RefCell::new(HashMap::new());
18}
19
20impl SemanticParserPlugin for CodeParserPlugin {
21    fn id(&self) -> &str {
22        "code"
23    }
24
25    fn extensions(&self) -> &[&str] {
26        get_all_code_extensions()
27    }
28
29    fn extract_entities(&self, content: &str, file_path: &str) -> Vec<SemanticEntity> {
30        let ext = std::path::Path::new(file_path)
31            .extension()
32            .and_then(|e| e.to_str())
33            .map(|e| format!(".{}", e.to_lowercase()))
34            .unwrap_or_default();
35
36        let config = match get_language_config(&ext) {
37            Some(c) => c,
38            None => return Vec::new(),
39        };
40
41        let language = match (config.get_language)() {
42            Some(lang) => lang,
43            None => return Vec::new(),
44        };
45
46        PARSER_CACHE.with(|cache| {
47            let mut cache = cache.borrow_mut();
48            let parser = cache.entry(config.id).or_insert_with(|| {
49                let mut p = tree_sitter::Parser::new();
50                let _ = p.set_language(&language);
51                p
52            });
53
54            let tree = match parser.parse(content.as_bytes(), None) {
55                Some(t) => t,
56                None => return Vec::new(),
57            };
58
59            extract_entities(&tree, file_path, config, content)
60        })
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn test_java_entity_extraction() {
70        let code = r#"
71package com.example;
72
73import java.util.List;
74
75public class UserService {
76    private String name;
77
78    public UserService(String name) {
79        this.name = name;
80    }
81
82    public List<User> getUsers() {
83        return db.findAll();
84    }
85
86    public void createUser(User user) {
87        db.save(user);
88    }
89}
90
91interface Repository<T> {
92    T findById(String id);
93    List<T> findAll();
94}
95
96enum Status {
97    ACTIVE,
98    INACTIVE,
99    DELETED
100}
101"#;
102        let plugin = CodeParserPlugin;
103        let entities = plugin.extract_entities(code, "UserService.java");
104        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
105        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
106        eprintln!("Java entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
107
108        assert!(names.contains(&"UserService"), "Should find class UserService, got: {:?}", names);
109        assert!(names.contains(&"Repository"), "Should find interface Repository, got: {:?}", names);
110        assert!(names.contains(&"Status"), "Should find enum Status, got: {:?}", names);
111    }
112
113    #[test]
114    fn test_java_nested_methods() {
115        let code = r#"
116public class Calculator {
117    public int add(int a, int b) {
118        return a + b;
119    }
120
121    public int subtract(int a, int b) {
122        return a - b;
123    }
124}
125"#;
126        let plugin = CodeParserPlugin;
127        let entities = plugin.extract_entities(code, "Calculator.java");
128        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
129        eprintln!("Java nested: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
130
131        assert!(names.contains(&"Calculator"), "Should find Calculator class");
132        assert!(names.contains(&"add"), "Should find add method, got: {:?}", names);
133        assert!(names.contains(&"subtract"), "Should find subtract method, got: {:?}", names);
134
135        // Methods should have Calculator as parent
136        let add = entities.iter().find(|e| e.name == "add").unwrap();
137        assert!(add.parent_id.is_some(), "add should have parent_id");
138    }
139
140    #[test]
141    fn test_c_entity_extraction() {
142        let code = r#"
143#include <stdio.h>
144
145struct Point {
146    int x;
147    int y;
148};
149
150enum Color {
151    RED,
152    GREEN,
153    BLUE
154};
155
156typedef struct {
157    char name[50];
158    int age;
159} Person;
160
161void greet(const char* name) {
162    printf("Hello, %s!\n", name);
163}
164
165int add(int a, int b) {
166    return a + b;
167}
168
169int main() {
170    greet("world");
171    return 0;
172}
173"#;
174        let plugin = CodeParserPlugin;
175        let entities = plugin.extract_entities(code, "main.c");
176        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
177        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
178        eprintln!("C entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
179
180        assert!(names.contains(&"greet"), "Should find greet function, got: {:?}", names);
181        assert!(names.contains(&"add"), "Should find add function, got: {:?}", names);
182        assert!(names.contains(&"main"), "Should find main function, got: {:?}", names);
183        assert!(names.contains(&"Point"), "Should find Point struct, got: {:?}", names);
184        assert!(names.contains(&"Color"), "Should find Color enum, got: {:?}", names);
185    }
186
187    #[test]
188    fn test_cpp_entity_extraction() {
189        let code = "namespace math {\nclass Vector3 {\npublic:\n    float length() const { return 0; }\n};\n}\nvoid greet() {}\n";
190        let plugin = CodeParserPlugin;
191        let entities = plugin.extract_entities(code, "main.cpp");
192        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
193        assert!(names.contains(&"math"), "got: {:?}", names);
194        assert!(names.contains(&"Vector3"), "got: {:?}", names);
195        assert!(names.contains(&"greet"), "got: {:?}", names);
196    }
197
198    #[test]
199    fn test_ruby_entity_extraction() {
200        let code = "module Auth\n  class User\n    def greet\n      \"hi\"\n    end\n  end\nend\ndef helper(x)\n  x * 2\nend\n";
201        let plugin = CodeParserPlugin;
202        let entities = plugin.extract_entities(code, "auth.rb");
203        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
204        assert!(names.contains(&"Auth"), "got: {:?}", names);
205        assert!(names.contains(&"User"), "got: {:?}", names);
206        assert!(names.contains(&"helper"), "got: {:?}", names);
207    }
208
209    #[test]
210    fn test_csharp_entity_extraction() {
211        let code = "namespace MyApp {\npublic class User {\n    public string GetName() { return \"\"; }\n}\npublic enum Role { Admin, User }\n}\n";
212        let plugin = CodeParserPlugin;
213        let entities = plugin.extract_entities(code, "Models.cs");
214        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
215        assert!(names.contains(&"MyApp"), "got: {:?}", names);
216        assert!(names.contains(&"User"), "got: {:?}", names);
217        assert!(names.contains(&"Role"), "got: {:?}", names);
218    }
219
220    #[test]
221    fn test_swift_entity_extraction() {
222        let code = r#"
223import Foundation
224
225class UserService {
226    var name: String
227
228    init(name: String) {
229        self.name = name
230    }
231
232    func getUsers() -> [User] {
233        return db.findAll()
234    }
235}
236
237struct Point {
238    var x: Double
239    var y: Double
240}
241
242enum Status {
243    case active
244    case inactive
245    case deleted
246}
247
248protocol Repository {
249    associatedtype Item
250    func findById(id: String) -> Item?
251    func findAll() -> [Item]
252}
253
254func helper(x: Int) -> Int {
255    return x * 2
256}
257"#;
258        let plugin = CodeParserPlugin;
259        let entities = plugin.extract_entities(code, "UserService.swift");
260        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
261        eprintln!("Swift entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
262
263        assert!(names.contains(&"UserService"), "Should find class UserService, got: {:?}", names);
264        assert!(names.contains(&"Point"), "Should find struct Point, got: {:?}", names);
265        assert!(names.contains(&"Status"), "Should find enum Status, got: {:?}", names);
266        assert!(names.contains(&"Repository"), "Should find protocol Repository, got: {:?}", names);
267        assert!(names.contains(&"helper"), "Should find function helper, got: {:?}", names);
268    }
269
270    #[test]
271    fn test_elixir_entity_extraction() {
272        let code = r#"
273defmodule MyApp.Accounts do
274  def create_user(attrs) do
275    %User{}
276    |> User.changeset(attrs)
277    |> Repo.insert()
278  end
279
280  defp validate(attrs) do
281    # private helper
282    :ok
283  end
284
285  defmacro is_admin(user) do
286    quote do
287      unquote(user).role == :admin
288    end
289  end
290
291  defguard is_positive(x) when is_integer(x) and x > 0
292end
293
294defprotocol Printable do
295  def to_string(data)
296end
297
298defimpl Printable, for: Integer do
299  def to_string(i), do: Integer.to_string(i)
300end
301"#;
302        let plugin = CodeParserPlugin;
303        let entities = plugin.extract_entities(code, "accounts.ex");
304        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
305        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
306        eprintln!("Elixir entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
307
308        assert!(names.contains(&"MyApp.Accounts"), "Should find module, got: {:?}", names);
309        assert!(names.contains(&"create_user"), "Should find def, got: {:?}", names);
310        assert!(names.contains(&"validate"), "Should find defp, got: {:?}", names);
311        assert!(names.contains(&"is_admin"), "Should find defmacro, got: {:?}", names);
312        assert!(names.contains(&"Printable"), "Should find defprotocol, got: {:?}", names);
313
314        // Verify nesting: create_user should have MyApp.Accounts as parent
315        let create_user = entities.iter().find(|e| e.name == "create_user").unwrap();
316        assert!(create_user.parent_id.is_some(), "create_user should be nested under module");
317    }
318
319    #[test]
320    fn test_bash_entity_extraction() {
321        let code = r#"#!/bin/bash
322
323greet() {
324    echo "Hello, $1!"
325}
326
327function deploy {
328    echo "deploying..."
329}
330
331# not a function
332echo "main script"
333"#;
334        let plugin = CodeParserPlugin;
335        let entities = plugin.extract_entities(code, "deploy.sh");
336        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
337        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
338        eprintln!("Bash entities: {:?}", names.iter().zip(types.iter()).collect::<Vec<_>>());
339
340        assert!(names.contains(&"greet"), "Should find greet(), got: {:?}", names);
341        assert!(names.contains(&"deploy"), "Should find function deploy, got: {:?}", names);
342        assert_eq!(entities.len(), 2, "Should only find functions, got: {:?}", names);
343    }
344
345    #[test]
346    fn test_typescript_entity_extraction() {
347        // Existing language should still work
348        let code = r#"
349export function hello(): string {
350    return "hello";
351}
352
353export class Greeter {
354    greet(name: string): string {
355        return `Hello, ${name}!`;
356    }
357}
358"#;
359        let plugin = CodeParserPlugin;
360        let entities = plugin.extract_entities(code, "test.ts");
361        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
362        assert!(names.contains(&"hello"), "Should find hello function");
363        assert!(names.contains(&"Greeter"), "Should find Greeter class");
364    }
365
366    #[test]
367    fn test_nested_functions_typescript() {
368        let code = r#"
369function outer() {
370    function inner() {
371        return 42;
372    }
373    return inner();
374}
375"#;
376        let plugin = CodeParserPlugin;
377        let entities = plugin.extract_entities(code, "nested.ts");
378        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
379        eprintln!("Nested TS: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
380
381        assert!(names.contains(&"outer"), "Should find outer, got: {:?}", names);
382        assert!(names.contains(&"inner"), "Should find inner, got: {:?}", names);
383
384        let inner = entities.iter().find(|e| e.name == "inner").unwrap();
385        assert!(inner.parent_id.is_some(), "inner should have parent_id");
386    }
387
388    #[test]
389    fn test_nested_functions_python() {
390        let code = "def outer():\n    def inner():\n        return 42\n    return inner()\n";
391        let plugin = CodeParserPlugin;
392        let entities = plugin.extract_entities(code, "nested.py");
393        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
394
395        assert!(names.contains(&"outer"), "got: {:?}", names);
396        assert!(names.contains(&"inner"), "got: {:?}", names);
397
398        let inner = entities.iter().find(|e| e.name == "inner").unwrap();
399        assert!(inner.parent_id.is_some(), "inner should have parent_id");
400    }
401
402    #[test]
403    fn test_nested_functions_rust() {
404        let code = "fn outer() {\n    fn inner() -> i32 {\n        42\n    }\n    inner();\n}\n";
405        let plugin = CodeParserPlugin;
406        let entities = plugin.extract_entities(code, "nested.rs");
407        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
408
409        assert!(names.contains(&"outer"), "got: {:?}", names);
410        assert!(names.contains(&"inner"), "got: {:?}", names);
411
412        let inner = entities.iter().find(|e| e.name == "inner").unwrap();
413        assert!(inner.parent_id.is_some(), "inner should have parent_id");
414    }
415
416    #[test]
417    fn test_nested_functions_go() {
418        // Go doesn't have named nested functions, but has nested type/var declarations
419        let code = "package main\n\nfunc outer() {\n    var x int = 42\n    _ = x\n}\n";
420        let plugin = CodeParserPlugin;
421        let entities = plugin.extract_entities(code, "nested.go");
422        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
423
424        assert!(names.contains(&"outer"), "got: {:?}", names);
425    }
426
427    #[test]
428    fn test_renamed_function_same_structural_hash() {
429        let code_a = "def get_card():\n    return db.query('cards')\n";
430        let code_b = "def get_card_1():\n    return db.query('cards')\n";
431
432        let plugin = CodeParserPlugin;
433        let entities_a = plugin.extract_entities(code_a, "a.py");
434        let entities_b = plugin.extract_entities(code_b, "b.py");
435
436        assert_eq!(entities_a.len(), 1, "Should find one entity in a");
437        assert_eq!(entities_b.len(), 1, "Should find one entity in b");
438        assert_eq!(entities_a[0].name, "get_card");
439        assert_eq!(entities_b[0].name, "get_card_1");
440
441        // Structural hash should match since only the name differs
442        assert_eq!(
443            entities_a[0].structural_hash, entities_b[0].structural_hash,
444            "Renamed function with identical body should have same structural_hash"
445        );
446
447        // Content hash should differ (it includes the name)
448        assert_ne!(
449            entities_a[0].content_hash, entities_b[0].content_hash,
450            "Content hash should differ since raw content includes the name"
451        );
452    }
453
454    #[test]
455    fn test_hcl_entity_extraction() {
456        let code = r#"
457region = "eu-west-1"
458
459variable "image_id" {
460  type = string
461}
462
463resource "aws_instance" "web" {
464  ami = var.image_id
465
466  lifecycle {
467    create_before_destroy = true
468  }
469}
470"#;
471        let plugin = CodeParserPlugin;
472        let entities = plugin.extract_entities(code, "main.tf");
473        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
474        let types: Vec<&str> = entities.iter().map(|e| e.entity_type.as_str()).collect();
475        eprintln!("HCL entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type, &e.parent_id)).collect::<Vec<_>>());
476
477        assert!(names.contains(&"region"), "Should find top-level attribute, got: {:?}", names);
478        assert!(names.contains(&"variable.image_id"), "Should find variable block, got: {:?}", names);
479        assert!(names.contains(&"resource.aws_instance.web"), "Should find resource block, got: {:?}", names);
480        assert!(
481            names.contains(&"resource.aws_instance.web.lifecycle"),
482            "Should find nested lifecycle block with qualified name, got: {:?}",
483            names
484        );
485        assert!(!names.contains(&"ami"), "Should skip nested attributes inside blocks, got: {:?}", names);
486        assert!(
487            !names.contains(&"create_before_destroy"),
488            "Should skip nested attributes inside nested blocks, got: {:?}",
489            names
490        );
491
492        let lifecycle = entities
493            .iter()
494            .find(|e| e.name == "resource.aws_instance.web.lifecycle")
495            .unwrap();
496        assert!(lifecycle.parent_id.is_some(), "lifecycle should be nested under resource");
497        assert!(types.contains(&"attribute"), "Should preserve attribute entity type for top-level attributes");
498    }
499
500    #[test]
501    fn test_kotlin_entity_extraction() {
502        let code = r#"
503class UserService {
504    val name: String = ""
505
506    fun greet(): String {
507        return "Hello, $name"
508    }
509
510    companion object {
511        fun create(): UserService = UserService()
512    }
513}
514
515interface Repository {
516    fun findById(id: Int): Any?
517}
518
519object AppConfig {
520    val version = "1.0"
521}
522
523fun topLevel(x: Int): Int = x * 2
524"#;
525        let plugin = CodeParserPlugin;
526        let entities = plugin.extract_entities(code, "App.kt");
527        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
528        eprintln!("Kotlin entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
529        assert!(names.contains(&"UserService"), "got: {:?}", names);
530        assert!(names.contains(&"greet"), "got: {:?}", names);
531        assert!(names.contains(&"Repository"), "got: {:?}", names);
532        assert!(names.contains(&"findById"), "got: {:?}", names);
533        assert!(names.contains(&"AppConfig"), "got: {:?}", names);
534        assert!(names.contains(&"topLevel"), "got: {:?}", names);
535    }
536
537    #[test]
538    fn test_xml_entity_extraction() {
539        let code = r#"<?xml version="1.0" encoding="UTF-8"?>
540<project>
541    <groupId>com.example</groupId>
542    <artifactId>my-app</artifactId>
543    <dependencies>
544        <dependency>
545            <groupId>junit</groupId>
546            <artifactId>junit</artifactId>
547        </dependency>
548    </dependencies>
549    <build>
550        <plugins>
551            <plugin>
552                <groupId>org.apache.maven</groupId>
553            </plugin>
554        </plugins>
555    </build>
556</project>
557"#;
558        let plugin = CodeParserPlugin;
559        let entities = plugin.extract_entities(code, "pom.xml");
560        let names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
561        eprintln!("XML entities: {:?}", entities.iter().map(|e| (&e.name, &e.entity_type)).collect::<Vec<_>>());
562        assert!(names.contains(&"project"), "got: {:?}", names);
563        assert!(names.contains(&"dependencies"), "got: {:?}", names);
564        assert!(names.contains(&"build"), "got: {:?}", names);
565    }
566}