Skip to main content

tsift_graph/
lang.rs

1use anyhow::Result;
2use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Symbol {
6    pub name: String,
7    pub kind: String,
8    pub line: usize,
9    pub end_line: usize,
10    pub node_kind: String,
11    pub start_byte: usize,
12    pub end_byte: usize,
13    pub body_start_byte: Option<usize>,
14    pub body_end_byte: Option<usize>,
15}
16
17#[allow(dead_code)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Lang {
20    #[cfg(feature = "lang-rust")]
21    Rust,
22    #[cfg(feature = "lang-python")]
23    Python,
24    #[cfg(feature = "lang-typescript")]
25    TypeScript,
26    #[cfg(feature = "lang-typescript")]
27    Tsx,
28    #[cfg(feature = "lang-javascript")]
29    JavaScript,
30    #[cfg(feature = "lang-javascript")]
31    Jsx,
32    #[cfg(feature = "lang-kotlin")]
33    Kotlin,
34    #[cfg(feature = "lang-zig")]
35    Zig,
36    #[cfg(feature = "lang-bash")]
37    Bash,
38    #[cfg(feature = "lang-markdown")]
39    Markdown,
40}
41
42#[allow(dead_code)]
43impl Lang {
44    pub fn from_extension(ext: &str) -> Option<Self> {
45        match ext {
46            #[cfg(feature = "lang-rust")]
47            "rs" => Some(Self::Rust),
48            #[cfg(feature = "lang-python")]
49            "py" | "pyi" => Some(Self::Python),
50            #[cfg(feature = "lang-typescript")]
51            "ts" => Some(Self::TypeScript),
52            #[cfg(feature = "lang-typescript")]
53            "tsx" => Some(Self::Tsx),
54            #[cfg(feature = "lang-javascript")]
55            "js" | "mjs" | "cjs" => Some(Self::JavaScript),
56            #[cfg(feature = "lang-javascript")]
57            "jsx" => Some(Self::Jsx),
58            #[cfg(feature = "lang-kotlin")]
59            "kt" | "kts" => Some(Self::Kotlin),
60            #[cfg(feature = "lang-zig")]
61            "zig" => Some(Self::Zig),
62            #[cfg(feature = "lang-bash")]
63            "sh" | "bash" | "zsh" => Some(Self::Bash),
64            #[cfg(feature = "lang-markdown")]
65            "md" | "mdx" => Some(Self::Markdown),
66            _ => None,
67        }
68    }
69
70    pub fn tree_sitter_language(&self) -> Language {
71        match self {
72            #[cfg(feature = "lang-rust")]
73            Self::Rust => tree_sitter_rust::LANGUAGE.into(),
74            #[cfg(feature = "lang-python")]
75            Self::Python => tree_sitter_python::LANGUAGE.into(),
76            #[cfg(feature = "lang-typescript")]
77            Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
78            #[cfg(feature = "lang-typescript")]
79            Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
80            #[cfg(feature = "lang-javascript")]
81            Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
82            #[cfg(feature = "lang-javascript")]
83            Self::Jsx => tree_sitter_javascript::LANGUAGE.into(),
84            #[cfg(feature = "lang-kotlin")]
85            Self::Kotlin => tree_sitter_kotlin_ng::LANGUAGE.into(),
86            #[cfg(feature = "lang-zig")]
87            Self::Zig => tree_sitter_zig::LANGUAGE.into(),
88            #[cfg(feature = "lang-bash")]
89            Self::Bash => tree_sitter_bash::LANGUAGE.into(),
90            #[cfg(feature = "lang-markdown")]
91            Self::Markdown => tsift_md_ast::markdown_language(),
92        }
93    }
94
95    pub fn name(&self) -> &'static str {
96        match self {
97            #[cfg(feature = "lang-rust")]
98            Self::Rust => "rust",
99            #[cfg(feature = "lang-python")]
100            Self::Python => "python",
101            #[cfg(feature = "lang-typescript")]
102            Self::TypeScript => "typescript",
103            #[cfg(feature = "lang-typescript")]
104            Self::Tsx => "tsx",
105            #[cfg(feature = "lang-javascript")]
106            Self::JavaScript => "javascript",
107            #[cfg(feature = "lang-javascript")]
108            Self::Jsx => "jsx",
109            #[cfg(feature = "lang-kotlin")]
110            Self::Kotlin => "kotlin",
111            #[cfg(feature = "lang-zig")]
112            Self::Zig => "zig",
113            #[cfg(feature = "lang-bash")]
114            Self::Bash => "bash",
115            #[cfg(feature = "lang-markdown")]
116            Self::Markdown => "markdown",
117        }
118    }
119
120    pub fn symbol_query(&self) -> &'static str {
121        match self {
122            #[cfg(feature = "lang-rust")]
123            Self::Rust => {
124                r#"
125                (function_item name: (identifier) @function.name)
126                (struct_item name: (type_identifier) @struct.name)
127                (enum_item name: (type_identifier) @enum.name)
128                (trait_item name: (type_identifier) @trait.name)
129                (impl_item type: (type_identifier) @impl.name)
130                (mod_item name: (identifier) @mod.name)
131                (type_item name: (type_identifier) @type_alias.name)
132                (const_item name: (identifier) @const.name)
133                (static_item name: (identifier) @static.name)
134            "#
135            }
136            #[cfg(feature = "lang-python")]
137            Self::Python => {
138                r#"
139                (function_definition name: (identifier) @function.name)
140                (class_definition name: (identifier) @class.name)
141            "#
142            }
143            #[cfg(feature = "lang-typescript")]
144            Self::TypeScript | Self::Tsx => {
145                r#"
146                (function_declaration name: (identifier) @function.name)
147                (class_declaration name: (type_identifier) @class.name)
148                (interface_declaration name: (type_identifier) @interface.name)
149                (type_alias_declaration name: (type_identifier) @type_alias.name)
150                (enum_declaration name: (identifier) @enum.name)
151                (variable_declarator name: (identifier) @function.name value: (arrow_function))
152            "#
153            }
154            #[cfg(feature = "lang-javascript")]
155            Self::JavaScript | Self::Jsx => {
156                r#"
157                (function_declaration name: (identifier) @function.name)
158                (class_declaration name: (identifier) @class.name)
159                (variable_declarator name: (identifier) @function.name value: (arrow_function))
160            "#
161            }
162            #[cfg(feature = "lang-kotlin")]
163            Self::Kotlin => {
164                r#"
165                (function_declaration name: (identifier) @function.name)
166                (class_declaration "interface" name: (identifier) @interface.name)
167                (class_declaration (modifiers (class_modifier "data")) name: (identifier) @data_class.name)
168                (class_declaration (modifiers (class_modifier "sealed")) name: (identifier) @sealed_class.name)
169                (class_declaration (modifiers (class_modifier "enum")) name: (identifier) @enum_class.name)
170                (class_declaration "class" name: (identifier) @class.name)
171                (object_declaration name: (identifier) @object.name)
172                (companion_object name: (identifier) @companion_object.name)
173            "#
174            }
175            #[cfg(feature = "lang-zig")]
176            Self::Zig => {
177                r#"
178                (function_declaration (identifier) @function.name)
179                (variable_declaration (identifier) @struct.name (struct_declaration))
180                (variable_declaration (identifier) @enum.name (enum_declaration))
181                (variable_declaration (identifier) @union.name (union_declaration))
182                (variable_declaration (identifier) @const.name)
183            "#
184            }
185            #[cfg(feature = "lang-bash")]
186            Self::Bash => {
187                r#"
188                (function_definition name: (word) @function.name)
189            "#
190            }
191            #[cfg(feature = "lang-markdown")]
192            Self::Markdown => {
193                r#"
194                (atx_heading (atx_h1_marker) (inline) @heading.name)
195                (atx_heading (atx_h2_marker) (inline) @heading.name)
196                (atx_heading (atx_h3_marker) (inline) @heading.name)
197                (atx_heading (atx_h4_marker) (inline) @heading.name)
198                (atx_heading (atx_h5_marker) (inline) @heading.name)
199                (atx_heading (atx_h6_marker) (inline) @heading.name)
200                (fenced_code_block (info_string (language) @code_block.name))
201            "#
202            }
203        }
204    }
205
206    pub fn call_query(&self) -> Option<&'static str> {
207        match self {
208            #[cfg(feature = "lang-rust")]
209            Self::Rust => Some(
210                r#"
211                (call_expression function: (identifier) @call.name)
212                (call_expression function: (field_expression field: (field_identifier) @call.name))
213                (call_expression function: (scoped_identifier name: (identifier) @call.name))
214                (macro_invocation macro: (identifier) @call.name)
215            "#,
216            ),
217            #[cfg(feature = "lang-python")]
218            Self::Python => Some(
219                r#"
220                (call function: (identifier) @call.name)
221                (call function: (attribute attribute: (identifier) @call.name))
222            "#,
223            ),
224            #[cfg(feature = "lang-typescript")]
225            Self::TypeScript | Self::Tsx => Some(
226                r#"
227                (call_expression function: (identifier) @call.name)
228                (call_expression function: (member_expression property: (property_identifier) @call.name))
229            "#,
230            ),
231            #[cfg(feature = "lang-javascript")]
232            Self::JavaScript | Self::Jsx => Some(
233                r#"
234                (call_expression function: (identifier) @call.name)
235                (call_expression function: (member_expression property: (property_identifier) @call.name))
236            "#,
237            ),
238            #[cfg(feature = "lang-kotlin")]
239            Self::Kotlin => Some(
240                r#"
241                (call_expression (simple_identifier) @call.name)
242            "#,
243            ),
244            _ => None,
245        }
246    }
247
248    pub fn extract_symbols(&self, source: &[u8]) -> Result<Vec<Symbol>> {
249        let mut parser = Parser::new();
250        let ts_lang = self.tree_sitter_language();
251        parser.set_language(&ts_lang)?;
252        let tree = parser
253            .parse(source, None)
254            .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
255        #[cfg(feature = "lang-markdown")]
256        if *self == Self::Markdown {
257            return Ok(tsift_md_ast::markdown_symbols_from_tree(&tree, source)
258                .into_iter()
259                .map(md_symbol_to_symbol)
260                .collect());
261        }
262        let query = Query::new(&ts_lang, self.symbol_query())?;
263        let mut cursor = QueryCursor::new();
264        let mut symbols = Vec::new();
265        let capture_names: Vec<String> = query
266            .capture_names()
267            .iter()
268            .map(|s| s.to_string())
269            .collect();
270
271        let mut matches = cursor.matches(&query, tree.root_node(), source);
272        while let Some(m) = matches.next() {
273            for capture in m.captures {
274                let capture_name = &capture_names[capture.index as usize];
275                if let Some(kind_str) = capture_name.strip_suffix(".name") {
276                    let name = capture
277                        .node
278                        .utf8_text(source)
279                        .unwrap_or("<invalid utf8>")
280                        .to_string();
281                    let node = symbol_node_for_capture(kind_str, capture.node);
282                    let body_span = symbol_body_span(node);
283                    symbols.push(Symbol {
284                        name,
285                        kind: kind_str.to_string(),
286                        line: node.start_position().row,
287                        end_line: node.end_position().row,
288                        node_kind: node.kind().to_string(),
289                        start_byte: node.start_byte(),
290                        end_byte: node.end_byte(),
291                        body_start_byte: body_span.map(|(start, _)| start),
292                        body_end_byte: body_span.map(|(_, end)| end),
293                    });
294                }
295            }
296        }
297
298        #[cfg(feature = "lang-bash")]
299        if *self == Self::Bash {
300            Self::extract_bash_aliases(&tree, source, &mut symbols);
301        }
302        symbols.sort_by(|a, b| a.line.cmp(&b.line).then(a.name.cmp(&b.name)));
303        symbols.dedup_by(|b, a| {
304            a.name == b.name && a.line == b.line && {
305                let a_generic = matches!(a.kind.as_str(), "variable" | "const");
306                let b_generic = matches!(b.kind.as_str(), "variable" | "const");
307                match (a_generic, b_generic) {
308                    (true, false) => a.kind.clone_from(&b.kind),
309                    (false, true) => {}
310                    _ => {
311                        if b.kind.len() > a.kind.len() {
312                            a.kind.clone_from(&b.kind);
313                        }
314                    }
315                }
316                true
317            }
318        });
319        Ok(symbols)
320    }
321
322    #[cfg(feature = "lang-bash")]
323    fn extract_bash_aliases(tree: &tree_sitter::Tree, source: &[u8], symbols: &mut Vec<Symbol>) {
324        let mut tree_cursor = tree.root_node().walk();
325        if !tree_cursor.goto_first_child() {
326            return;
327        }
328        loop {
329            let node = tree_cursor.node();
330            if node.kind() == "command"
331                && let Some(name_node) = node.child_by_field_name("name")
332            {
333                let cmd = name_node.utf8_text(source).unwrap_or("");
334                if cmd == "alias" {
335                    for i in 0..node.named_child_count() {
336                        if let Some(arg) = node.named_child(i as u32)
337                            && (arg.kind() == "concatenation" || arg.kind() == "word")
338                        {
339                            let text = arg.utf8_text(source).unwrap_or("");
340                            if let Some(alias_name) = text.split('=').next()
341                                && !alias_name.is_empty()
342                                && alias_name != cmd
343                            {
344                                symbols.push(Symbol {
345                                    name: alias_name.to_string(),
346                                    kind: "alias".to_string(),
347                                    line: arg.start_position().row,
348                                    end_line: node.end_position().row,
349                                    node_kind: node.kind().to_string(),
350                                    start_byte: arg.start_byte(),
351                                    end_byte: node.end_byte(),
352                                    body_start_byte: None,
353                                    body_end_byte: None,
354                                });
355                            }
356                        }
357                    }
358                }
359            }
360            if !tree_cursor.goto_next_sibling() {
361                break;
362            }
363        }
364    }
365
366    pub fn all() -> Vec<Self> {
367        vec![
368            #[cfg(feature = "lang-rust")]
369            Self::Rust,
370            #[cfg(feature = "lang-python")]
371            Self::Python,
372            #[cfg(feature = "lang-typescript")]
373            Self::TypeScript,
374            #[cfg(feature = "lang-typescript")]
375            Self::Tsx,
376            #[cfg(feature = "lang-javascript")]
377            Self::JavaScript,
378            #[cfg(feature = "lang-javascript")]
379            Self::Jsx,
380            #[cfg(feature = "lang-kotlin")]
381            Self::Kotlin,
382            #[cfg(feature = "lang-zig")]
383            Self::Zig,
384            #[cfg(feature = "lang-bash")]
385            Self::Bash,
386            #[cfg(feature = "lang-markdown")]
387            Self::Markdown,
388        ]
389    }
390}
391
392fn symbol_node_for_capture<'tree>(
393    kind: &str,
394    name_node: tree_sitter::Node<'tree>,
395) -> tree_sitter::Node<'tree> {
396    let mut node = name_node.parent().unwrap_or(name_node);
397    if kind == "code_block" {
398        while let Some(parent) = node.parent() {
399            node = parent;
400            if node.kind() == "fenced_code_block" {
401                break;
402            }
403        }
404    }
405    node
406}
407
408fn symbol_body_span(node: tree_sitter::Node<'_>) -> Option<(usize, usize)> {
409    if let Some(body) = node.child_by_field_name("body") {
410        return Some((body.start_byte(), body.end_byte()));
411    }
412    for idx in 0..node.named_child_count() {
413        let Some(child) = node.named_child(idx as u32) else {
414            continue;
415        };
416        if matches!(
417            child.kind(),
418            "block"
419                | "declaration_list"
420                | "field_declaration_list"
421                | "enum_variant_list"
422                | "match_block"
423                | "statement_block"
424                | "suite"
425        ) {
426            return Some((child.start_byte(), child.end_byte()));
427        }
428    }
429    None
430}
431
432#[cfg(feature = "lang-markdown")]
433fn md_symbol_to_symbol(md: tsift_md_ast::MdSymbol) -> Symbol {
434    Symbol {
435        name: md.name,
436        kind: md.kind,
437        line: md.line,
438        end_line: md.end_line,
439        node_kind: md.node_kind,
440        start_byte: md.start_byte,
441        end_byte: md.end_byte,
442        body_start_byte: md.body_start_byte,
443        body_end_byte: md.body_end_byte,
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_all_grammars_create_parser() {
453        for lang in Lang::all() {
454            let ts_lang = lang.tree_sitter_language();
455            let mut parser = tree_sitter::Parser::new();
456            parser
457                .set_language(&ts_lang)
458                .unwrap_or_else(|e| panic!("failed to set language for {:?}: {}", lang, e));
459        }
460    }
461
462    #[test]
463    fn test_extension_dispatch() {
464        let cases = [
465            ("rs", "rust"),
466            ("py", "python"),
467            ("pyi", "python"),
468            ("ts", "typescript"),
469            ("tsx", "tsx"),
470            ("js", "javascript"),
471            ("mjs", "javascript"),
472            ("cjs", "javascript"),
473            ("jsx", "jsx"),
474            ("kt", "kotlin"),
475            ("kts", "kotlin"),
476            ("zig", "zig"),
477            ("sh", "bash"),
478            ("bash", "bash"),
479            ("zsh", "bash"),
480            ("md", "markdown"),
481            ("mdx", "markdown"),
482        ];
483        for (ext, expected_name) in cases {
484            let lang = Lang::from_extension(ext)
485                .unwrap_or_else(|| panic!("no language for extension: {ext}"));
486            assert_eq!(lang.name(), expected_name, "wrong language for .{ext}");
487        }
488    }
489
490    #[test]
491    fn test_unknown_extension_returns_none() {
492        assert!(Lang::from_extension("xyz").is_none());
493        assert!(Lang::from_extension("").is_none());
494        assert!(Lang::from_extension("txt").is_none());
495    }
496
497    #[cfg(feature = "lang-rust")]
498    #[test]
499    fn test_parse_rust_snippet() {
500        let lang = Lang::Rust;
501        let mut parser = tree_sitter::Parser::new();
502        parser.set_language(&lang.tree_sitter_language()).unwrap();
503        let tree = parser.parse("fn main() {}", None).unwrap();
504        assert_eq!(tree.root_node().kind(), "source_file");
505        assert!(!tree.root_node().has_error());
506    }
507
508    #[cfg(feature = "lang-python")]
509    #[test]
510    fn test_parse_python_snippet() {
511        let lang = Lang::Python;
512        let mut parser = tree_sitter::Parser::new();
513        parser.set_language(&lang.tree_sitter_language()).unwrap();
514        let tree = parser.parse("def hello():\n    pass\n", None).unwrap();
515        assert_eq!(tree.root_node().kind(), "module");
516        assert!(!tree.root_node().has_error());
517    }
518
519    #[cfg(feature = "lang-typescript")]
520    #[test]
521    fn test_parse_typescript_snippet() {
522        let lang = Lang::TypeScript;
523        let mut parser = tree_sitter::Parser::new();
524        parser.set_language(&lang.tree_sitter_language()).unwrap();
525        let tree = parser
526            .parse("function greet(name: string): void {}", None)
527            .unwrap();
528        assert_eq!(tree.root_node().kind(), "program");
529        assert!(!tree.root_node().has_error());
530    }
531
532    #[cfg(feature = "lang-typescript")]
533    #[test]
534    fn test_parse_tsx_snippet() {
535        let lang = Lang::Tsx;
536        let mut parser = tree_sitter::Parser::new();
537        parser.set_language(&lang.tree_sitter_language()).unwrap();
538        let tree = parser
539            .parse("const App = () => <div>hello</div>;", None)
540            .unwrap();
541        assert_eq!(tree.root_node().kind(), "program");
542        assert!(!tree.root_node().has_error());
543    }
544
545    #[cfg(feature = "lang-javascript")]
546    #[test]
547    fn test_parse_javascript_snippet() {
548        let lang = Lang::JavaScript;
549        let mut parser = tree_sitter::Parser::new();
550        parser.set_language(&lang.tree_sitter_language()).unwrap();
551        let tree = parser
552            .parse("function hello() { return 42; }", None)
553            .unwrap();
554        assert_eq!(tree.root_node().kind(), "program");
555        assert!(!tree.root_node().has_error());
556    }
557
558    #[cfg(feature = "lang-kotlin")]
559    #[test]
560    fn test_parse_kotlin_snippet() {
561        let lang = Lang::Kotlin;
562        let mut parser = tree_sitter::Parser::new();
563        parser.set_language(&lang.tree_sitter_language()).unwrap();
564        let tree = parser
565            .parse("fun main() { println(\"hello\") }", None)
566            .unwrap();
567        assert_eq!(tree.root_node().kind(), "source_file");
568        assert!(!tree.root_node().has_error());
569    }
570
571    #[cfg(feature = "lang-zig")]
572    #[test]
573    fn test_parse_zig_snippet() {
574        let lang = Lang::Zig;
575        let mut parser = tree_sitter::Parser::new();
576        parser.set_language(&lang.tree_sitter_language()).unwrap();
577        let tree = parser.parse("pub fn main() !void {}", None).unwrap();
578        assert_eq!(tree.root_node().kind(), "source_file");
579    }
580
581    #[cfg(feature = "lang-bash")]
582    #[test]
583    fn test_parse_bash_snippet() {
584        let lang = Lang::Bash;
585        let mut parser = tree_sitter::Parser::new();
586        parser.set_language(&lang.tree_sitter_language()).unwrap();
587        let tree = parser
588            .parse("#!/bin/bash\nhello() { echo hi; }\n", None)
589            .unwrap();
590        assert_eq!(tree.root_node().kind(), "program");
591        assert!(!tree.root_node().has_error());
592    }
593
594    #[cfg(feature = "lang-markdown")]
595    #[test]
596    fn test_parse_markdown_snippet() {
597        let lang = Lang::Markdown;
598        let mut parser = tree_sitter::Parser::new();
599        parser.set_language(&lang.tree_sitter_language()).unwrap();
600        let tree = parser.parse("# Hello\n\nSome text.\n", None).unwrap();
601        assert_eq!(tree.root_node().kind(), "document");
602        assert!(!tree.root_node().has_error());
603    }
604
605    #[test]
606    fn test_all_symbol_queries_compile() {
607        for lang in Lang::all() {
608            let ts_lang = lang.tree_sitter_language();
609            tree_sitter::Query::new(&ts_lang, lang.symbol_query())
610                .unwrap_or_else(|e| panic!("query compile failed for {:?}: {}", lang, e));
611        }
612    }
613
614    #[cfg(feature = "lang-rust")]
615    #[test]
616    fn test_extract_rust_symbols() {
617        let source = b"fn main() {}\nstruct Foo;\nenum Bar {}\ntrait Baz {}\nconst X: i32 = 1;\nstatic Y: i32 = 2;\nmod inner {}\ntype Alias = i32;\n";
618        let symbols = Lang::Rust.extract_symbols(source).unwrap();
619        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
620        assert!(names.contains(&"main"), "missing main, got {:?}", names);
621        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
622        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
623        assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
624        assert!(names.contains(&"X"), "missing X, got {:?}", names);
625        assert!(names.contains(&"Y"), "missing Y, got {:?}", names);
626        assert!(names.contains(&"inner"), "missing inner, got {:?}", names);
627        assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
628        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
629        assert_eq!(main_sym.kind, "function");
630        let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
631        assert_eq!(foo_sym.kind, "struct");
632    }
633
634    #[cfg(feature = "lang-python")]
635    #[test]
636    fn test_extract_python_symbols() {
637        let source =
638            b"def hello():\n    pass\n\nclass MyClass:\n    def method(self):\n        pass\n";
639        let symbols = Lang::Python.extract_symbols(source).unwrap();
640        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
641        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
642        assert!(
643            names.contains(&"MyClass"),
644            "missing MyClass, got {:?}",
645            names
646        );
647        assert!(names.contains(&"method"), "missing method, got {:?}", names);
648        let cls = symbols.iter().find(|s| s.name == "MyClass").unwrap();
649        assert_eq!(cls.kind, "class");
650    }
651
652    #[cfg(feature = "lang-typescript")]
653    #[test]
654    fn test_extract_typescript_symbols() {
655        let source = b"function greet(name: string): void {}\nclass Foo {}\ninterface Bar {}\ntype Alias = string;\nenum Color { Red, Green }\n";
656        let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
657        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
658        assert!(names.contains(&"greet"), "missing greet, got {:?}", names);
659        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
660        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
661        assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
662        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
663    }
664
665    #[cfg(feature = "lang-javascript")]
666    #[test]
667    fn test_extract_javascript_symbols() {
668        let source = b"function hello() { return 42; }\nclass Widget {}\n";
669        let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
670        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
671        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
672        assert!(names.contains(&"Widget"), "missing Widget, got {:?}", names);
673    }
674
675    #[cfg(feature = "lang-kotlin")]
676    #[test]
677    fn test_extract_kotlin_symbols() {
678        let source = b"fun main() { println(\"hi\") }\nclass Foo\ninterface Bar\ndata class Baz(val x: Int)\nsealed class Qux\nenum class Color { RED, GREEN }\nobject Singleton\n";
679        let symbols = Lang::Kotlin.extract_symbols(source).unwrap();
680        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
681        assert!(names.contains(&"main"), "missing main, got {:?}", names);
682        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
683        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
684        assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
685        assert!(names.contains(&"Qux"), "missing Qux, got {:?}", names);
686        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
687        assert!(
688            names.contains(&"Singleton"),
689            "missing Singleton, got {:?}",
690            names
691        );
692        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
693        assert_eq!(main_sym.kind, "function");
694        let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
695        assert_eq!(foo_sym.kind, "class");
696        let bar_sym = symbols.iter().find(|s| s.name == "Bar").unwrap();
697        assert_eq!(bar_sym.kind, "interface");
698        let baz_sym = symbols.iter().find(|s| s.name == "Baz").unwrap();
699        assert_eq!(baz_sym.kind, "data_class");
700        let qux_sym = symbols.iter().find(|s| s.name == "Qux").unwrap();
701        assert_eq!(qux_sym.kind, "sealed_class");
702        let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
703        assert_eq!(color_sym.kind, "enum_class");
704        let singleton_sym = symbols.iter().find(|s| s.name == "Singleton").unwrap();
705        assert_eq!(singleton_sym.kind, "object");
706        assert_eq!(
707            symbols.len(),
708            7,
709            "expected exactly 7 symbols, got {:?}",
710            symbols
711        );
712    }
713
714    #[cfg(feature = "lang-zig")]
715    #[test]
716    fn test_extract_zig_symbols() {
717        let source = b"const std = @import(\"std\");\npub fn main() !void {}\nconst Point = struct { x: i32, y: i32 };\nconst Color = enum { red, green, blue };\nconst Result = union(enum) { ok: i32, err: []const u8 };\nconst MAX: i32 = 100;\n";
718        let symbols = Lang::Zig.extract_symbols(source).unwrap();
719        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
720        assert!(names.contains(&"main"), "missing main, got {:?}", names);
721        assert!(names.contains(&"Point"), "missing Point, got {:?}", names);
722        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
723        assert!(names.contains(&"Result"), "missing Result, got {:?}", names);
724        assert!(names.contains(&"std"), "missing std, got {:?}", names);
725        assert!(names.contains(&"MAX"), "missing MAX, got {:?}", names);
726        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
727        assert_eq!(main_sym.kind, "function");
728        let point_sym = symbols.iter().find(|s| s.name == "Point").unwrap();
729        assert_eq!(point_sym.kind, "struct");
730        let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
731        assert_eq!(color_sym.kind, "enum");
732        let result_sym = symbols.iter().find(|s| s.name == "Result").unwrap();
733        assert_eq!(result_sym.kind, "union");
734        let max_sym = symbols.iter().find(|s| s.name == "MAX").unwrap();
735        assert_eq!(max_sym.kind, "const");
736    }
737
738    #[cfg(feature = "lang-bash")]
739    #[test]
740    fn test_extract_bash_symbols() {
741        let source = b"#!/bin/bash\nhello() { echo hi; }\nfunction world { echo world; }\nalias ll='ls -la'\nalias grep='grep --color=auto'\n";
742        let symbols = Lang::Bash.extract_symbols(source).unwrap();
743        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
744        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
745        assert!(names.contains(&"world"), "missing world, got {:?}", names);
746        assert!(names.contains(&"ll"), "missing alias ll, got {:?}", names);
747        assert!(
748            names.contains(&"grep"),
749            "missing alias grep, got {:?}",
750            names
751        );
752        let hello_sym = symbols.iter().find(|s| s.name == "hello").unwrap();
753        assert_eq!(hello_sym.kind, "function");
754        let ll_sym = symbols.iter().find(|s| s.name == "ll").unwrap();
755        assert_eq!(ll_sym.kind, "alias");
756    }
757
758    #[cfg(feature = "lang-markdown")]
759    #[test]
760    fn test_extract_markdown_symbols() {
761        let source = b"# Title\n\n## Section One\n\nSome text.\n\n- Run setup\n  - Confirm setup\n\n```rust\nfn main() {}\n```\n\n### Subsection\n\n```python\ndef hello():\n    pass\n```\n\n## Next Section\n\nDone.\n";
762        let symbols = Lang::Markdown.extract_symbols(source).unwrap();
763        let headings: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "heading").collect();
764        let code_blocks: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "code_block").collect();
765        let list_items: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "list_item").collect();
766        assert_eq!(headings.len(), 4, "expected 4 headings, got {:?}", headings);
767        assert_eq!(
768            code_blocks.len(),
769            2,
770            "expected 2 code blocks, got {:?}",
771            code_blocks
772        );
773        assert_eq!(
774            list_items.len(),
775            2,
776            "expected 2 list items, got {:?}",
777            list_items
778        );
779        let title = headings.iter().find(|s| s.name == "Title").unwrap();
780        let section = headings.iter().find(|s| s.name == "Section One").unwrap();
781        let next = headings.iter().find(|s| s.name == "Next Section").unwrap();
782        assert_eq!(title.node_kind, "atx_heading");
783        assert!(title.end_byte > next.start_byte);
784        assert_eq!(section.end_byte, next.start_byte);
785        assert!(
786            section.body_start_byte.unwrap() > section.start_byte,
787            "heading body should begin after the marker line"
788        );
789        assert!(
790            code_blocks.iter().any(|s| s.name == "rust"),
791            "missing rust block, got {:?}",
792            code_blocks
793        );
794        assert!(
795            code_blocks.iter().any(|s| s.name == "python"),
796            "missing python block, got {:?}",
797            code_blocks
798        );
799        assert!(
800            list_items.iter().any(|s| s.name == "Run setup"),
801            "missing top-level list item, got {:?}",
802            list_items
803        );
804    }
805
806    #[cfg(feature = "lang-python")]
807    #[test]
808    fn test_python_async_def() {
809        let source = b"async def fetch_data():\n    await get()\n\ndef sync_fn():\n    pass\n";
810        let symbols = Lang::Python.extract_symbols(source).unwrap();
811        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
812        assert!(
813            names.contains(&"fetch_data"),
814            "missing async function, got {:?}",
815            names
816        );
817        assert!(
818            names.contains(&"sync_fn"),
819            "missing sync function, got {:?}",
820            names
821        );
822    }
823
824    #[cfg(feature = "lang-python")]
825    #[test]
826    fn test_python_decorated_function() {
827        let source = b"@staticmethod\ndef helper():\n    pass\n\n@property\ndef name(self):\n    return self._name\n";
828        let symbols = Lang::Python.extract_symbols(source).unwrap();
829        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
830        assert!(
831            names.contains(&"helper"),
832            "missing decorated function, got {:?}",
833            names
834        );
835        assert!(
836            names.contains(&"name"),
837            "missing property function, got {:?}",
838            names
839        );
840    }
841
842    #[cfg(feature = "lang-typescript")]
843    #[test]
844    fn test_typescript_arrow_exports() {
845        let source = b"export const Foo = () => { return 42; };\nexport const Bar = (x: number): number => x + 1;\nconst local = () => {};\nfunction regular() {}\n";
846        let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
847        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
848        assert!(
849            names.contains(&"Foo"),
850            "missing arrow export Foo, got {:?}",
851            names
852        );
853        assert!(
854            names.contains(&"Bar"),
855            "missing arrow export Bar, got {:?}",
856            names
857        );
858        assert!(
859            names.contains(&"local"),
860            "missing local arrow, got {:?}",
861            names
862        );
863        assert!(
864            names.contains(&"regular"),
865            "missing regular function, got {:?}",
866            names
867        );
868    }
869
870    #[cfg(feature = "lang-typescript")]
871    #[test]
872    fn test_tsx_arrow_component() {
873        let source = b"export const MyComponent = () => <div>hello</div>;\nfunction Other() { return <span/>; }\n";
874        let symbols = Lang::Tsx.extract_symbols(source).unwrap();
875        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
876        assert!(
877            names.contains(&"MyComponent"),
878            "missing arrow component, got {:?}",
879            names
880        );
881        assert!(
882            names.contains(&"Other"),
883            "missing function component, got {:?}",
884            names
885        );
886    }
887
888    #[cfg(feature = "lang-javascript")]
889    #[test]
890    fn test_javascript_arrow_exports() {
891        let source = b"export const handler = () => { return 'ok'; };\nconst helper = (x) => x * 2;\nfunction regular() {}\n";
892        let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
893        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
894        assert!(
895            names.contains(&"handler"),
896            "missing arrow export, got {:?}",
897            names
898        );
899        assert!(
900            names.contains(&"helper"),
901            "missing local arrow, got {:?}",
902            names
903        );
904        assert!(
905            names.contains(&"regular"),
906            "missing regular function, got {:?}",
907            names
908        );
909    }
910
911    #[cfg(feature = "lang-javascript")]
912    #[test]
913    fn test_jsx_arrow_component() {
914        let source = b"const App = () => <div>hi</div>;\nfunction Page() { return <main/>; }\n";
915        let symbols = Lang::Jsx.extract_symbols(source).unwrap();
916        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
917        assert!(
918            names.contains(&"App"),
919            "missing arrow JSX component, got {:?}",
920            names
921        );
922        assert!(
923            names.contains(&"Page"),
924            "missing function component, got {:?}",
925            names
926        );
927    }
928}