Skip to main content

tsift_graph/
lang.rs

1use anyhow::Result;
2use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Symbol {
6    pub name: String,
7    pub kind: String,
8    pub line: usize,
9    pub end_line: usize,
10}
11
12#[allow(dead_code)]
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum Lang {
15    #[cfg(feature = "lang-rust")]
16    Rust,
17    #[cfg(feature = "lang-python")]
18    Python,
19    #[cfg(feature = "lang-typescript")]
20    TypeScript,
21    #[cfg(feature = "lang-typescript")]
22    Tsx,
23    #[cfg(feature = "lang-javascript")]
24    JavaScript,
25    #[cfg(feature = "lang-javascript")]
26    Jsx,
27    #[cfg(feature = "lang-kotlin")]
28    Kotlin,
29    #[cfg(feature = "lang-zig")]
30    Zig,
31    #[cfg(feature = "lang-bash")]
32    Bash,
33    #[cfg(feature = "lang-markdown")]
34    Markdown,
35}
36
37#[allow(dead_code)]
38impl Lang {
39    pub fn from_extension(ext: &str) -> Option<Self> {
40        match ext {
41            #[cfg(feature = "lang-rust")]
42            "rs" => Some(Self::Rust),
43            #[cfg(feature = "lang-python")]
44            "py" | "pyi" => Some(Self::Python),
45            #[cfg(feature = "lang-typescript")]
46            "ts" => Some(Self::TypeScript),
47            #[cfg(feature = "lang-typescript")]
48            "tsx" => Some(Self::Tsx),
49            #[cfg(feature = "lang-javascript")]
50            "js" | "mjs" | "cjs" => Some(Self::JavaScript),
51            #[cfg(feature = "lang-javascript")]
52            "jsx" => Some(Self::Jsx),
53            #[cfg(feature = "lang-kotlin")]
54            "kt" | "kts" => Some(Self::Kotlin),
55            #[cfg(feature = "lang-zig")]
56            "zig" => Some(Self::Zig),
57            #[cfg(feature = "lang-bash")]
58            "sh" | "bash" | "zsh" => Some(Self::Bash),
59            #[cfg(feature = "lang-markdown")]
60            "md" | "mdx" => Some(Self::Markdown),
61            _ => None,
62        }
63    }
64
65    pub fn tree_sitter_language(&self) -> Language {
66        match self {
67            #[cfg(feature = "lang-rust")]
68            Self::Rust => tree_sitter_rust::LANGUAGE.into(),
69            #[cfg(feature = "lang-python")]
70            Self::Python => tree_sitter_python::LANGUAGE.into(),
71            #[cfg(feature = "lang-typescript")]
72            Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
73            #[cfg(feature = "lang-typescript")]
74            Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
75            #[cfg(feature = "lang-javascript")]
76            Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
77            #[cfg(feature = "lang-javascript")]
78            Self::Jsx => tree_sitter_javascript::LANGUAGE.into(),
79            #[cfg(feature = "lang-kotlin")]
80            Self::Kotlin => tree_sitter_kotlin_ng::LANGUAGE.into(),
81            #[cfg(feature = "lang-zig")]
82            Self::Zig => tree_sitter_zig::LANGUAGE.into(),
83            #[cfg(feature = "lang-bash")]
84            Self::Bash => tree_sitter_bash::LANGUAGE.into(),
85            #[cfg(feature = "lang-markdown")]
86            Self::Markdown => tree_sitter_md::LANGUAGE.into(),
87        }
88    }
89
90    pub fn name(&self) -> &'static str {
91        match self {
92            #[cfg(feature = "lang-rust")]
93            Self::Rust => "rust",
94            #[cfg(feature = "lang-python")]
95            Self::Python => "python",
96            #[cfg(feature = "lang-typescript")]
97            Self::TypeScript => "typescript",
98            #[cfg(feature = "lang-typescript")]
99            Self::Tsx => "tsx",
100            #[cfg(feature = "lang-javascript")]
101            Self::JavaScript => "javascript",
102            #[cfg(feature = "lang-javascript")]
103            Self::Jsx => "jsx",
104            #[cfg(feature = "lang-kotlin")]
105            Self::Kotlin => "kotlin",
106            #[cfg(feature = "lang-zig")]
107            Self::Zig => "zig",
108            #[cfg(feature = "lang-bash")]
109            Self::Bash => "bash",
110            #[cfg(feature = "lang-markdown")]
111            Self::Markdown => "markdown",
112        }
113    }
114
115    pub fn symbol_query(&self) -> &'static str {
116        match self {
117            #[cfg(feature = "lang-rust")]
118            Self::Rust => {
119                r#"
120                (function_item name: (identifier) @function.name)
121                (struct_item name: (type_identifier) @struct.name)
122                (enum_item name: (type_identifier) @enum.name)
123                (trait_item name: (type_identifier) @trait.name)
124                (impl_item type: (type_identifier) @impl.name)
125                (mod_item name: (identifier) @mod.name)
126                (type_item name: (type_identifier) @type_alias.name)
127                (const_item name: (identifier) @const.name)
128                (static_item name: (identifier) @static.name)
129            "#
130            }
131            #[cfg(feature = "lang-python")]
132            Self::Python => {
133                r#"
134                (function_definition name: (identifier) @function.name)
135                (class_definition name: (identifier) @class.name)
136            "#
137            }
138            #[cfg(feature = "lang-typescript")]
139            Self::TypeScript | Self::Tsx => {
140                r#"
141                (function_declaration name: (identifier) @function.name)
142                (class_declaration name: (type_identifier) @class.name)
143                (interface_declaration name: (type_identifier) @interface.name)
144                (type_alias_declaration name: (type_identifier) @type_alias.name)
145                (enum_declaration name: (identifier) @enum.name)
146                (variable_declarator name: (identifier) @function.name value: (arrow_function))
147            "#
148            }
149            #[cfg(feature = "lang-javascript")]
150            Self::JavaScript | Self::Jsx => {
151                r#"
152                (function_declaration name: (identifier) @function.name)
153                (class_declaration name: (identifier) @class.name)
154                (variable_declarator name: (identifier) @function.name value: (arrow_function))
155            "#
156            }
157            #[cfg(feature = "lang-kotlin")]
158            Self::Kotlin => {
159                r#"
160                (function_declaration name: (identifier) @function.name)
161                (class_declaration "interface" name: (identifier) @interface.name)
162                (class_declaration (modifiers (class_modifier "data")) name: (identifier) @data_class.name)
163                (class_declaration (modifiers (class_modifier "sealed")) name: (identifier) @sealed_class.name)
164                (class_declaration (modifiers (class_modifier "enum")) name: (identifier) @enum_class.name)
165                (class_declaration "class" name: (identifier) @class.name)
166                (object_declaration name: (identifier) @object.name)
167                (companion_object name: (identifier) @companion_object.name)
168            "#
169            }
170            #[cfg(feature = "lang-zig")]
171            Self::Zig => {
172                r#"
173                (function_declaration (identifier) @function.name)
174                (variable_declaration (identifier) @struct.name (struct_declaration))
175                (variable_declaration (identifier) @enum.name (enum_declaration))
176                (variable_declaration (identifier) @union.name (union_declaration))
177                (variable_declaration (identifier) @const.name)
178            "#
179            }
180            #[cfg(feature = "lang-bash")]
181            Self::Bash => {
182                r#"
183                (function_definition name: (word) @function.name)
184            "#
185            }
186            #[cfg(feature = "lang-markdown")]
187            Self::Markdown => {
188                r#"
189                (atx_heading (atx_h1_marker) (inline) @heading.name)
190                (atx_heading (atx_h2_marker) (inline) @heading.name)
191                (atx_heading (atx_h3_marker) (inline) @heading.name)
192                (atx_heading (atx_h4_marker) (inline) @heading.name)
193                (atx_heading (atx_h5_marker) (inline) @heading.name)
194                (atx_heading (atx_h6_marker) (inline) @heading.name)
195                (fenced_code_block (info_string (language) @code_block.name))
196            "#
197            }
198        }
199    }
200
201    pub fn call_query(&self) -> Option<&'static str> {
202        match self {
203            #[cfg(feature = "lang-rust")]
204            Self::Rust => Some(
205                r#"
206                (call_expression function: (identifier) @call.name)
207                (call_expression function: (field_expression field: (field_identifier) @call.name))
208                (call_expression function: (scoped_identifier name: (identifier) @call.name))
209                (macro_invocation macro: (identifier) @call.name)
210            "#,
211            ),
212            #[cfg(feature = "lang-python")]
213            Self::Python => Some(
214                r#"
215                (call function: (identifier) @call.name)
216                (call function: (attribute attribute: (identifier) @call.name))
217            "#,
218            ),
219            #[cfg(feature = "lang-typescript")]
220            Self::TypeScript | Self::Tsx => Some(
221                r#"
222                (call_expression function: (identifier) @call.name)
223                (call_expression function: (member_expression property: (property_identifier) @call.name))
224            "#,
225            ),
226            #[cfg(feature = "lang-javascript")]
227            Self::JavaScript | Self::Jsx => Some(
228                r#"
229                (call_expression function: (identifier) @call.name)
230                (call_expression function: (member_expression property: (property_identifier) @call.name))
231            "#,
232            ),
233            #[cfg(feature = "lang-kotlin")]
234            Self::Kotlin => Some(
235                r#"
236                (call_expression (simple_identifier) @call.name)
237            "#,
238            ),
239            _ => None,
240        }
241    }
242
243    pub fn extract_symbols(&self, source: &[u8]) -> Result<Vec<Symbol>> {
244        let mut parser = Parser::new();
245        let ts_lang = self.tree_sitter_language();
246        parser.set_language(&ts_lang)?;
247        let tree = parser
248            .parse(source, None)
249            .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
250        let query = Query::new(&ts_lang, self.symbol_query())?;
251        let mut cursor = QueryCursor::new();
252        let mut symbols = Vec::new();
253        let capture_names: Vec<String> = query
254            .capture_names()
255            .iter()
256            .map(|s| s.to_string())
257            .collect();
258
259        let mut matches = cursor.matches(&query, tree.root_node(), source);
260        while let Some(m) = matches.next() {
261            for capture in m.captures {
262                let capture_name = &capture_names[capture.index as usize];
263                if let Some(kind_str) = capture_name.strip_suffix(".name") {
264                    let name = capture
265                        .node
266                        .utf8_text(source)
267                        .unwrap_or("<invalid utf8>")
268                        .to_string();
269                    let parent_end = capture
270                        .node
271                        .parent()
272                        .map(|p| p.end_position().row)
273                        .unwrap_or(capture.node.end_position().row);
274                    symbols.push(Symbol {
275                        name,
276                        kind: kind_str.to_string(),
277                        line: capture.node.start_position().row,
278                        end_line: parent_end,
279                    });
280                }
281            }
282        }
283
284        #[cfg(feature = "lang-bash")]
285        if *self == Self::Bash {
286            Self::extract_bash_aliases(&tree, source, &mut symbols);
287        }
288        symbols.sort_by(|a, b| a.line.cmp(&b.line).then(a.name.cmp(&b.name)));
289        symbols.dedup_by(|b, a| {
290            a.name == b.name && a.line == b.line && {
291                let a_generic = matches!(a.kind.as_str(), "variable" | "const");
292                let b_generic = matches!(b.kind.as_str(), "variable" | "const");
293                match (a_generic, b_generic) {
294                    (true, false) => a.kind.clone_from(&b.kind),
295                    (false, true) => {}
296                    _ => {
297                        if b.kind.len() > a.kind.len() {
298                            a.kind.clone_from(&b.kind);
299                        }
300                    }
301                }
302                true
303            }
304        });
305        Ok(symbols)
306    }
307
308    #[cfg(feature = "lang-bash")]
309    fn extract_bash_aliases(tree: &tree_sitter::Tree, source: &[u8], symbols: &mut Vec<Symbol>) {
310        let mut tree_cursor = tree.root_node().walk();
311        if !tree_cursor.goto_first_child() {
312            return;
313        }
314        loop {
315            let node = tree_cursor.node();
316            if node.kind() == "command"
317                && let Some(name_node) = node.child_by_field_name("name")
318            {
319                let cmd = name_node.utf8_text(source).unwrap_or("");
320                if cmd == "alias" {
321                    for i in 0..node.named_child_count() {
322                        if let Some(arg) = node.named_child(i as u32)
323                            && (arg.kind() == "concatenation" || arg.kind() == "word")
324                        {
325                            let text = arg.utf8_text(source).unwrap_or("");
326                            if let Some(alias_name) = text.split('=').next()
327                                && !alias_name.is_empty()
328                                && alias_name != cmd
329                            {
330                                symbols.push(Symbol {
331                                    name: alias_name.to_string(),
332                                    kind: "alias".to_string(),
333                                    line: arg.start_position().row,
334                                    end_line: node.end_position().row,
335                                });
336                            }
337                        }
338                    }
339                }
340            }
341            if !tree_cursor.goto_next_sibling() {
342                break;
343            }
344        }
345    }
346
347    pub fn all() -> Vec<Self> {
348        vec![
349            #[cfg(feature = "lang-rust")]
350            Self::Rust,
351            #[cfg(feature = "lang-python")]
352            Self::Python,
353            #[cfg(feature = "lang-typescript")]
354            Self::TypeScript,
355            #[cfg(feature = "lang-typescript")]
356            Self::Tsx,
357            #[cfg(feature = "lang-javascript")]
358            Self::JavaScript,
359            #[cfg(feature = "lang-javascript")]
360            Self::Jsx,
361            #[cfg(feature = "lang-kotlin")]
362            Self::Kotlin,
363            #[cfg(feature = "lang-zig")]
364            Self::Zig,
365            #[cfg(feature = "lang-bash")]
366            Self::Bash,
367            #[cfg(feature = "lang-markdown")]
368            Self::Markdown,
369        ]
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_all_grammars_create_parser() {
379        for lang in Lang::all() {
380            let ts_lang = lang.tree_sitter_language();
381            let mut parser = tree_sitter::Parser::new();
382            parser
383                .set_language(&ts_lang)
384                .unwrap_or_else(|e| panic!("failed to set language for {:?}: {}", lang, e));
385        }
386    }
387
388    #[test]
389    fn test_extension_dispatch() {
390        let cases = [
391            ("rs", "rust"),
392            ("py", "python"),
393            ("pyi", "python"),
394            ("ts", "typescript"),
395            ("tsx", "tsx"),
396            ("js", "javascript"),
397            ("mjs", "javascript"),
398            ("cjs", "javascript"),
399            ("jsx", "jsx"),
400            ("kt", "kotlin"),
401            ("kts", "kotlin"),
402            ("zig", "zig"),
403            ("sh", "bash"),
404            ("bash", "bash"),
405            ("zsh", "bash"),
406            ("md", "markdown"),
407            ("mdx", "markdown"),
408        ];
409        for (ext, expected_name) in cases {
410            let lang = Lang::from_extension(ext)
411                .unwrap_or_else(|| panic!("no language for extension: {ext}"));
412            assert_eq!(lang.name(), expected_name, "wrong language for .{ext}");
413        }
414    }
415
416    #[test]
417    fn test_unknown_extension_returns_none() {
418        assert!(Lang::from_extension("xyz").is_none());
419        assert!(Lang::from_extension("").is_none());
420        assert!(Lang::from_extension("txt").is_none());
421    }
422
423    #[cfg(feature = "lang-rust")]
424    #[test]
425    fn test_parse_rust_snippet() {
426        let lang = Lang::Rust;
427        let mut parser = tree_sitter::Parser::new();
428        parser.set_language(&lang.tree_sitter_language()).unwrap();
429        let tree = parser.parse("fn main() {}", None).unwrap();
430        assert_eq!(tree.root_node().kind(), "source_file");
431        assert!(!tree.root_node().has_error());
432    }
433
434    #[cfg(feature = "lang-python")]
435    #[test]
436    fn test_parse_python_snippet() {
437        let lang = Lang::Python;
438        let mut parser = tree_sitter::Parser::new();
439        parser.set_language(&lang.tree_sitter_language()).unwrap();
440        let tree = parser.parse("def hello():\n    pass\n", None).unwrap();
441        assert_eq!(tree.root_node().kind(), "module");
442        assert!(!tree.root_node().has_error());
443    }
444
445    #[cfg(feature = "lang-typescript")]
446    #[test]
447    fn test_parse_typescript_snippet() {
448        let lang = Lang::TypeScript;
449        let mut parser = tree_sitter::Parser::new();
450        parser.set_language(&lang.tree_sitter_language()).unwrap();
451        let tree = parser
452            .parse("function greet(name: string): void {}", None)
453            .unwrap();
454        assert_eq!(tree.root_node().kind(), "program");
455        assert!(!tree.root_node().has_error());
456    }
457
458    #[cfg(feature = "lang-typescript")]
459    #[test]
460    fn test_parse_tsx_snippet() {
461        let lang = Lang::Tsx;
462        let mut parser = tree_sitter::Parser::new();
463        parser.set_language(&lang.tree_sitter_language()).unwrap();
464        let tree = parser
465            .parse("const App = () => <div>hello</div>;", None)
466            .unwrap();
467        assert_eq!(tree.root_node().kind(), "program");
468        assert!(!tree.root_node().has_error());
469    }
470
471    #[cfg(feature = "lang-javascript")]
472    #[test]
473    fn test_parse_javascript_snippet() {
474        let lang = Lang::JavaScript;
475        let mut parser = tree_sitter::Parser::new();
476        parser.set_language(&lang.tree_sitter_language()).unwrap();
477        let tree = parser
478            .parse("function hello() { return 42; }", None)
479            .unwrap();
480        assert_eq!(tree.root_node().kind(), "program");
481        assert!(!tree.root_node().has_error());
482    }
483
484    #[cfg(feature = "lang-kotlin")]
485    #[test]
486    fn test_parse_kotlin_snippet() {
487        let lang = Lang::Kotlin;
488        let mut parser = tree_sitter::Parser::new();
489        parser.set_language(&lang.tree_sitter_language()).unwrap();
490        let tree = parser
491            .parse("fun main() { println(\"hello\") }", None)
492            .unwrap();
493        assert_eq!(tree.root_node().kind(), "source_file");
494        assert!(!tree.root_node().has_error());
495    }
496
497    #[cfg(feature = "lang-zig")]
498    #[test]
499    fn test_parse_zig_snippet() {
500        let lang = Lang::Zig;
501        let mut parser = tree_sitter::Parser::new();
502        parser.set_language(&lang.tree_sitter_language()).unwrap();
503        let tree = parser.parse("pub fn main() !void {}", None).unwrap();
504        assert_eq!(tree.root_node().kind(), "source_file");
505    }
506
507    #[cfg(feature = "lang-bash")]
508    #[test]
509    fn test_parse_bash_snippet() {
510        let lang = Lang::Bash;
511        let mut parser = tree_sitter::Parser::new();
512        parser.set_language(&lang.tree_sitter_language()).unwrap();
513        let tree = parser
514            .parse("#!/bin/bash\nhello() { echo hi; }\n", None)
515            .unwrap();
516        assert_eq!(tree.root_node().kind(), "program");
517        assert!(!tree.root_node().has_error());
518    }
519
520    #[cfg(feature = "lang-markdown")]
521    #[test]
522    fn test_parse_markdown_snippet() {
523        let lang = Lang::Markdown;
524        let mut parser = tree_sitter::Parser::new();
525        parser.set_language(&lang.tree_sitter_language()).unwrap();
526        let tree = parser.parse("# Hello\n\nSome text.\n", None).unwrap();
527        assert_eq!(tree.root_node().kind(), "document");
528        assert!(!tree.root_node().has_error());
529    }
530
531    #[test]
532    fn test_all_symbol_queries_compile() {
533        for lang in Lang::all() {
534            let ts_lang = lang.tree_sitter_language();
535            tree_sitter::Query::new(&ts_lang, lang.symbol_query())
536                .unwrap_or_else(|e| panic!("query compile failed for {:?}: {}", lang, e));
537        }
538    }
539
540    #[cfg(feature = "lang-rust")]
541    #[test]
542    fn test_extract_rust_symbols() {
543        let source = b"fn main() {}\nstruct Foo;\nenum Bar {}\ntrait Baz {}\nconst X: i32 = 1;\nstatic Y: i32 = 2;\nmod inner {}\ntype Alias = i32;\n";
544        let symbols = Lang::Rust.extract_symbols(source).unwrap();
545        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
546        assert!(names.contains(&"main"), "missing main, got {:?}", names);
547        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
548        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
549        assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
550        assert!(names.contains(&"X"), "missing X, got {:?}", names);
551        assert!(names.contains(&"Y"), "missing Y, got {:?}", names);
552        assert!(names.contains(&"inner"), "missing inner, got {:?}", names);
553        assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
554        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
555        assert_eq!(main_sym.kind, "function");
556        let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
557        assert_eq!(foo_sym.kind, "struct");
558    }
559
560    #[cfg(feature = "lang-python")]
561    #[test]
562    fn test_extract_python_symbols() {
563        let source =
564            b"def hello():\n    pass\n\nclass MyClass:\n    def method(self):\n        pass\n";
565        let symbols = Lang::Python.extract_symbols(source).unwrap();
566        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
567        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
568        assert!(
569            names.contains(&"MyClass"),
570            "missing MyClass, got {:?}",
571            names
572        );
573        assert!(names.contains(&"method"), "missing method, got {:?}", names);
574        let cls = symbols.iter().find(|s| s.name == "MyClass").unwrap();
575        assert_eq!(cls.kind, "class");
576    }
577
578    #[cfg(feature = "lang-typescript")]
579    #[test]
580    fn test_extract_typescript_symbols() {
581        let source = b"function greet(name: string): void {}\nclass Foo {}\ninterface Bar {}\ntype Alias = string;\nenum Color { Red, Green }\n";
582        let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
583        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
584        assert!(names.contains(&"greet"), "missing greet, got {:?}", names);
585        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
586        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
587        assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
588        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
589    }
590
591    #[cfg(feature = "lang-javascript")]
592    #[test]
593    fn test_extract_javascript_symbols() {
594        let source = b"function hello() { return 42; }\nclass Widget {}\n";
595        let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
596        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
597        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
598        assert!(names.contains(&"Widget"), "missing Widget, got {:?}", names);
599    }
600
601    #[cfg(feature = "lang-kotlin")]
602    #[test]
603    fn test_extract_kotlin_symbols() {
604        let source = b"fun main() { println(\"hi\") }\nclass Foo\ninterface Bar\ndata class Baz(val x: Int)\nsealed class Qux\nenum class Color { RED, GREEN }\nobject Singleton\n";
605        let symbols = Lang::Kotlin.extract_symbols(source).unwrap();
606        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
607        assert!(names.contains(&"main"), "missing main, got {:?}", names);
608        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
609        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
610        assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
611        assert!(names.contains(&"Qux"), "missing Qux, got {:?}", names);
612        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
613        assert!(
614            names.contains(&"Singleton"),
615            "missing Singleton, got {:?}",
616            names
617        );
618        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
619        assert_eq!(main_sym.kind, "function");
620        let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
621        assert_eq!(foo_sym.kind, "class");
622        let bar_sym = symbols.iter().find(|s| s.name == "Bar").unwrap();
623        assert_eq!(bar_sym.kind, "interface");
624        let baz_sym = symbols.iter().find(|s| s.name == "Baz").unwrap();
625        assert_eq!(baz_sym.kind, "data_class");
626        let qux_sym = symbols.iter().find(|s| s.name == "Qux").unwrap();
627        assert_eq!(qux_sym.kind, "sealed_class");
628        let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
629        assert_eq!(color_sym.kind, "enum_class");
630        let singleton_sym = symbols.iter().find(|s| s.name == "Singleton").unwrap();
631        assert_eq!(singleton_sym.kind, "object");
632        assert_eq!(
633            symbols.len(),
634            7,
635            "expected exactly 7 symbols, got {:?}",
636            symbols
637        );
638    }
639
640    #[cfg(feature = "lang-zig")]
641    #[test]
642    fn test_extract_zig_symbols() {
643        let source = b"const std = @import(\"std\");\npub fn main() !void {}\nconst Point = struct { x: i32, y: i32 };\nconst Color = enum { red, green, blue };\nconst Result = union(enum) { ok: i32, err: []const u8 };\nconst MAX: i32 = 100;\n";
644        let symbols = Lang::Zig.extract_symbols(source).unwrap();
645        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
646        assert!(names.contains(&"main"), "missing main, got {:?}", names);
647        assert!(names.contains(&"Point"), "missing Point, got {:?}", names);
648        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
649        assert!(names.contains(&"Result"), "missing Result, got {:?}", names);
650        assert!(names.contains(&"std"), "missing std, got {:?}", names);
651        assert!(names.contains(&"MAX"), "missing MAX, got {:?}", names);
652        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
653        assert_eq!(main_sym.kind, "function");
654        let point_sym = symbols.iter().find(|s| s.name == "Point").unwrap();
655        assert_eq!(point_sym.kind, "struct");
656        let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
657        assert_eq!(color_sym.kind, "enum");
658        let result_sym = symbols.iter().find(|s| s.name == "Result").unwrap();
659        assert_eq!(result_sym.kind, "union");
660        let max_sym = symbols.iter().find(|s| s.name == "MAX").unwrap();
661        assert_eq!(max_sym.kind, "const");
662    }
663
664    #[cfg(feature = "lang-bash")]
665    #[test]
666    fn test_extract_bash_symbols() {
667        let source = b"#!/bin/bash\nhello() { echo hi; }\nfunction world { echo world; }\nalias ll='ls -la'\nalias grep='grep --color=auto'\n";
668        let symbols = Lang::Bash.extract_symbols(source).unwrap();
669        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
670        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
671        assert!(names.contains(&"world"), "missing world, got {:?}", names);
672        assert!(names.contains(&"ll"), "missing alias ll, got {:?}", names);
673        assert!(
674            names.contains(&"grep"),
675            "missing alias grep, got {:?}",
676            names
677        );
678        let hello_sym = symbols.iter().find(|s| s.name == "hello").unwrap();
679        assert_eq!(hello_sym.kind, "function");
680        let ll_sym = symbols.iter().find(|s| s.name == "ll").unwrap();
681        assert_eq!(ll_sym.kind, "alias");
682    }
683
684    #[cfg(feature = "lang-markdown")]
685    #[test]
686    fn test_extract_markdown_symbols() {
687        let source = b"# Title\n\n## Section One\n\nSome text.\n\n```rust\nfn main() {}\n```\n\n### Subsection\n\n```python\ndef hello():\n    pass\n```\n";
688        let symbols = Lang::Markdown.extract_symbols(source).unwrap();
689        let headings: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "heading").collect();
690        let code_blocks: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "code_block").collect();
691        assert_eq!(headings.len(), 3, "expected 3 headings, got {:?}", headings);
692        assert_eq!(
693            code_blocks.len(),
694            2,
695            "expected 2 code blocks, got {:?}",
696            code_blocks
697        );
698        assert!(
699            code_blocks.iter().any(|s| s.name == "rust"),
700            "missing rust block, got {:?}",
701            code_blocks
702        );
703        assert!(
704            code_blocks.iter().any(|s| s.name == "python"),
705            "missing python block, got {:?}",
706            code_blocks
707        );
708    }
709
710    #[cfg(feature = "lang-python")]
711    #[test]
712    fn test_python_async_def() {
713        let source = b"async def fetch_data():\n    await get()\n\ndef sync_fn():\n    pass\n";
714        let symbols = Lang::Python.extract_symbols(source).unwrap();
715        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
716        assert!(
717            names.contains(&"fetch_data"),
718            "missing async function, got {:?}",
719            names
720        );
721        assert!(
722            names.contains(&"sync_fn"),
723            "missing sync function, got {:?}",
724            names
725        );
726    }
727
728    #[cfg(feature = "lang-python")]
729    #[test]
730    fn test_python_decorated_function() {
731        let source = b"@staticmethod\ndef helper():\n    pass\n\n@property\ndef name(self):\n    return self._name\n";
732        let symbols = Lang::Python.extract_symbols(source).unwrap();
733        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
734        assert!(
735            names.contains(&"helper"),
736            "missing decorated function, got {:?}",
737            names
738        );
739        assert!(
740            names.contains(&"name"),
741            "missing property function, got {:?}",
742            names
743        );
744    }
745
746    #[cfg(feature = "lang-typescript")]
747    #[test]
748    fn test_typescript_arrow_exports() {
749        let source = b"export const Foo = () => { return 42; };\nexport const Bar = (x: number): number => x + 1;\nconst local = () => {};\nfunction regular() {}\n";
750        let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
751        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
752        assert!(
753            names.contains(&"Foo"),
754            "missing arrow export Foo, got {:?}",
755            names
756        );
757        assert!(
758            names.contains(&"Bar"),
759            "missing arrow export Bar, got {:?}",
760            names
761        );
762        assert!(
763            names.contains(&"local"),
764            "missing local arrow, got {:?}",
765            names
766        );
767        assert!(
768            names.contains(&"regular"),
769            "missing regular function, got {:?}",
770            names
771        );
772    }
773
774    #[cfg(feature = "lang-typescript")]
775    #[test]
776    fn test_tsx_arrow_component() {
777        let source = b"export const MyComponent = () => <div>hello</div>;\nfunction Other() { return <span/>; }\n";
778        let symbols = Lang::Tsx.extract_symbols(source).unwrap();
779        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
780        assert!(
781            names.contains(&"MyComponent"),
782            "missing arrow component, got {:?}",
783            names
784        );
785        assert!(
786            names.contains(&"Other"),
787            "missing function component, got {:?}",
788            names
789        );
790    }
791
792    #[cfg(feature = "lang-javascript")]
793    #[test]
794    fn test_javascript_arrow_exports() {
795        let source = b"export const handler = () => { return 'ok'; };\nconst helper = (x) => x * 2;\nfunction regular() {}\n";
796        let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
797        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
798        assert!(
799            names.contains(&"handler"),
800            "missing arrow export, got {:?}",
801            names
802        );
803        assert!(
804            names.contains(&"helper"),
805            "missing local arrow, got {:?}",
806            names
807        );
808        assert!(
809            names.contains(&"regular"),
810            "missing regular function, got {:?}",
811            names
812        );
813    }
814
815    #[cfg(feature = "lang-javascript")]
816    #[test]
817    fn test_jsx_arrow_component() {
818        let source = b"const App = () => <div>hi</div>;\nfunction Page() { return <main/>; }\n";
819        let symbols = Lang::Jsx.extract_symbols(source).unwrap();
820        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
821        assert!(
822            names.contains(&"App"),
823            "missing arrow JSX component, got {:?}",
824            names
825        );
826        assert!(
827            names.contains(&"Page"),
828            "missing function component, got {:?}",
829            names
830        );
831    }
832}