1use anyhow::Result;
2use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Symbol {
6 pub name: String,
7 pub kind: String,
8 pub line: usize,
9 pub end_line: usize,
10 pub node_kind: String,
11 pub start_byte: usize,
12 pub end_byte: usize,
13 pub body_start_byte: Option<usize>,
14 pub body_end_byte: Option<usize>,
15}
16
17#[allow(dead_code)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Lang {
20 #[cfg(feature = "lang-rust")]
21 Rust,
22 #[cfg(feature = "lang-python")]
23 Python,
24 #[cfg(feature = "lang-typescript")]
25 TypeScript,
26 #[cfg(feature = "lang-typescript")]
27 Tsx,
28 #[cfg(feature = "lang-javascript")]
29 JavaScript,
30 #[cfg(feature = "lang-javascript")]
31 Jsx,
32 #[cfg(feature = "lang-kotlin")]
33 Kotlin,
34 #[cfg(feature = "lang-zig")]
35 Zig,
36 #[cfg(feature = "lang-bash")]
37 Bash,
38 #[cfg(feature = "lang-markdown")]
39 Markdown,
40}
41
42#[allow(dead_code)]
43impl Lang {
44 pub fn from_extension(ext: &str) -> Option<Self> {
45 match ext {
46 #[cfg(feature = "lang-rust")]
47 "rs" => Some(Self::Rust),
48 #[cfg(feature = "lang-python")]
49 "py" | "pyi" => Some(Self::Python),
50 #[cfg(feature = "lang-typescript")]
51 "ts" => Some(Self::TypeScript),
52 #[cfg(feature = "lang-typescript")]
53 "tsx" => Some(Self::Tsx),
54 #[cfg(feature = "lang-javascript")]
55 "js" | "mjs" | "cjs" => Some(Self::JavaScript),
56 #[cfg(feature = "lang-javascript")]
57 "jsx" => Some(Self::Jsx),
58 #[cfg(feature = "lang-kotlin")]
59 "kt" | "kts" => Some(Self::Kotlin),
60 #[cfg(feature = "lang-zig")]
61 "zig" => Some(Self::Zig),
62 #[cfg(feature = "lang-bash")]
63 "sh" | "bash" | "zsh" => Some(Self::Bash),
64 #[cfg(feature = "lang-markdown")]
65 "md" | "mdx" => Some(Self::Markdown),
66 _ => None,
67 }
68 }
69
70 pub fn tree_sitter_language(&self) -> Language {
71 match self {
72 #[cfg(feature = "lang-rust")]
73 Self::Rust => tree_sitter_rust::LANGUAGE.into(),
74 #[cfg(feature = "lang-python")]
75 Self::Python => tree_sitter_python::LANGUAGE.into(),
76 #[cfg(feature = "lang-typescript")]
77 Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
78 #[cfg(feature = "lang-typescript")]
79 Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
80 #[cfg(feature = "lang-javascript")]
81 Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
82 #[cfg(feature = "lang-javascript")]
83 Self::Jsx => tree_sitter_javascript::LANGUAGE.into(),
84 #[cfg(feature = "lang-kotlin")]
85 Self::Kotlin => tree_sitter_kotlin_ng::LANGUAGE.into(),
86 #[cfg(feature = "lang-zig")]
87 Self::Zig => tree_sitter_zig::LANGUAGE.into(),
88 #[cfg(feature = "lang-bash")]
89 Self::Bash => tree_sitter_bash::LANGUAGE.into(),
90 #[cfg(feature = "lang-markdown")]
91 Self::Markdown => tree_sitter_md::LANGUAGE.into(),
92 }
93 }
94
95 pub fn name(&self) -> &'static str {
96 match self {
97 #[cfg(feature = "lang-rust")]
98 Self::Rust => "rust",
99 #[cfg(feature = "lang-python")]
100 Self::Python => "python",
101 #[cfg(feature = "lang-typescript")]
102 Self::TypeScript => "typescript",
103 #[cfg(feature = "lang-typescript")]
104 Self::Tsx => "tsx",
105 #[cfg(feature = "lang-javascript")]
106 Self::JavaScript => "javascript",
107 #[cfg(feature = "lang-javascript")]
108 Self::Jsx => "jsx",
109 #[cfg(feature = "lang-kotlin")]
110 Self::Kotlin => "kotlin",
111 #[cfg(feature = "lang-zig")]
112 Self::Zig => "zig",
113 #[cfg(feature = "lang-bash")]
114 Self::Bash => "bash",
115 #[cfg(feature = "lang-markdown")]
116 Self::Markdown => "markdown",
117 }
118 }
119
120 pub fn symbol_query(&self) -> &'static str {
121 match self {
122 #[cfg(feature = "lang-rust")]
123 Self::Rust => {
124 r#"
125 (function_item name: (identifier) @function.name)
126 (struct_item name: (type_identifier) @struct.name)
127 (enum_item name: (type_identifier) @enum.name)
128 (trait_item name: (type_identifier) @trait.name)
129 (impl_item type: (type_identifier) @impl.name)
130 (mod_item name: (identifier) @mod.name)
131 (type_item name: (type_identifier) @type_alias.name)
132 (const_item name: (identifier) @const.name)
133 (static_item name: (identifier) @static.name)
134 "#
135 }
136 #[cfg(feature = "lang-python")]
137 Self::Python => {
138 r#"
139 (function_definition name: (identifier) @function.name)
140 (class_definition name: (identifier) @class.name)
141 "#
142 }
143 #[cfg(feature = "lang-typescript")]
144 Self::TypeScript | Self::Tsx => {
145 r#"
146 (function_declaration name: (identifier) @function.name)
147 (class_declaration name: (type_identifier) @class.name)
148 (interface_declaration name: (type_identifier) @interface.name)
149 (type_alias_declaration name: (type_identifier) @type_alias.name)
150 (enum_declaration name: (identifier) @enum.name)
151 (variable_declarator name: (identifier) @function.name value: (arrow_function))
152 "#
153 }
154 #[cfg(feature = "lang-javascript")]
155 Self::JavaScript | Self::Jsx => {
156 r#"
157 (function_declaration name: (identifier) @function.name)
158 (class_declaration name: (identifier) @class.name)
159 (variable_declarator name: (identifier) @function.name value: (arrow_function))
160 "#
161 }
162 #[cfg(feature = "lang-kotlin")]
163 Self::Kotlin => {
164 r#"
165 (function_declaration name: (identifier) @function.name)
166 (class_declaration "interface" name: (identifier) @interface.name)
167 (class_declaration (modifiers (class_modifier "data")) name: (identifier) @data_class.name)
168 (class_declaration (modifiers (class_modifier "sealed")) name: (identifier) @sealed_class.name)
169 (class_declaration (modifiers (class_modifier "enum")) name: (identifier) @enum_class.name)
170 (class_declaration "class" name: (identifier) @class.name)
171 (object_declaration name: (identifier) @object.name)
172 (companion_object name: (identifier) @companion_object.name)
173 "#
174 }
175 #[cfg(feature = "lang-zig")]
176 Self::Zig => {
177 r#"
178 (function_declaration (identifier) @function.name)
179 (variable_declaration (identifier) @struct.name (struct_declaration))
180 (variable_declaration (identifier) @enum.name (enum_declaration))
181 (variable_declaration (identifier) @union.name (union_declaration))
182 (variable_declaration (identifier) @const.name)
183 "#
184 }
185 #[cfg(feature = "lang-bash")]
186 Self::Bash => {
187 r#"
188 (function_definition name: (word) @function.name)
189 "#
190 }
191 #[cfg(feature = "lang-markdown")]
192 Self::Markdown => {
193 r#"
194 (atx_heading (atx_h1_marker) (inline) @heading.name)
195 (atx_heading (atx_h2_marker) (inline) @heading.name)
196 (atx_heading (atx_h3_marker) (inline) @heading.name)
197 (atx_heading (atx_h4_marker) (inline) @heading.name)
198 (atx_heading (atx_h5_marker) (inline) @heading.name)
199 (atx_heading (atx_h6_marker) (inline) @heading.name)
200 (fenced_code_block (info_string (language) @code_block.name))
201 "#
202 }
203 }
204 }
205
206 pub fn call_query(&self) -> Option<&'static str> {
207 match self {
208 #[cfg(feature = "lang-rust")]
209 Self::Rust => Some(
210 r#"
211 (call_expression function: (identifier) @call.name)
212 (call_expression function: (field_expression field: (field_identifier) @call.name))
213 (call_expression function: (scoped_identifier name: (identifier) @call.name))
214 (macro_invocation macro: (identifier) @call.name)
215 "#,
216 ),
217 #[cfg(feature = "lang-python")]
218 Self::Python => Some(
219 r#"
220 (call function: (identifier) @call.name)
221 (call function: (attribute attribute: (identifier) @call.name))
222 "#,
223 ),
224 #[cfg(feature = "lang-typescript")]
225 Self::TypeScript | Self::Tsx => Some(
226 r#"
227 (call_expression function: (identifier) @call.name)
228 (call_expression function: (member_expression property: (property_identifier) @call.name))
229 "#,
230 ),
231 #[cfg(feature = "lang-javascript")]
232 Self::JavaScript | Self::Jsx => Some(
233 r#"
234 (call_expression function: (identifier) @call.name)
235 (call_expression function: (member_expression property: (property_identifier) @call.name))
236 "#,
237 ),
238 #[cfg(feature = "lang-kotlin")]
239 Self::Kotlin => Some(
240 r#"
241 (call_expression (simple_identifier) @call.name)
242 "#,
243 ),
244 _ => None,
245 }
246 }
247
248 pub fn extract_symbols(&self, source: &[u8]) -> Result<Vec<Symbol>> {
249 let mut parser = Parser::new();
250 let ts_lang = self.tree_sitter_language();
251 parser.set_language(&ts_lang)?;
252 let tree = parser
253 .parse(source, None)
254 .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
255 #[cfg(feature = "lang-markdown")]
256 if *self == Self::Markdown {
257 return Ok(extract_markdown_symbols(&tree, source));
258 }
259 let query = Query::new(&ts_lang, self.symbol_query())?;
260 let mut cursor = QueryCursor::new();
261 let mut symbols = Vec::new();
262 let capture_names: Vec<String> = query
263 .capture_names()
264 .iter()
265 .map(|s| s.to_string())
266 .collect();
267
268 let mut matches = cursor.matches(&query, tree.root_node(), source);
269 while let Some(m) = matches.next() {
270 for capture in m.captures {
271 let capture_name = &capture_names[capture.index as usize];
272 if let Some(kind_str) = capture_name.strip_suffix(".name") {
273 let name = capture
274 .node
275 .utf8_text(source)
276 .unwrap_or("<invalid utf8>")
277 .to_string();
278 let node = symbol_node_for_capture(kind_str, capture.node);
279 let body_span = symbol_body_span(node);
280 symbols.push(Symbol {
281 name,
282 kind: kind_str.to_string(),
283 line: node.start_position().row,
284 end_line: node.end_position().row,
285 node_kind: node.kind().to_string(),
286 start_byte: node.start_byte(),
287 end_byte: node.end_byte(),
288 body_start_byte: body_span.map(|(start, _)| start),
289 body_end_byte: body_span.map(|(_, end)| end),
290 });
291 }
292 }
293 }
294
295 #[cfg(feature = "lang-bash")]
296 if *self == Self::Bash {
297 Self::extract_bash_aliases(&tree, source, &mut symbols);
298 }
299 symbols.sort_by(|a, b| a.line.cmp(&b.line).then(a.name.cmp(&b.name)));
300 symbols.dedup_by(|b, a| {
301 a.name == b.name && a.line == b.line && {
302 let a_generic = matches!(a.kind.as_str(), "variable" | "const");
303 let b_generic = matches!(b.kind.as_str(), "variable" | "const");
304 match (a_generic, b_generic) {
305 (true, false) => a.kind.clone_from(&b.kind),
306 (false, true) => {}
307 _ => {
308 if b.kind.len() > a.kind.len() {
309 a.kind.clone_from(&b.kind);
310 }
311 }
312 }
313 true
314 }
315 });
316 Ok(symbols)
317 }
318
319 #[cfg(feature = "lang-bash")]
320 fn extract_bash_aliases(tree: &tree_sitter::Tree, source: &[u8], symbols: &mut Vec<Symbol>) {
321 let mut tree_cursor = tree.root_node().walk();
322 if !tree_cursor.goto_first_child() {
323 return;
324 }
325 loop {
326 let node = tree_cursor.node();
327 if node.kind() == "command"
328 && let Some(name_node) = node.child_by_field_name("name")
329 {
330 let cmd = name_node.utf8_text(source).unwrap_or("");
331 if cmd == "alias" {
332 for i in 0..node.named_child_count() {
333 if let Some(arg) = node.named_child(i as u32)
334 && (arg.kind() == "concatenation" || arg.kind() == "word")
335 {
336 let text = arg.utf8_text(source).unwrap_or("");
337 if let Some(alias_name) = text.split('=').next()
338 && !alias_name.is_empty()
339 && alias_name != cmd
340 {
341 symbols.push(Symbol {
342 name: alias_name.to_string(),
343 kind: "alias".to_string(),
344 line: arg.start_position().row,
345 end_line: node.end_position().row,
346 node_kind: node.kind().to_string(),
347 start_byte: arg.start_byte(),
348 end_byte: node.end_byte(),
349 body_start_byte: None,
350 body_end_byte: None,
351 });
352 }
353 }
354 }
355 }
356 }
357 if !tree_cursor.goto_next_sibling() {
358 break;
359 }
360 }
361 }
362
363 pub fn all() -> Vec<Self> {
364 vec![
365 #[cfg(feature = "lang-rust")]
366 Self::Rust,
367 #[cfg(feature = "lang-python")]
368 Self::Python,
369 #[cfg(feature = "lang-typescript")]
370 Self::TypeScript,
371 #[cfg(feature = "lang-typescript")]
372 Self::Tsx,
373 #[cfg(feature = "lang-javascript")]
374 Self::JavaScript,
375 #[cfg(feature = "lang-javascript")]
376 Self::Jsx,
377 #[cfg(feature = "lang-kotlin")]
378 Self::Kotlin,
379 #[cfg(feature = "lang-zig")]
380 Self::Zig,
381 #[cfg(feature = "lang-bash")]
382 Self::Bash,
383 #[cfg(feature = "lang-markdown")]
384 Self::Markdown,
385 ]
386 }
387}
388
389fn symbol_node_for_capture<'tree>(
390 kind: &str,
391 name_node: tree_sitter::Node<'tree>,
392) -> tree_sitter::Node<'tree> {
393 let mut node = name_node.parent().unwrap_or(name_node);
394 if kind == "code_block" {
395 while let Some(parent) = node.parent() {
396 node = parent;
397 if node.kind() == "fenced_code_block" {
398 break;
399 }
400 }
401 }
402 node
403}
404
405fn symbol_body_span(node: tree_sitter::Node<'_>) -> Option<(usize, usize)> {
406 if let Some(body) = node.child_by_field_name("body") {
407 return Some((body.start_byte(), body.end_byte()));
408 }
409 for idx in 0..node.named_child_count() {
410 let Some(child) = node.named_child(idx as u32) else {
411 continue;
412 };
413 if matches!(
414 child.kind(),
415 "block"
416 | "declaration_list"
417 | "field_declaration_list"
418 | "enum_variant_list"
419 | "match_block"
420 | "statement_block"
421 | "suite"
422 ) {
423 return Some((child.start_byte(), child.end_byte()));
424 }
425 }
426 None
427}
428
429#[cfg(feature = "lang-markdown")]
430#[derive(Debug, Clone)]
431struct MarkdownHeading {
432 name: String,
433 level: usize,
434 start_byte: usize,
435 heading_end_byte: usize,
436 start_line: usize,
437}
438
439#[cfg(feature = "lang-markdown")]
440fn extract_markdown_symbols(tree: &tree_sitter::Tree, source: &[u8]) -> Vec<Symbol> {
441 let mut headings = Vec::new();
442 let mut symbols = Vec::new();
443 collect_markdown_symbols(tree.root_node(), source, &mut headings, &mut symbols);
444 headings.sort_by(|left, right| {
445 left.start_byte
446 .cmp(&right.start_byte)
447 .then(left.level.cmp(&right.level))
448 .then(left.name.cmp(&right.name))
449 });
450
451 for (idx, heading) in headings.iter().enumerate() {
452 let section_end_byte = headings
453 .iter()
454 .skip(idx + 1)
455 .find(|candidate| candidate.level <= heading.level)
456 .map(|candidate| candidate.start_byte)
457 .unwrap_or(source.len());
458 let body_start_byte =
459 markdown_next_line_start(source, heading.heading_end_byte).min(section_end_byte);
460 symbols.push(Symbol {
461 name: heading.name.clone(),
462 kind: "heading".to_string(),
463 line: heading.start_line,
464 end_line: markdown_zero_based_end_line(source, section_end_byte),
465 node_kind: "atx_heading".to_string(),
466 start_byte: heading.start_byte,
467 end_byte: section_end_byte,
468 body_start_byte: Some(body_start_byte),
469 body_end_byte: Some(section_end_byte),
470 });
471 }
472
473 symbols.sort_by(|left, right| {
474 left.line
475 .cmp(&right.line)
476 .then(left.start_byte.cmp(&right.start_byte))
477 .then(left.kind.cmp(&right.kind))
478 .then(left.name.cmp(&right.name))
479 });
480 symbols
481}
482
483#[cfg(feature = "lang-markdown")]
484fn collect_markdown_symbols(
485 node: tree_sitter::Node<'_>,
486 source: &[u8],
487 headings: &mut Vec<MarkdownHeading>,
488 symbols: &mut Vec<Symbol>,
489) {
490 match node.kind() {
491 "atx_heading" => {
492 if let Some(level) = markdown_heading_level(node)
493 && let Some(name) = markdown_heading_name(node, source)
494 {
495 headings.push(MarkdownHeading {
496 name,
497 level,
498 start_byte: node.start_byte(),
499 heading_end_byte: node.end_byte(),
500 start_line: node.start_position().row,
501 });
502 }
503 }
504 "fenced_code_block" => {
505 let language = markdown_fenced_code_language(node, source)
506 .filter(|value| !value.is_empty())
507 .unwrap_or_else(|| "code".to_string());
508 let body_span = markdown_fenced_code_body_span(node, source);
509 symbols.push(Symbol {
510 name: language,
511 kind: "code_block".to_string(),
512 line: node.start_position().row,
513 end_line: markdown_zero_based_end_line(source, node.end_byte()),
514 node_kind: "fenced_code_block".to_string(),
515 start_byte: node.start_byte(),
516 end_byte: node.end_byte(),
517 body_start_byte: body_span.map(|(start, _)| start),
518 body_end_byte: body_span.map(|(_, end)| end),
519 });
520 }
521 "list_item" => {
522 let name = markdown_list_item_name(node, source);
523 symbols.push(Symbol {
524 name,
525 kind: "list_item".to_string(),
526 line: node.start_position().row,
527 end_line: markdown_zero_based_end_line(source, node.end_byte()),
528 node_kind: "list_item".to_string(),
529 start_byte: node.start_byte(),
530 end_byte: node.end_byte(),
531 body_start_byte: Some(node.start_byte()),
532 body_end_byte: Some(node.end_byte()),
533 });
534 }
535 _ => {}
536 }
537
538 let mut cursor = node.walk();
539 for child in node.children(&mut cursor) {
540 collect_markdown_symbols(child, source, headings, symbols);
541 }
542}
543
544#[cfg(feature = "lang-markdown")]
545fn markdown_heading_level(node: tree_sitter::Node<'_>) -> Option<usize> {
546 let mut cursor = node.walk();
547 for child in node.children(&mut cursor) {
548 let kind = child.kind();
549 if let Some(level) = kind
550 .strip_prefix("atx_h")
551 .and_then(|suffix| suffix.strip_suffix("_marker"))
552 .and_then(|value| value.parse::<usize>().ok())
553 {
554 return Some(level);
555 }
556 }
557 None
558}
559
560#[cfg(feature = "lang-markdown")]
561fn markdown_heading_name(node: tree_sitter::Node<'_>, source: &[u8]) -> Option<String> {
562 let mut cursor = node.walk();
563 for child in node.children(&mut cursor) {
564 if child.kind() == "inline" {
565 let text = child.utf8_text(source).ok()?.trim();
566 if !text.is_empty() {
567 return Some(text.to_string());
568 }
569 }
570 }
571 let line = node.utf8_text(source).ok()?.lines().next()?.trim();
572 let text = line.trim_start_matches('#').trim();
573 (!text.is_empty()).then(|| text.to_string())
574}
575
576#[cfg(feature = "lang-markdown")]
577fn markdown_fenced_code_language(node: tree_sitter::Node<'_>, source: &[u8]) -> Option<String> {
578 if node.kind() == "language" || node.kind() == "info_string" {
579 let text = node.utf8_text(source).ok()?.trim();
580 if !text.is_empty() {
581 return Some(text.to_string());
582 }
583 }
584 let mut cursor = node.walk();
585 for child in node.children(&mut cursor) {
586 if let Some(language) = markdown_fenced_code_language(child, source) {
587 return Some(language);
588 }
589 }
590 None
591}
592
593#[cfg(feature = "lang-markdown")]
594fn markdown_fenced_code_body_span(
595 node: tree_sitter::Node<'_>,
596 source: &[u8],
597) -> Option<(usize, usize)> {
598 let text = node.utf8_text(source).ok()?;
599 let first_newline = text.find('\n')?;
600 let body_start = node.start_byte().saturating_add(first_newline + 1);
601 let closing_start = source[node.start_byte()..node.end_byte()]
602 .iter()
603 .rposition(|byte| *byte == b'\n')
604 .map(|offset| node.start_byte() + offset + 1)
605 .unwrap_or(node.end_byte());
606 Some((body_start.min(closing_start), closing_start))
607}
608
609#[cfg(feature = "lang-markdown")]
610fn markdown_list_item_name(node: tree_sitter::Node<'_>, source: &[u8]) -> String {
611 let text = node.utf8_text(source).unwrap_or("");
612 let first_line = text.lines().next().unwrap_or("").trim();
613 let marker_stripped = first_line
614 .strip_prefix("- ")
615 .or_else(|| first_line.strip_prefix("* "))
616 .or_else(|| first_line.strip_prefix("+ "))
617 .or_else(|| {
618 let (digits, rest) = first_line.split_at(
619 first_line
620 .find(|ch: char| !ch.is_ascii_digit())
621 .unwrap_or(first_line.len()),
622 );
623 (!digits.is_empty())
624 .then_some(rest)
625 .and_then(|rest| rest.strip_prefix(". "))
626 })
627 .unwrap_or(first_line)
628 .trim();
629 if marker_stripped.is_empty() {
630 "list item".to_string()
631 } else {
632 marker_stripped.chars().take(96).collect()
633 }
634}
635
636#[cfg(feature = "lang-markdown")]
637fn markdown_next_line_start(source: &[u8], byte: usize) -> usize {
638 let byte = byte.min(source.len());
639 source[byte..]
640 .iter()
641 .position(|value| *value == b'\n')
642 .map(|offset| byte + offset + 1)
643 .unwrap_or(byte)
644}
645
646#[cfg(feature = "lang-markdown")]
647fn markdown_zero_based_end_line(source: &[u8], end_byte: usize) -> usize {
648 let byte = end_byte.saturating_sub(1).min(source.len());
649 source[..byte]
650 .iter()
651 .filter(|value| **value == b'\n')
652 .count()
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_all_grammars_create_parser() {
661 for lang in Lang::all() {
662 let ts_lang = lang.tree_sitter_language();
663 let mut parser = tree_sitter::Parser::new();
664 parser
665 .set_language(&ts_lang)
666 .unwrap_or_else(|e| panic!("failed to set language for {:?}: {}", lang, e));
667 }
668 }
669
670 #[test]
671 fn test_extension_dispatch() {
672 let cases = [
673 ("rs", "rust"),
674 ("py", "python"),
675 ("pyi", "python"),
676 ("ts", "typescript"),
677 ("tsx", "tsx"),
678 ("js", "javascript"),
679 ("mjs", "javascript"),
680 ("cjs", "javascript"),
681 ("jsx", "jsx"),
682 ("kt", "kotlin"),
683 ("kts", "kotlin"),
684 ("zig", "zig"),
685 ("sh", "bash"),
686 ("bash", "bash"),
687 ("zsh", "bash"),
688 ("md", "markdown"),
689 ("mdx", "markdown"),
690 ];
691 for (ext, expected_name) in cases {
692 let lang = Lang::from_extension(ext)
693 .unwrap_or_else(|| panic!("no language for extension: {ext}"));
694 assert_eq!(lang.name(), expected_name, "wrong language for .{ext}");
695 }
696 }
697
698 #[test]
699 fn test_unknown_extension_returns_none() {
700 assert!(Lang::from_extension("xyz").is_none());
701 assert!(Lang::from_extension("").is_none());
702 assert!(Lang::from_extension("txt").is_none());
703 }
704
705 #[cfg(feature = "lang-rust")]
706 #[test]
707 fn test_parse_rust_snippet() {
708 let lang = Lang::Rust;
709 let mut parser = tree_sitter::Parser::new();
710 parser.set_language(&lang.tree_sitter_language()).unwrap();
711 let tree = parser.parse("fn main() {}", None).unwrap();
712 assert_eq!(tree.root_node().kind(), "source_file");
713 assert!(!tree.root_node().has_error());
714 }
715
716 #[cfg(feature = "lang-python")]
717 #[test]
718 fn test_parse_python_snippet() {
719 let lang = Lang::Python;
720 let mut parser = tree_sitter::Parser::new();
721 parser.set_language(&lang.tree_sitter_language()).unwrap();
722 let tree = parser.parse("def hello():\n pass\n", None).unwrap();
723 assert_eq!(tree.root_node().kind(), "module");
724 assert!(!tree.root_node().has_error());
725 }
726
727 #[cfg(feature = "lang-typescript")]
728 #[test]
729 fn test_parse_typescript_snippet() {
730 let lang = Lang::TypeScript;
731 let mut parser = tree_sitter::Parser::new();
732 parser.set_language(&lang.tree_sitter_language()).unwrap();
733 let tree = parser
734 .parse("function greet(name: string): void {}", None)
735 .unwrap();
736 assert_eq!(tree.root_node().kind(), "program");
737 assert!(!tree.root_node().has_error());
738 }
739
740 #[cfg(feature = "lang-typescript")]
741 #[test]
742 fn test_parse_tsx_snippet() {
743 let lang = Lang::Tsx;
744 let mut parser = tree_sitter::Parser::new();
745 parser.set_language(&lang.tree_sitter_language()).unwrap();
746 let tree = parser
747 .parse("const App = () => <div>hello</div>;", None)
748 .unwrap();
749 assert_eq!(tree.root_node().kind(), "program");
750 assert!(!tree.root_node().has_error());
751 }
752
753 #[cfg(feature = "lang-javascript")]
754 #[test]
755 fn test_parse_javascript_snippet() {
756 let lang = Lang::JavaScript;
757 let mut parser = tree_sitter::Parser::new();
758 parser.set_language(&lang.tree_sitter_language()).unwrap();
759 let tree = parser
760 .parse("function hello() { return 42; }", None)
761 .unwrap();
762 assert_eq!(tree.root_node().kind(), "program");
763 assert!(!tree.root_node().has_error());
764 }
765
766 #[cfg(feature = "lang-kotlin")]
767 #[test]
768 fn test_parse_kotlin_snippet() {
769 let lang = Lang::Kotlin;
770 let mut parser = tree_sitter::Parser::new();
771 parser.set_language(&lang.tree_sitter_language()).unwrap();
772 let tree = parser
773 .parse("fun main() { println(\"hello\") }", None)
774 .unwrap();
775 assert_eq!(tree.root_node().kind(), "source_file");
776 assert!(!tree.root_node().has_error());
777 }
778
779 #[cfg(feature = "lang-zig")]
780 #[test]
781 fn test_parse_zig_snippet() {
782 let lang = Lang::Zig;
783 let mut parser = tree_sitter::Parser::new();
784 parser.set_language(&lang.tree_sitter_language()).unwrap();
785 let tree = parser.parse("pub fn main() !void {}", None).unwrap();
786 assert_eq!(tree.root_node().kind(), "source_file");
787 }
788
789 #[cfg(feature = "lang-bash")]
790 #[test]
791 fn test_parse_bash_snippet() {
792 let lang = Lang::Bash;
793 let mut parser = tree_sitter::Parser::new();
794 parser.set_language(&lang.tree_sitter_language()).unwrap();
795 let tree = parser
796 .parse("#!/bin/bash\nhello() { echo hi; }\n", None)
797 .unwrap();
798 assert_eq!(tree.root_node().kind(), "program");
799 assert!(!tree.root_node().has_error());
800 }
801
802 #[cfg(feature = "lang-markdown")]
803 #[test]
804 fn test_parse_markdown_snippet() {
805 let lang = Lang::Markdown;
806 let mut parser = tree_sitter::Parser::new();
807 parser.set_language(&lang.tree_sitter_language()).unwrap();
808 let tree = parser.parse("# Hello\n\nSome text.\n", None).unwrap();
809 assert_eq!(tree.root_node().kind(), "document");
810 assert!(!tree.root_node().has_error());
811 }
812
813 #[test]
814 fn test_all_symbol_queries_compile() {
815 for lang in Lang::all() {
816 let ts_lang = lang.tree_sitter_language();
817 tree_sitter::Query::new(&ts_lang, lang.symbol_query())
818 .unwrap_or_else(|e| panic!("query compile failed for {:?}: {}", lang, e));
819 }
820 }
821
822 #[cfg(feature = "lang-rust")]
823 #[test]
824 fn test_extract_rust_symbols() {
825 let source = b"fn main() {}\nstruct Foo;\nenum Bar {}\ntrait Baz {}\nconst X: i32 = 1;\nstatic Y: i32 = 2;\nmod inner {}\ntype Alias = i32;\n";
826 let symbols = Lang::Rust.extract_symbols(source).unwrap();
827 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
828 assert!(names.contains(&"main"), "missing main, got {:?}", names);
829 assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
830 assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
831 assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
832 assert!(names.contains(&"X"), "missing X, got {:?}", names);
833 assert!(names.contains(&"Y"), "missing Y, got {:?}", names);
834 assert!(names.contains(&"inner"), "missing inner, got {:?}", names);
835 assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
836 let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
837 assert_eq!(main_sym.kind, "function");
838 let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
839 assert_eq!(foo_sym.kind, "struct");
840 }
841
842 #[cfg(feature = "lang-python")]
843 #[test]
844 fn test_extract_python_symbols() {
845 let source =
846 b"def hello():\n pass\n\nclass MyClass:\n def method(self):\n pass\n";
847 let symbols = Lang::Python.extract_symbols(source).unwrap();
848 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
849 assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
850 assert!(
851 names.contains(&"MyClass"),
852 "missing MyClass, got {:?}",
853 names
854 );
855 assert!(names.contains(&"method"), "missing method, got {:?}", names);
856 let cls = symbols.iter().find(|s| s.name == "MyClass").unwrap();
857 assert_eq!(cls.kind, "class");
858 }
859
860 #[cfg(feature = "lang-typescript")]
861 #[test]
862 fn test_extract_typescript_symbols() {
863 let source = b"function greet(name: string): void {}\nclass Foo {}\ninterface Bar {}\ntype Alias = string;\nenum Color { Red, Green }\n";
864 let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
865 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
866 assert!(names.contains(&"greet"), "missing greet, got {:?}", names);
867 assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
868 assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
869 assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
870 assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
871 }
872
873 #[cfg(feature = "lang-javascript")]
874 #[test]
875 fn test_extract_javascript_symbols() {
876 let source = b"function hello() { return 42; }\nclass Widget {}\n";
877 let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
878 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
879 assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
880 assert!(names.contains(&"Widget"), "missing Widget, got {:?}", names);
881 }
882
883 #[cfg(feature = "lang-kotlin")]
884 #[test]
885 fn test_extract_kotlin_symbols() {
886 let source = b"fun main() { println(\"hi\") }\nclass Foo\ninterface Bar\ndata class Baz(val x: Int)\nsealed class Qux\nenum class Color { RED, GREEN }\nobject Singleton\n";
887 let symbols = Lang::Kotlin.extract_symbols(source).unwrap();
888 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
889 assert!(names.contains(&"main"), "missing main, got {:?}", names);
890 assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
891 assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
892 assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
893 assert!(names.contains(&"Qux"), "missing Qux, got {:?}", names);
894 assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
895 assert!(
896 names.contains(&"Singleton"),
897 "missing Singleton, got {:?}",
898 names
899 );
900 let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
901 assert_eq!(main_sym.kind, "function");
902 let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
903 assert_eq!(foo_sym.kind, "class");
904 let bar_sym = symbols.iter().find(|s| s.name == "Bar").unwrap();
905 assert_eq!(bar_sym.kind, "interface");
906 let baz_sym = symbols.iter().find(|s| s.name == "Baz").unwrap();
907 assert_eq!(baz_sym.kind, "data_class");
908 let qux_sym = symbols.iter().find(|s| s.name == "Qux").unwrap();
909 assert_eq!(qux_sym.kind, "sealed_class");
910 let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
911 assert_eq!(color_sym.kind, "enum_class");
912 let singleton_sym = symbols.iter().find(|s| s.name == "Singleton").unwrap();
913 assert_eq!(singleton_sym.kind, "object");
914 assert_eq!(
915 symbols.len(),
916 7,
917 "expected exactly 7 symbols, got {:?}",
918 symbols
919 );
920 }
921
922 #[cfg(feature = "lang-zig")]
923 #[test]
924 fn test_extract_zig_symbols() {
925 let source = b"const std = @import(\"std\");\npub fn main() !void {}\nconst Point = struct { x: i32, y: i32 };\nconst Color = enum { red, green, blue };\nconst Result = union(enum) { ok: i32, err: []const u8 };\nconst MAX: i32 = 100;\n";
926 let symbols = Lang::Zig.extract_symbols(source).unwrap();
927 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
928 assert!(names.contains(&"main"), "missing main, got {:?}", names);
929 assert!(names.contains(&"Point"), "missing Point, got {:?}", names);
930 assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
931 assert!(names.contains(&"Result"), "missing Result, got {:?}", names);
932 assert!(names.contains(&"std"), "missing std, got {:?}", names);
933 assert!(names.contains(&"MAX"), "missing MAX, got {:?}", names);
934 let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
935 assert_eq!(main_sym.kind, "function");
936 let point_sym = symbols.iter().find(|s| s.name == "Point").unwrap();
937 assert_eq!(point_sym.kind, "struct");
938 let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
939 assert_eq!(color_sym.kind, "enum");
940 let result_sym = symbols.iter().find(|s| s.name == "Result").unwrap();
941 assert_eq!(result_sym.kind, "union");
942 let max_sym = symbols.iter().find(|s| s.name == "MAX").unwrap();
943 assert_eq!(max_sym.kind, "const");
944 }
945
946 #[cfg(feature = "lang-bash")]
947 #[test]
948 fn test_extract_bash_symbols() {
949 let source = b"#!/bin/bash\nhello() { echo hi; }\nfunction world { echo world; }\nalias ll='ls -la'\nalias grep='grep --color=auto'\n";
950 let symbols = Lang::Bash.extract_symbols(source).unwrap();
951 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
952 assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
953 assert!(names.contains(&"world"), "missing world, got {:?}", names);
954 assert!(names.contains(&"ll"), "missing alias ll, got {:?}", names);
955 assert!(
956 names.contains(&"grep"),
957 "missing alias grep, got {:?}",
958 names
959 );
960 let hello_sym = symbols.iter().find(|s| s.name == "hello").unwrap();
961 assert_eq!(hello_sym.kind, "function");
962 let ll_sym = symbols.iter().find(|s| s.name == "ll").unwrap();
963 assert_eq!(ll_sym.kind, "alias");
964 }
965
966 #[cfg(feature = "lang-markdown")]
967 #[test]
968 fn test_extract_markdown_symbols() {
969 let source = b"# Title\n\n## Section One\n\nSome text.\n\n- Run setup\n - Confirm setup\n\n```rust\nfn main() {}\n```\n\n### Subsection\n\n```python\ndef hello():\n pass\n```\n\n## Next Section\n\nDone.\n";
970 let symbols = Lang::Markdown.extract_symbols(source).unwrap();
971 let headings: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "heading").collect();
972 let code_blocks: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "code_block").collect();
973 let list_items: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "list_item").collect();
974 assert_eq!(headings.len(), 4, "expected 4 headings, got {:?}", headings);
975 assert_eq!(
976 code_blocks.len(),
977 2,
978 "expected 2 code blocks, got {:?}",
979 code_blocks
980 );
981 assert_eq!(
982 list_items.len(),
983 2,
984 "expected 2 list items, got {:?}",
985 list_items
986 );
987 let title = headings.iter().find(|s| s.name == "Title").unwrap();
988 let section = headings.iter().find(|s| s.name == "Section One").unwrap();
989 let next = headings.iter().find(|s| s.name == "Next Section").unwrap();
990 assert_eq!(title.node_kind, "atx_heading");
991 assert!(title.end_byte > next.start_byte);
992 assert_eq!(section.end_byte, next.start_byte);
993 assert!(
994 section.body_start_byte.unwrap() > section.start_byte,
995 "heading body should begin after the marker line"
996 );
997 assert!(
998 code_blocks.iter().any(|s| s.name == "rust"),
999 "missing rust block, got {:?}",
1000 code_blocks
1001 );
1002 assert!(
1003 code_blocks.iter().any(|s| s.name == "python"),
1004 "missing python block, got {:?}",
1005 code_blocks
1006 );
1007 assert!(
1008 list_items.iter().any(|s| s.name == "Run setup"),
1009 "missing top-level list item, got {:?}",
1010 list_items
1011 );
1012 }
1013
1014 #[cfg(feature = "lang-python")]
1015 #[test]
1016 fn test_python_async_def() {
1017 let source = b"async def fetch_data():\n await get()\n\ndef sync_fn():\n pass\n";
1018 let symbols = Lang::Python.extract_symbols(source).unwrap();
1019 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1020 assert!(
1021 names.contains(&"fetch_data"),
1022 "missing async function, got {:?}",
1023 names
1024 );
1025 assert!(
1026 names.contains(&"sync_fn"),
1027 "missing sync function, got {:?}",
1028 names
1029 );
1030 }
1031
1032 #[cfg(feature = "lang-python")]
1033 #[test]
1034 fn test_python_decorated_function() {
1035 let source = b"@staticmethod\ndef helper():\n pass\n\n@property\ndef name(self):\n return self._name\n";
1036 let symbols = Lang::Python.extract_symbols(source).unwrap();
1037 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1038 assert!(
1039 names.contains(&"helper"),
1040 "missing decorated function, got {:?}",
1041 names
1042 );
1043 assert!(
1044 names.contains(&"name"),
1045 "missing property function, got {:?}",
1046 names
1047 );
1048 }
1049
1050 #[cfg(feature = "lang-typescript")]
1051 #[test]
1052 fn test_typescript_arrow_exports() {
1053 let source = b"export const Foo = () => { return 42; };\nexport const Bar = (x: number): number => x + 1;\nconst local = () => {};\nfunction regular() {}\n";
1054 let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
1055 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1056 assert!(
1057 names.contains(&"Foo"),
1058 "missing arrow export Foo, got {:?}",
1059 names
1060 );
1061 assert!(
1062 names.contains(&"Bar"),
1063 "missing arrow export Bar, got {:?}",
1064 names
1065 );
1066 assert!(
1067 names.contains(&"local"),
1068 "missing local arrow, got {:?}",
1069 names
1070 );
1071 assert!(
1072 names.contains(&"regular"),
1073 "missing regular function, got {:?}",
1074 names
1075 );
1076 }
1077
1078 #[cfg(feature = "lang-typescript")]
1079 #[test]
1080 fn test_tsx_arrow_component() {
1081 let source = b"export const MyComponent = () => <div>hello</div>;\nfunction Other() { return <span/>; }\n";
1082 let symbols = Lang::Tsx.extract_symbols(source).unwrap();
1083 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1084 assert!(
1085 names.contains(&"MyComponent"),
1086 "missing arrow component, got {:?}",
1087 names
1088 );
1089 assert!(
1090 names.contains(&"Other"),
1091 "missing function component, got {:?}",
1092 names
1093 );
1094 }
1095
1096 #[cfg(feature = "lang-javascript")]
1097 #[test]
1098 fn test_javascript_arrow_exports() {
1099 let source = b"export const handler = () => { return 'ok'; };\nconst helper = (x) => x * 2;\nfunction regular() {}\n";
1100 let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
1101 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1102 assert!(
1103 names.contains(&"handler"),
1104 "missing arrow export, got {:?}",
1105 names
1106 );
1107 assert!(
1108 names.contains(&"helper"),
1109 "missing local arrow, got {:?}",
1110 names
1111 );
1112 assert!(
1113 names.contains(&"regular"),
1114 "missing regular function, got {:?}",
1115 names
1116 );
1117 }
1118
1119 #[cfg(feature = "lang-javascript")]
1120 #[test]
1121 fn test_jsx_arrow_component() {
1122 let source = b"const App = () => <div>hi</div>;\nfunction Page() { return <main/>; }\n";
1123 let symbols = Lang::Jsx.extract_symbols(source).unwrap();
1124 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
1125 assert!(
1126 names.contains(&"App"),
1127 "missing arrow JSX component, got {:?}",
1128 names
1129 );
1130 assert!(
1131 names.contains(&"Page"),
1132 "missing function component, got {:?}",
1133 names
1134 );
1135 }
1136}