Skip to main content

tsift_graph/
lang.rs

1use anyhow::Result;
2use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Symbol {
6    pub name: String,
7    pub kind: String,
8    pub line: usize,
9    pub end_line: usize,
10    pub node_kind: String,
11    pub start_byte: usize,
12    pub end_byte: usize,
13    pub body_start_byte: Option<usize>,
14    pub body_end_byte: Option<usize>,
15}
16
17#[allow(dead_code)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Lang {
20    #[cfg(feature = "lang-rust")]
21    Rust,
22    #[cfg(feature = "lang-python")]
23    Python,
24    #[cfg(feature = "lang-typescript")]
25    TypeScript,
26    #[cfg(feature = "lang-typescript")]
27    Tsx,
28    #[cfg(feature = "lang-javascript")]
29    JavaScript,
30    #[cfg(feature = "lang-javascript")]
31    Jsx,
32    #[cfg(feature = "lang-kotlin")]
33    Kotlin,
34    #[cfg(feature = "lang-zig")]
35    Zig,
36    #[cfg(feature = "lang-bash")]
37    Bash,
38    #[cfg(feature = "lang-markdown")]
39    Markdown,
40}
41
42#[allow(dead_code)]
43impl Lang {
44    pub fn from_extension(ext: &str) -> Option<Self> {
45        match ext {
46            #[cfg(feature = "lang-rust")]
47            "rs" => Some(Self::Rust),
48            #[cfg(feature = "lang-python")]
49            "py" | "pyi" => Some(Self::Python),
50            #[cfg(feature = "lang-typescript")]
51            "ts" => Some(Self::TypeScript),
52            #[cfg(feature = "lang-typescript")]
53            "tsx" => Some(Self::Tsx),
54            #[cfg(feature = "lang-javascript")]
55            "js" | "mjs" | "cjs" => Some(Self::JavaScript),
56            #[cfg(feature = "lang-javascript")]
57            "jsx" => Some(Self::Jsx),
58            #[cfg(feature = "lang-kotlin")]
59            "kt" | "kts" => Some(Self::Kotlin),
60            #[cfg(feature = "lang-zig")]
61            "zig" => Some(Self::Zig),
62            #[cfg(feature = "lang-bash")]
63            "sh" | "bash" | "zsh" => Some(Self::Bash),
64            #[cfg(feature = "lang-markdown")]
65            "md" | "mdx" => Some(Self::Markdown),
66            _ => None,
67        }
68    }
69
70    pub fn tree_sitter_language(&self) -> Language {
71        match self {
72            #[cfg(feature = "lang-rust")]
73            Self::Rust => tree_sitter_rust::LANGUAGE.into(),
74            #[cfg(feature = "lang-python")]
75            Self::Python => tree_sitter_python::LANGUAGE.into(),
76            #[cfg(feature = "lang-typescript")]
77            Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
78            #[cfg(feature = "lang-typescript")]
79            Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
80            #[cfg(feature = "lang-javascript")]
81            Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
82            #[cfg(feature = "lang-javascript")]
83            Self::Jsx => tree_sitter_javascript::LANGUAGE.into(),
84            #[cfg(feature = "lang-kotlin")]
85            Self::Kotlin => tree_sitter_kotlin_ng::LANGUAGE.into(),
86            #[cfg(feature = "lang-zig")]
87            Self::Zig => tree_sitter_zig::LANGUAGE.into(),
88            #[cfg(feature = "lang-bash")]
89            Self::Bash => tree_sitter_bash::LANGUAGE.into(),
90            #[cfg(feature = "lang-markdown")]
91            Self::Markdown => tree_sitter_md::LANGUAGE.into(),
92        }
93    }
94
95    pub fn name(&self) -> &'static str {
96        match self {
97            #[cfg(feature = "lang-rust")]
98            Self::Rust => "rust",
99            #[cfg(feature = "lang-python")]
100            Self::Python => "python",
101            #[cfg(feature = "lang-typescript")]
102            Self::TypeScript => "typescript",
103            #[cfg(feature = "lang-typescript")]
104            Self::Tsx => "tsx",
105            #[cfg(feature = "lang-javascript")]
106            Self::JavaScript => "javascript",
107            #[cfg(feature = "lang-javascript")]
108            Self::Jsx => "jsx",
109            #[cfg(feature = "lang-kotlin")]
110            Self::Kotlin => "kotlin",
111            #[cfg(feature = "lang-zig")]
112            Self::Zig => "zig",
113            #[cfg(feature = "lang-bash")]
114            Self::Bash => "bash",
115            #[cfg(feature = "lang-markdown")]
116            Self::Markdown => "markdown",
117        }
118    }
119
120    pub fn symbol_query(&self) -> &'static str {
121        match self {
122            #[cfg(feature = "lang-rust")]
123            Self::Rust => {
124                r#"
125                (function_item name: (identifier) @function.name)
126                (struct_item name: (type_identifier) @struct.name)
127                (enum_item name: (type_identifier) @enum.name)
128                (trait_item name: (type_identifier) @trait.name)
129                (impl_item type: (type_identifier) @impl.name)
130                (mod_item name: (identifier) @mod.name)
131                (type_item name: (type_identifier) @type_alias.name)
132                (const_item name: (identifier) @const.name)
133                (static_item name: (identifier) @static.name)
134            "#
135            }
136            #[cfg(feature = "lang-python")]
137            Self::Python => {
138                r#"
139                (function_definition name: (identifier) @function.name)
140                (class_definition name: (identifier) @class.name)
141            "#
142            }
143            #[cfg(feature = "lang-typescript")]
144            Self::TypeScript | Self::Tsx => {
145                r#"
146                (function_declaration name: (identifier) @function.name)
147                (class_declaration name: (type_identifier) @class.name)
148                (interface_declaration name: (type_identifier) @interface.name)
149                (type_alias_declaration name: (type_identifier) @type_alias.name)
150                (enum_declaration name: (identifier) @enum.name)
151                (variable_declarator name: (identifier) @function.name value: (arrow_function))
152            "#
153            }
154            #[cfg(feature = "lang-javascript")]
155            Self::JavaScript | Self::Jsx => {
156                r#"
157                (function_declaration name: (identifier) @function.name)
158                (class_declaration name: (identifier) @class.name)
159                (variable_declarator name: (identifier) @function.name value: (arrow_function))
160            "#
161            }
162            #[cfg(feature = "lang-kotlin")]
163            Self::Kotlin => {
164                r#"
165                (function_declaration name: (identifier) @function.name)
166                (class_declaration "interface" name: (identifier) @interface.name)
167                (class_declaration (modifiers (class_modifier "data")) name: (identifier) @data_class.name)
168                (class_declaration (modifiers (class_modifier "sealed")) name: (identifier) @sealed_class.name)
169                (class_declaration (modifiers (class_modifier "enum")) name: (identifier) @enum_class.name)
170                (class_declaration "class" name: (identifier) @class.name)
171                (object_declaration name: (identifier) @object.name)
172                (companion_object name: (identifier) @companion_object.name)
173            "#
174            }
175            #[cfg(feature = "lang-zig")]
176            Self::Zig => {
177                r#"
178                (function_declaration (identifier) @function.name)
179                (variable_declaration (identifier) @struct.name (struct_declaration))
180                (variable_declaration (identifier) @enum.name (enum_declaration))
181                (variable_declaration (identifier) @union.name (union_declaration))
182                (variable_declaration (identifier) @const.name)
183            "#
184            }
185            #[cfg(feature = "lang-bash")]
186            Self::Bash => {
187                r#"
188                (function_definition name: (word) @function.name)
189            "#
190            }
191            #[cfg(feature = "lang-markdown")]
192            Self::Markdown => {
193                r#"
194                (atx_heading (atx_h1_marker) (inline) @heading.name)
195                (atx_heading (atx_h2_marker) (inline) @heading.name)
196                (atx_heading (atx_h3_marker) (inline) @heading.name)
197                (atx_heading (atx_h4_marker) (inline) @heading.name)
198                (atx_heading (atx_h5_marker) (inline) @heading.name)
199                (atx_heading (atx_h6_marker) (inline) @heading.name)
200                (fenced_code_block (info_string (language) @code_block.name))
201            "#
202            }
203        }
204    }
205
206    pub fn call_query(&self) -> Option<&'static str> {
207        match self {
208            #[cfg(feature = "lang-rust")]
209            Self::Rust => Some(
210                r#"
211                (call_expression function: (identifier) @call.name)
212                (call_expression function: (field_expression field: (field_identifier) @call.name))
213                (call_expression function: (scoped_identifier name: (identifier) @call.name))
214                (macro_invocation macro: (identifier) @call.name)
215            "#,
216            ),
217            #[cfg(feature = "lang-python")]
218            Self::Python => Some(
219                r#"
220                (call function: (identifier) @call.name)
221                (call function: (attribute attribute: (identifier) @call.name))
222            "#,
223            ),
224            #[cfg(feature = "lang-typescript")]
225            Self::TypeScript | Self::Tsx => Some(
226                r#"
227                (call_expression function: (identifier) @call.name)
228                (call_expression function: (member_expression property: (property_identifier) @call.name))
229            "#,
230            ),
231            #[cfg(feature = "lang-javascript")]
232            Self::JavaScript | Self::Jsx => Some(
233                r#"
234                (call_expression function: (identifier) @call.name)
235                (call_expression function: (member_expression property: (property_identifier) @call.name))
236            "#,
237            ),
238            #[cfg(feature = "lang-kotlin")]
239            Self::Kotlin => Some(
240                r#"
241                (call_expression (simple_identifier) @call.name)
242            "#,
243            ),
244            _ => None,
245        }
246    }
247
248    pub fn extract_symbols(&self, source: &[u8]) -> Result<Vec<Symbol>> {
249        let mut parser = Parser::new();
250        let ts_lang = self.tree_sitter_language();
251        parser.set_language(&ts_lang)?;
252        let tree = parser
253            .parse(source, None)
254            .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
255        #[cfg(feature = "lang-markdown")]
256        if *self == Self::Markdown {
257            return Ok(extract_markdown_symbols(&tree, source));
258        }
259        let query = Query::new(&ts_lang, self.symbol_query())?;
260        let mut cursor = QueryCursor::new();
261        let mut symbols = Vec::new();
262        let capture_names: Vec<String> = query
263            .capture_names()
264            .iter()
265            .map(|s| s.to_string())
266            .collect();
267
268        let mut matches = cursor.matches(&query, tree.root_node(), source);
269        while let Some(m) = matches.next() {
270            for capture in m.captures {
271                let capture_name = &capture_names[capture.index as usize];
272                if let Some(kind_str) = capture_name.strip_suffix(".name") {
273                    let name = capture
274                        .node
275                        .utf8_text(source)
276                        .unwrap_or("<invalid utf8>")
277                        .to_string();
278                    let node = symbol_node_for_capture(kind_str, capture.node);
279                    let body_span = symbol_body_span(node);
280                    symbols.push(Symbol {
281                        name,
282                        kind: kind_str.to_string(),
283                        line: node.start_position().row,
284                        end_line: node.end_position().row,
285                        node_kind: node.kind().to_string(),
286                        start_byte: node.start_byte(),
287                        end_byte: node.end_byte(),
288                        body_start_byte: body_span.map(|(start, _)| start),
289                        body_end_byte: body_span.map(|(_, end)| end),
290                    });
291                }
292            }
293        }
294
295        #[cfg(feature = "lang-bash")]
296        if *self == Self::Bash {
297            Self::extract_bash_aliases(&tree, source, &mut symbols);
298        }
299        symbols.sort_by(|a, b| a.line.cmp(&b.line).then(a.name.cmp(&b.name)));
300        symbols.dedup_by(|b, a| {
301            a.name == b.name && a.line == b.line && {
302                let a_generic = matches!(a.kind.as_str(), "variable" | "const");
303                let b_generic = matches!(b.kind.as_str(), "variable" | "const");
304                match (a_generic, b_generic) {
305                    (true, false) => a.kind.clone_from(&b.kind),
306                    (false, true) => {}
307                    _ => {
308                        if b.kind.len() > a.kind.len() {
309                            a.kind.clone_from(&b.kind);
310                        }
311                    }
312                }
313                true
314            }
315        });
316        Ok(symbols)
317    }
318
319    #[cfg(feature = "lang-bash")]
320    fn extract_bash_aliases(tree: &tree_sitter::Tree, source: &[u8], symbols: &mut Vec<Symbol>) {
321        let mut tree_cursor = tree.root_node().walk();
322        if !tree_cursor.goto_first_child() {
323            return;
324        }
325        loop {
326            let node = tree_cursor.node();
327            if node.kind() == "command"
328                && let Some(name_node) = node.child_by_field_name("name")
329            {
330                let cmd = name_node.utf8_text(source).unwrap_or("");
331                if cmd == "alias" {
332                    for i in 0..node.named_child_count() {
333                        if let Some(arg) = node.named_child(i as u32)
334                            && (arg.kind() == "concatenation" || arg.kind() == "word")
335                        {
336                            let text = arg.utf8_text(source).unwrap_or("");
337                            if let Some(alias_name) = text.split('=').next()
338                                && !alias_name.is_empty()
339                                && alias_name != cmd
340                            {
341                                symbols.push(Symbol {
342                                    name: alias_name.to_string(),
343                                    kind: "alias".to_string(),
344                                    line: arg.start_position().row,
345                                    end_line: node.end_position().row,
346                                    node_kind: node.kind().to_string(),
347                                    start_byte: arg.start_byte(),
348                                    end_byte: node.end_byte(),
349                                    body_start_byte: None,
350                                    body_end_byte: None,
351                                });
352                            }
353                        }
354                    }
355                }
356            }
357            if !tree_cursor.goto_next_sibling() {
358                break;
359            }
360        }
361    }
362
363    pub fn all() -> Vec<Self> {
364        vec![
365            #[cfg(feature = "lang-rust")]
366            Self::Rust,
367            #[cfg(feature = "lang-python")]
368            Self::Python,
369            #[cfg(feature = "lang-typescript")]
370            Self::TypeScript,
371            #[cfg(feature = "lang-typescript")]
372            Self::Tsx,
373            #[cfg(feature = "lang-javascript")]
374            Self::JavaScript,
375            #[cfg(feature = "lang-javascript")]
376            Self::Jsx,
377            #[cfg(feature = "lang-kotlin")]
378            Self::Kotlin,
379            #[cfg(feature = "lang-zig")]
380            Self::Zig,
381            #[cfg(feature = "lang-bash")]
382            Self::Bash,
383            #[cfg(feature = "lang-markdown")]
384            Self::Markdown,
385        ]
386    }
387}
388
389fn symbol_node_for_capture<'tree>(
390    kind: &str,
391    name_node: tree_sitter::Node<'tree>,
392) -> tree_sitter::Node<'tree> {
393    let mut node = name_node.parent().unwrap_or(name_node);
394    if kind == "code_block" {
395        while let Some(parent) = node.parent() {
396            node = parent;
397            if node.kind() == "fenced_code_block" {
398                break;
399            }
400        }
401    }
402    node
403}
404
405fn symbol_body_span(node: tree_sitter::Node<'_>) -> Option<(usize, usize)> {
406    if let Some(body) = node.child_by_field_name("body") {
407        return Some((body.start_byte(), body.end_byte()));
408    }
409    for idx in 0..node.named_child_count() {
410        let Some(child) = node.named_child(idx as u32) else {
411            continue;
412        };
413        if matches!(
414            child.kind(),
415            "block"
416                | "declaration_list"
417                | "field_declaration_list"
418                | "enum_variant_list"
419                | "match_block"
420                | "statement_block"
421                | "suite"
422        ) {
423            return Some((child.start_byte(), child.end_byte()));
424        }
425    }
426    None
427}
428
429#[cfg(feature = "lang-markdown")]
430#[derive(Debug, Clone)]
431struct MarkdownHeading {
432    name: String,
433    level: usize,
434    start_byte: usize,
435    heading_end_byte: usize,
436    start_line: usize,
437}
438
439#[cfg(feature = "lang-markdown")]
440fn extract_markdown_symbols(tree: &tree_sitter::Tree, source: &[u8]) -> Vec<Symbol> {
441    let mut headings = Vec::new();
442    let mut symbols = Vec::new();
443    collect_markdown_symbols(tree.root_node(), source, &mut headings, &mut symbols);
444    headings.sort_by(|left, right| {
445        left.start_byte
446            .cmp(&right.start_byte)
447            .then(left.level.cmp(&right.level))
448            .then(left.name.cmp(&right.name))
449    });
450
451    for (idx, heading) in headings.iter().enumerate() {
452        let section_end_byte = headings
453            .iter()
454            .skip(idx + 1)
455            .find(|candidate| candidate.level <= heading.level)
456            .map(|candidate| candidate.start_byte)
457            .unwrap_or(source.len());
458        let body_start_byte =
459            markdown_next_line_start(source, heading.heading_end_byte).min(section_end_byte);
460        symbols.push(Symbol {
461            name: heading.name.clone(),
462            kind: "heading".to_string(),
463            line: heading.start_line,
464            end_line: markdown_zero_based_end_line(source, section_end_byte),
465            node_kind: "atx_heading".to_string(),
466            start_byte: heading.start_byte,
467            end_byte: section_end_byte,
468            body_start_byte: Some(body_start_byte),
469            body_end_byte: Some(section_end_byte),
470        });
471    }
472
473    symbols.sort_by(|left, right| {
474        left.line
475            .cmp(&right.line)
476            .then(left.start_byte.cmp(&right.start_byte))
477            .then(left.kind.cmp(&right.kind))
478            .then(left.name.cmp(&right.name))
479    });
480    symbols
481}
482
483#[cfg(feature = "lang-markdown")]
484fn collect_markdown_symbols(
485    node: tree_sitter::Node<'_>,
486    source: &[u8],
487    headings: &mut Vec<MarkdownHeading>,
488    symbols: &mut Vec<Symbol>,
489) {
490    match node.kind() {
491        "atx_heading" => {
492            if let Some(level) = markdown_heading_level(node)
493                && let Some(name) = markdown_heading_name(node, source)
494            {
495                headings.push(MarkdownHeading {
496                    name,
497                    level,
498                    start_byte: node.start_byte(),
499                    heading_end_byte: node.end_byte(),
500                    start_line: node.start_position().row,
501                });
502            }
503        }
504        "fenced_code_block" => {
505            let language = markdown_fenced_code_language(node, source)
506                .filter(|value| !value.is_empty())
507                .unwrap_or_else(|| "code".to_string());
508            let body_span = markdown_fenced_code_body_span(node, source);
509            symbols.push(Symbol {
510                name: language,
511                kind: "code_block".to_string(),
512                line: node.start_position().row,
513                end_line: markdown_zero_based_end_line(source, node.end_byte()),
514                node_kind: "fenced_code_block".to_string(),
515                start_byte: node.start_byte(),
516                end_byte: node.end_byte(),
517                body_start_byte: body_span.map(|(start, _)| start),
518                body_end_byte: body_span.map(|(_, end)| end),
519            });
520        }
521        "list_item" => {
522            let name = markdown_list_item_name(node, source);
523            symbols.push(Symbol {
524                name,
525                kind: "list_item".to_string(),
526                line: node.start_position().row,
527                end_line: markdown_zero_based_end_line(source, node.end_byte()),
528                node_kind: "list_item".to_string(),
529                start_byte: node.start_byte(),
530                end_byte: node.end_byte(),
531                body_start_byte: Some(node.start_byte()),
532                body_end_byte: Some(node.end_byte()),
533            });
534        }
535        _ => {}
536    }
537
538    let mut cursor = node.walk();
539    for child in node.children(&mut cursor) {
540        collect_markdown_symbols(child, source, headings, symbols);
541    }
542}
543
544#[cfg(feature = "lang-markdown")]
545fn markdown_heading_level(node: tree_sitter::Node<'_>) -> Option<usize> {
546    let mut cursor = node.walk();
547    for child in node.children(&mut cursor) {
548        let kind = child.kind();
549        if let Some(level) = kind
550            .strip_prefix("atx_h")
551            .and_then(|suffix| suffix.strip_suffix("_marker"))
552            .and_then(|value| value.parse::<usize>().ok())
553        {
554            return Some(level);
555        }
556    }
557    None
558}
559
560#[cfg(feature = "lang-markdown")]
561fn markdown_heading_name(node: tree_sitter::Node<'_>, source: &[u8]) -> Option<String> {
562    let mut cursor = node.walk();
563    for child in node.children(&mut cursor) {
564        if child.kind() == "inline" {
565            let text = child.utf8_text(source).ok()?.trim();
566            if !text.is_empty() {
567                return Some(text.to_string());
568            }
569        }
570    }
571    let line = node.utf8_text(source).ok()?.lines().next()?.trim();
572    let text = line.trim_start_matches('#').trim();
573    (!text.is_empty()).then(|| text.to_string())
574}
575
576#[cfg(feature = "lang-markdown")]
577fn markdown_fenced_code_language(node: tree_sitter::Node<'_>, source: &[u8]) -> Option<String> {
578    if node.kind() == "language" || node.kind() == "info_string" {
579        let text = node.utf8_text(source).ok()?.trim();
580        if !text.is_empty() {
581            return Some(text.to_string());
582        }
583    }
584    let mut cursor = node.walk();
585    for child in node.children(&mut cursor) {
586        if let Some(language) = markdown_fenced_code_language(child, source) {
587            return Some(language);
588        }
589    }
590    None
591}
592
593#[cfg(feature = "lang-markdown")]
594fn markdown_fenced_code_body_span(
595    node: tree_sitter::Node<'_>,
596    source: &[u8],
597) -> Option<(usize, usize)> {
598    let text = node.utf8_text(source).ok()?;
599    let first_newline = text.find('\n')?;
600    let body_start = node.start_byte().saturating_add(first_newline + 1);
601    let closing_start = source[node.start_byte()..node.end_byte()]
602        .iter()
603        .rposition(|byte| *byte == b'\n')
604        .map(|offset| node.start_byte() + offset + 1)
605        .unwrap_or(node.end_byte());
606    Some((body_start.min(closing_start), closing_start))
607}
608
609#[cfg(feature = "lang-markdown")]
610fn markdown_list_item_name(node: tree_sitter::Node<'_>, source: &[u8]) -> String {
611    let text = node.utf8_text(source).unwrap_or("");
612    let first_line = text.lines().next().unwrap_or("").trim();
613    let marker_stripped = first_line
614        .strip_prefix("- ")
615        .or_else(|| first_line.strip_prefix("* "))
616        .or_else(|| first_line.strip_prefix("+ "))
617        .or_else(|| {
618            let (digits, rest) = first_line.split_at(
619                first_line
620                    .find(|ch: char| !ch.is_ascii_digit())
621                    .unwrap_or(first_line.len()),
622            );
623            (!digits.is_empty())
624                .then_some(rest)
625                .and_then(|rest| rest.strip_prefix(". "))
626        })
627        .unwrap_or(first_line)
628        .trim();
629    if marker_stripped.is_empty() {
630        "list item".to_string()
631    } else {
632        marker_stripped.chars().take(96).collect()
633    }
634}
635
636#[cfg(feature = "lang-markdown")]
637fn markdown_next_line_start(source: &[u8], byte: usize) -> usize {
638    let byte = byte.min(source.len());
639    source[byte..]
640        .iter()
641        .position(|value| *value == b'\n')
642        .map(|offset| byte + offset + 1)
643        .unwrap_or(byte)
644}
645
646#[cfg(feature = "lang-markdown")]
647fn markdown_zero_based_end_line(source: &[u8], end_byte: usize) -> usize {
648    let byte = end_byte.saturating_sub(1).min(source.len());
649    source[..byte]
650        .iter()
651        .filter(|value| **value == b'\n')
652        .count()
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_all_grammars_create_parser() {
661        for lang in Lang::all() {
662            let ts_lang = lang.tree_sitter_language();
663            let mut parser = tree_sitter::Parser::new();
664            parser
665                .set_language(&ts_lang)
666                .unwrap_or_else(|e| panic!("failed to set language for {:?}: {}", lang, e));
667        }
668    }
669
670    #[test]
671    fn test_extension_dispatch() {
672        let cases = [
673            ("rs", "rust"),
674            ("py", "python"),
675            ("pyi", "python"),
676            ("ts", "typescript"),
677            ("tsx", "tsx"),
678            ("js", "javascript"),
679            ("mjs", "javascript"),
680            ("cjs", "javascript"),
681            ("jsx", "jsx"),
682            ("kt", "kotlin"),
683            ("kts", "kotlin"),
684            ("zig", "zig"),
685            ("sh", "bash"),
686            ("bash", "bash"),
687            ("zsh", "bash"),
688            ("md", "markdown"),
689            ("mdx", "markdown"),
690        ];
691        for (ext, expected_name) in cases {
692            let lang = Lang::from_extension(ext)
693                .unwrap_or_else(|| panic!("no language for extension: {ext}"));
694            assert_eq!(lang.name(), expected_name, "wrong language for .{ext}");
695        }
696    }
697
698    #[test]
699    fn test_unknown_extension_returns_none() {
700        assert!(Lang::from_extension("xyz").is_none());
701        assert!(Lang::from_extension("").is_none());
702        assert!(Lang::from_extension("txt").is_none());
703    }
704
705    #[cfg(feature = "lang-rust")]
706    #[test]
707    fn test_parse_rust_snippet() {
708        let lang = Lang::Rust;
709        let mut parser = tree_sitter::Parser::new();
710        parser.set_language(&lang.tree_sitter_language()).unwrap();
711        let tree = parser.parse("fn main() {}", None).unwrap();
712        assert_eq!(tree.root_node().kind(), "source_file");
713        assert!(!tree.root_node().has_error());
714    }
715
716    #[cfg(feature = "lang-python")]
717    #[test]
718    fn test_parse_python_snippet() {
719        let lang = Lang::Python;
720        let mut parser = tree_sitter::Parser::new();
721        parser.set_language(&lang.tree_sitter_language()).unwrap();
722        let tree = parser.parse("def hello():\n    pass\n", None).unwrap();
723        assert_eq!(tree.root_node().kind(), "module");
724        assert!(!tree.root_node().has_error());
725    }
726
727    #[cfg(feature = "lang-typescript")]
728    #[test]
729    fn test_parse_typescript_snippet() {
730        let lang = Lang::TypeScript;
731        let mut parser = tree_sitter::Parser::new();
732        parser.set_language(&lang.tree_sitter_language()).unwrap();
733        let tree = parser
734            .parse("function greet(name: string): void {}", None)
735            .unwrap();
736        assert_eq!(tree.root_node().kind(), "program");
737        assert!(!tree.root_node().has_error());
738    }
739
740    #[cfg(feature = "lang-typescript")]
741    #[test]
742    fn test_parse_tsx_snippet() {
743        let lang = Lang::Tsx;
744        let mut parser = tree_sitter::Parser::new();
745        parser.set_language(&lang.tree_sitter_language()).unwrap();
746        let tree = parser
747            .parse("const App = () => <div>hello</div>;", None)
748            .unwrap();
749        assert_eq!(tree.root_node().kind(), "program");
750        assert!(!tree.root_node().has_error());
751    }
752
753    #[cfg(feature = "lang-javascript")]
754    #[test]
755    fn test_parse_javascript_snippet() {
756        let lang = Lang::JavaScript;
757        let mut parser = tree_sitter::Parser::new();
758        parser.set_language(&lang.tree_sitter_language()).unwrap();
759        let tree = parser
760            .parse("function hello() { return 42; }", None)
761            .unwrap();
762        assert_eq!(tree.root_node().kind(), "program");
763        assert!(!tree.root_node().has_error());
764    }
765
766    #[cfg(feature = "lang-kotlin")]
767    #[test]
768    fn test_parse_kotlin_snippet() {
769        let lang = Lang::Kotlin;
770        let mut parser = tree_sitter::Parser::new();
771        parser.set_language(&lang.tree_sitter_language()).unwrap();
772        let tree = parser
773            .parse("fun main() { println(\"hello\") }", None)
774            .unwrap();
775        assert_eq!(tree.root_node().kind(), "source_file");
776        assert!(!tree.root_node().has_error());
777    }
778
779    #[cfg(feature = "lang-zig")]
780    #[test]
781    fn test_parse_zig_snippet() {
782        let lang = Lang::Zig;
783        let mut parser = tree_sitter::Parser::new();
784        parser.set_language(&lang.tree_sitter_language()).unwrap();
785        let tree = parser.parse("pub fn main() !void {}", None).unwrap();
786        assert_eq!(tree.root_node().kind(), "source_file");
787    }
788
789    #[cfg(feature = "lang-bash")]
790    #[test]
791    fn test_parse_bash_snippet() {
792        let lang = Lang::Bash;
793        let mut parser = tree_sitter::Parser::new();
794        parser.set_language(&lang.tree_sitter_language()).unwrap();
795        let tree = parser
796            .parse("#!/bin/bash\nhello() { echo hi; }\n", None)
797            .unwrap();
798        assert_eq!(tree.root_node().kind(), "program");
799        assert!(!tree.root_node().has_error());
800    }
801
802    #[cfg(feature = "lang-markdown")]
803    #[test]
804    fn test_parse_markdown_snippet() {
805        let lang = Lang::Markdown;
806        let mut parser = tree_sitter::Parser::new();
807        parser.set_language(&lang.tree_sitter_language()).unwrap();
808        let tree = parser.parse("# Hello\n\nSome text.\n", None).unwrap();
809        assert_eq!(tree.root_node().kind(), "document");
810        assert!(!tree.root_node().has_error());
811    }
812
813    #[test]
814    fn test_all_symbol_queries_compile() {
815        for lang in Lang::all() {
816            let ts_lang = lang.tree_sitter_language();
817            tree_sitter::Query::new(&ts_lang, lang.symbol_query())
818                .unwrap_or_else(|e| panic!("query compile failed for {:?}: {}", lang, e));
819        }
820    }
821
822    #[cfg(feature = "lang-rust")]
823    #[test]
824    fn test_extract_rust_symbols() {
825        let source = b"fn main() {}\nstruct Foo;\nenum Bar {}\ntrait Baz {}\nconst X: i32 = 1;\nstatic Y: i32 = 2;\nmod inner {}\ntype Alias = i32;\n";
826        let symbols = Lang::Rust.extract_symbols(source).unwrap();
827        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
828        assert!(names.contains(&"main"), "missing main, got {:?}", names);
829        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
830        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
831        assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
832        assert!(names.contains(&"X"), "missing X, got {:?}", names);
833        assert!(names.contains(&"Y"), "missing Y, got {:?}", names);
834        assert!(names.contains(&"inner"), "missing inner, got {:?}", names);
835        assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
836        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
837        assert_eq!(main_sym.kind, "function");
838        let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
839        assert_eq!(foo_sym.kind, "struct");
840    }
841
842    #[cfg(feature = "lang-python")]
843    #[test]
844    fn test_extract_python_symbols() {
845        let source =
846            b"def hello():\n    pass\n\nclass MyClass:\n    def method(self):\n        pass\n";
847        let symbols = Lang::Python.extract_symbols(source).unwrap();
848        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
849        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
850        assert!(
851            names.contains(&"MyClass"),
852            "missing MyClass, got {:?}",
853            names
854        );
855        assert!(names.contains(&"method"), "missing method, got {:?}", names);
856        let cls = symbols.iter().find(|s| s.name == "MyClass").unwrap();
857        assert_eq!(cls.kind, "class");
858    }
859
860    #[cfg(feature = "lang-typescript")]
861    #[test]
862    fn test_extract_typescript_symbols() {
863        let source = b"function greet(name: string): void {}\nclass Foo {}\ninterface Bar {}\ntype Alias = string;\nenum Color { Red, Green }\n";
864        let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
865        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
866        assert!(names.contains(&"greet"), "missing greet, got {:?}", names);
867        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
868        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
869        assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
870        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
871    }
872
873    #[cfg(feature = "lang-javascript")]
874    #[test]
875    fn test_extract_javascript_symbols() {
876        let source = b"function hello() { return 42; }\nclass Widget {}\n";
877        let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
878        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
879        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
880        assert!(names.contains(&"Widget"), "missing Widget, got {:?}", names);
881    }
882
883    #[cfg(feature = "lang-kotlin")]
884    #[test]
885    fn test_extract_kotlin_symbols() {
886        let source = b"fun main() { println(\"hi\") }\nclass Foo\ninterface Bar\ndata class Baz(val x: Int)\nsealed class Qux\nenum class Color { RED, GREEN }\nobject Singleton\n";
887        let symbols = Lang::Kotlin.extract_symbols(source).unwrap();
888        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
889        assert!(names.contains(&"main"), "missing main, got {:?}", names);
890        assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
891        assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
892        assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
893        assert!(names.contains(&"Qux"), "missing Qux, got {:?}", names);
894        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
895        assert!(
896            names.contains(&"Singleton"),
897            "missing Singleton, got {:?}",
898            names
899        );
900        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
901        assert_eq!(main_sym.kind, "function");
902        let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
903        assert_eq!(foo_sym.kind, "class");
904        let bar_sym = symbols.iter().find(|s| s.name == "Bar").unwrap();
905        assert_eq!(bar_sym.kind, "interface");
906        let baz_sym = symbols.iter().find(|s| s.name == "Baz").unwrap();
907        assert_eq!(baz_sym.kind, "data_class");
908        let qux_sym = symbols.iter().find(|s| s.name == "Qux").unwrap();
909        assert_eq!(qux_sym.kind, "sealed_class");
910        let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
911        assert_eq!(color_sym.kind, "enum_class");
912        let singleton_sym = symbols.iter().find(|s| s.name == "Singleton").unwrap();
913        assert_eq!(singleton_sym.kind, "object");
914        assert_eq!(
915            symbols.len(),
916            7,
917            "expected exactly 7 symbols, got {:?}",
918            symbols
919        );
920    }
921
922    #[cfg(feature = "lang-zig")]
923    #[test]
924    fn test_extract_zig_symbols() {
925        let source = b"const std = @import(\"std\");\npub fn main() !void {}\nconst Point = struct { x: i32, y: i32 };\nconst Color = enum { red, green, blue };\nconst Result = union(enum) { ok: i32, err: []const u8 };\nconst MAX: i32 = 100;\n";
926        let symbols = Lang::Zig.extract_symbols(source).unwrap();
927        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
928        assert!(names.contains(&"main"), "missing main, got {:?}", names);
929        assert!(names.contains(&"Point"), "missing Point, got {:?}", names);
930        assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
931        assert!(names.contains(&"Result"), "missing Result, got {:?}", names);
932        assert!(names.contains(&"std"), "missing std, got {:?}", names);
933        assert!(names.contains(&"MAX"), "missing MAX, got {:?}", names);
934        let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
935        assert_eq!(main_sym.kind, "function");
936        let point_sym = symbols.iter().find(|s| s.name == "Point").unwrap();
937        assert_eq!(point_sym.kind, "struct");
938        let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
939        assert_eq!(color_sym.kind, "enum");
940        let result_sym = symbols.iter().find(|s| s.name == "Result").unwrap();
941        assert_eq!(result_sym.kind, "union");
942        let max_sym = symbols.iter().find(|s| s.name == "MAX").unwrap();
943        assert_eq!(max_sym.kind, "const");
944    }
945
946    #[cfg(feature = "lang-bash")]
947    #[test]
948    fn test_extract_bash_symbols() {
949        let source = b"#!/bin/bash\nhello() { echo hi; }\nfunction world { echo world; }\nalias ll='ls -la'\nalias grep='grep --color=auto'\n";
950        let symbols = Lang::Bash.extract_symbols(source).unwrap();
951        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
952        assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
953        assert!(names.contains(&"world"), "missing world, got {:?}", names);
954        assert!(names.contains(&"ll"), "missing alias ll, got {:?}", names);
955        assert!(
956            names.contains(&"grep"),
957            "missing alias grep, got {:?}",
958            names
959        );
960        let hello_sym = symbols.iter().find(|s| s.name == "hello").unwrap();
961        assert_eq!(hello_sym.kind, "function");
962        let ll_sym = symbols.iter().find(|s| s.name == "ll").unwrap();
963        assert_eq!(ll_sym.kind, "alias");
964    }
965
966    #[cfg(feature = "lang-markdown")]
967    #[test]
968    fn test_extract_markdown_symbols() {
969        let source = b"# Title\n\n## Section One\n\nSome text.\n\n- Run setup\n  - Confirm setup\n\n```rust\nfn main() {}\n```\n\n### Subsection\n\n```python\ndef hello():\n    pass\n```\n\n## Next Section\n\nDone.\n";
970        let symbols = Lang::Markdown.extract_symbols(source).unwrap();
971        let headings: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "heading").collect();
972        let code_blocks: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "code_block").collect();
973        let list_items: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "list_item").collect();
974        assert_eq!(headings.len(), 4, "expected 4 headings, got {:?}", headings);
975        assert_eq!(
976            code_blocks.len(),
977            2,
978            "expected 2 code blocks, got {:?}",
979            code_blocks
980        );
981        assert_eq!(
982            list_items.len(),
983            2,
984            "expected 2 list items, got {:?}",
985            list_items
986        );
987        let title = headings.iter().find(|s| s.name == "Title").unwrap();
988        let section = headings.iter().find(|s| s.name == "Section One").unwrap();
989        let next = headings.iter().find(|s| s.name == "Next Section").unwrap();
990        assert_eq!(title.node_kind, "atx_heading");
991        assert!(title.end_byte > next.start_byte);
992        assert_eq!(section.end_byte, next.start_byte);
993        assert!(
994            section.body_start_byte.unwrap() > section.start_byte,
995            "heading body should begin after the marker line"
996        );
997        assert!(
998            code_blocks.iter().any(|s| s.name == "rust"),
999            "missing rust block, got {:?}",
1000            code_blocks
1001        );
1002        assert!(
1003            code_blocks.iter().any(|s| s.name == "python"),
1004            "missing python block, got {:?}",
1005            code_blocks
1006        );
1007        assert!(
1008            list_items.iter().any(|s| s.name == "Run setup"),
1009            "missing top-level list item, got {:?}",
1010            list_items
1011        );
1012    }
1013
1014    #[cfg(feature = "lang-python")]
1015    #[test]
1016    fn test_python_async_def() {
1017        let source = b"async def fetch_data():\n    await get()\n\ndef sync_fn():\n    pass\n";
1018        let symbols = Lang::Python.extract_symbols(source).unwrap();
1019        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1020        assert!(
1021            names.contains(&"fetch_data"),
1022            "missing async function, got {:?}",
1023            names
1024        );
1025        assert!(
1026            names.contains(&"sync_fn"),
1027            "missing sync function, got {:?}",
1028            names
1029        );
1030    }
1031
1032    #[cfg(feature = "lang-python")]
1033    #[test]
1034    fn test_python_decorated_function() {
1035        let source = b"@staticmethod\ndef helper():\n    pass\n\n@property\ndef name(self):\n    return self._name\n";
1036        let symbols = Lang::Python.extract_symbols(source).unwrap();
1037        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1038        assert!(
1039            names.contains(&"helper"),
1040            "missing decorated function, got {:?}",
1041            names
1042        );
1043        assert!(
1044            names.contains(&"name"),
1045            "missing property function, got {:?}",
1046            names
1047        );
1048    }
1049
1050    #[cfg(feature = "lang-typescript")]
1051    #[test]
1052    fn test_typescript_arrow_exports() {
1053        let source = b"export const Foo = () => { return 42; };\nexport const Bar = (x: number): number => x + 1;\nconst local = () => {};\nfunction regular() {}\n";
1054        let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
1055        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1056        assert!(
1057            names.contains(&"Foo"),
1058            "missing arrow export Foo, got {:?}",
1059            names
1060        );
1061        assert!(
1062            names.contains(&"Bar"),
1063            "missing arrow export Bar, got {:?}",
1064            names
1065        );
1066        assert!(
1067            names.contains(&"local"),
1068            "missing local arrow, got {:?}",
1069            names
1070        );
1071        assert!(
1072            names.contains(&"regular"),
1073            "missing regular function, got {:?}",
1074            names
1075        );
1076    }
1077
1078    #[cfg(feature = "lang-typescript")]
1079    #[test]
1080    fn test_tsx_arrow_component() {
1081        let source = b"export const MyComponent = () => <div>hello</div>;\nfunction Other() { return <span/>; }\n";
1082        let symbols = Lang::Tsx.extract_symbols(source).unwrap();
1083        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1084        assert!(
1085            names.contains(&"MyComponent"),
1086            "missing arrow component, got {:?}",
1087            names
1088        );
1089        assert!(
1090            names.contains(&"Other"),
1091            "missing function component, got {:?}",
1092            names
1093        );
1094    }
1095
1096    #[cfg(feature = "lang-javascript")]
1097    #[test]
1098    fn test_javascript_arrow_exports() {
1099        let source = b"export const handler = () => { return 'ok'; };\nconst helper = (x) => x * 2;\nfunction regular() {}\n";
1100        let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
1101        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1102        assert!(
1103            names.contains(&"handler"),
1104            "missing arrow export, got {:?}",
1105            names
1106        );
1107        assert!(
1108            names.contains(&"helper"),
1109            "missing local arrow, got {:?}",
1110            names
1111        );
1112        assert!(
1113            names.contains(&"regular"),
1114            "missing regular function, got {:?}",
1115            names
1116        );
1117    }
1118
1119    #[cfg(feature = "lang-javascript")]
1120    #[test]
1121    fn test_jsx_arrow_component() {
1122        let source = b"const App = () => <div>hi</div>;\nfunction Page() { return <main/>; }\n";
1123        let symbols = Lang::Jsx.extract_symbols(source).unwrap();
1124        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1125        assert!(
1126            names.contains(&"App"),
1127            "missing arrow JSX component, got {:?}",
1128            names
1129        );
1130        assert!(
1131            names.contains(&"Page"),
1132            "missing function component, got {:?}",
1133            names
1134        );
1135    }
1136}