1use anyhow::Result;
2use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct Symbol {
6 pub name: String,
7 pub kind: String,
8 pub line: usize,
9 pub end_line: usize,
10 pub node_kind: String,
11 pub start_byte: usize,
12 pub end_byte: usize,
13 pub body_start_byte: Option<usize>,
14 pub body_end_byte: Option<usize>,
15}
16
17#[allow(dead_code)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Lang {
20 #[cfg(feature = "lang-rust")]
21 Rust,
22 #[cfg(feature = "lang-python")]
23 Python,
24 #[cfg(feature = "lang-typescript")]
25 TypeScript,
26 #[cfg(feature = "lang-typescript")]
27 Tsx,
28 #[cfg(feature = "lang-javascript")]
29 JavaScript,
30 #[cfg(feature = "lang-javascript")]
31 Jsx,
32 #[cfg(feature = "lang-kotlin")]
33 Kotlin,
34 #[cfg(feature = "lang-zig")]
35 Zig,
36 #[cfg(feature = "lang-bash")]
37 Bash,
38 #[cfg(feature = "lang-markdown")]
39 Markdown,
40}
41
42#[allow(dead_code)]
43impl Lang {
44 pub fn from_extension(ext: &str) -> Option<Self> {
45 match ext {
46 #[cfg(feature = "lang-rust")]
47 "rs" => Some(Self::Rust),
48 #[cfg(feature = "lang-python")]
49 "py" | "pyi" => Some(Self::Python),
50 #[cfg(feature = "lang-typescript")]
51 "ts" => Some(Self::TypeScript),
52 #[cfg(feature = "lang-typescript")]
53 "tsx" => Some(Self::Tsx),
54 #[cfg(feature = "lang-javascript")]
55 "js" | "mjs" | "cjs" => Some(Self::JavaScript),
56 #[cfg(feature = "lang-javascript")]
57 "jsx" => Some(Self::Jsx),
58 #[cfg(feature = "lang-kotlin")]
59 "kt" | "kts" => Some(Self::Kotlin),
60 #[cfg(feature = "lang-zig")]
61 "zig" => Some(Self::Zig),
62 #[cfg(feature = "lang-bash")]
63 "sh" | "bash" | "zsh" => Some(Self::Bash),
64 #[cfg(feature = "lang-markdown")]
65 "md" | "mdx" => Some(Self::Markdown),
66 _ => None,
67 }
68 }
69
70 pub fn tree_sitter_language(&self) -> Language {
71 match self {
72 #[cfg(feature = "lang-rust")]
73 Self::Rust => tree_sitter_rust::LANGUAGE.into(),
74 #[cfg(feature = "lang-python")]
75 Self::Python => tree_sitter_python::LANGUAGE.into(),
76 #[cfg(feature = "lang-typescript")]
77 Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
78 #[cfg(feature = "lang-typescript")]
79 Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
80 #[cfg(feature = "lang-javascript")]
81 Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
82 #[cfg(feature = "lang-javascript")]
83 Self::Jsx => tree_sitter_javascript::LANGUAGE.into(),
84 #[cfg(feature = "lang-kotlin")]
85 Self::Kotlin => tree_sitter_kotlin_ng::LANGUAGE.into(),
86 #[cfg(feature = "lang-zig")]
87 Self::Zig => tree_sitter_zig::LANGUAGE.into(),
88 #[cfg(feature = "lang-bash")]
89 Self::Bash => tree_sitter_bash::LANGUAGE.into(),
90 #[cfg(feature = "lang-markdown")]
91 Self::Markdown => tsift_md_ast::markdown_language(),
92 }
93 }
94
95 pub fn name(&self) -> &'static str {
96 match self {
97 #[cfg(feature = "lang-rust")]
98 Self::Rust => "rust",
99 #[cfg(feature = "lang-python")]
100 Self::Python => "python",
101 #[cfg(feature = "lang-typescript")]
102 Self::TypeScript => "typescript",
103 #[cfg(feature = "lang-typescript")]
104 Self::Tsx => "tsx",
105 #[cfg(feature = "lang-javascript")]
106 Self::JavaScript => "javascript",
107 #[cfg(feature = "lang-javascript")]
108 Self::Jsx => "jsx",
109 #[cfg(feature = "lang-kotlin")]
110 Self::Kotlin => "kotlin",
111 #[cfg(feature = "lang-zig")]
112 Self::Zig => "zig",
113 #[cfg(feature = "lang-bash")]
114 Self::Bash => "bash",
115 #[cfg(feature = "lang-markdown")]
116 Self::Markdown => "markdown",
117 }
118 }
119
120 pub fn symbol_query(&self) -> &'static str {
121 match self {
122 #[cfg(feature = "lang-rust")]
123 Self::Rust => {
124 r#"
125 (function_item name: (identifier) @function.name)
126 (struct_item name: (type_identifier) @struct.name)
127 (enum_item name: (type_identifier) @enum.name)
128 (trait_item name: (type_identifier) @trait.name)
129 (impl_item type: (type_identifier) @impl.name)
130 (mod_item name: (identifier) @mod.name)
131 (type_item name: (type_identifier) @type_alias.name)
132 (const_item name: (identifier) @const.name)
133 (static_item name: (identifier) @static.name)
134 "#
135 }
136 #[cfg(feature = "lang-python")]
137 Self::Python => {
138 r#"
139 (function_definition name: (identifier) @function.name)
140 (class_definition name: (identifier) @class.name)
141 "#
142 }
143 #[cfg(feature = "lang-typescript")]
144 Self::TypeScript | Self::Tsx => {
145 r#"
146 (function_declaration name: (identifier) @function.name)
147 (class_declaration name: (type_identifier) @class.name)
148 (interface_declaration name: (type_identifier) @interface.name)
149 (type_alias_declaration name: (type_identifier) @type_alias.name)
150 (enum_declaration name: (identifier) @enum.name)
151 (variable_declarator name: (identifier) @function.name value: (arrow_function))
152 "#
153 }
154 #[cfg(feature = "lang-javascript")]
155 Self::JavaScript | Self::Jsx => {
156 r#"
157 (function_declaration name: (identifier) @function.name)
158 (class_declaration name: (identifier) @class.name)
159 (variable_declarator name: (identifier) @function.name value: (arrow_function))
160 "#
161 }
162 #[cfg(feature = "lang-kotlin")]
163 Self::Kotlin => {
164 r#"
165 (function_declaration name: (identifier) @function.name)
166 (class_declaration "interface" name: (identifier) @interface.name)
167 (class_declaration (modifiers (class_modifier "data")) name: (identifier) @data_class.name)
168 (class_declaration (modifiers (class_modifier "sealed")) name: (identifier) @sealed_class.name)
169 (class_declaration (modifiers (class_modifier "enum")) name: (identifier) @enum_class.name)
170 (class_declaration "class" name: (identifier) @class.name)
171 (object_declaration name: (identifier) @object.name)
172 (companion_object name: (identifier) @companion_object.name)
173 "#
174 }
175 #[cfg(feature = "lang-zig")]
176 Self::Zig => {
177 r#"
178 (function_declaration (identifier) @function.name)
179 (variable_declaration (identifier) @struct.name (struct_declaration))
180 (variable_declaration (identifier) @enum.name (enum_declaration))
181 (variable_declaration (identifier) @union.name (union_declaration))
182 (variable_declaration (identifier) @const.name)
183 "#
184 }
185 #[cfg(feature = "lang-bash")]
186 Self::Bash => {
187 r#"
188 (function_definition name: (word) @function.name)
189 "#
190 }
191 #[cfg(feature = "lang-markdown")]
192 Self::Markdown => {
193 r#"
194 (atx_heading (atx_h1_marker) (inline) @heading.name)
195 (atx_heading (atx_h2_marker) (inline) @heading.name)
196 (atx_heading (atx_h3_marker) (inline) @heading.name)
197 (atx_heading (atx_h4_marker) (inline) @heading.name)
198 (atx_heading (atx_h5_marker) (inline) @heading.name)
199 (atx_heading (atx_h6_marker) (inline) @heading.name)
200 (fenced_code_block (info_string (language) @code_block.name))
201 "#
202 }
203 }
204 }
205
206 pub fn call_query(&self) -> Option<&'static str> {
207 match self {
208 #[cfg(feature = "lang-rust")]
209 Self::Rust => Some(
210 r#"
211 (call_expression function: (identifier) @call.name)
212 (call_expression function: (field_expression field: (field_identifier) @call.name))
213 (call_expression function: (scoped_identifier name: (identifier) @call.name))
214 (macro_invocation macro: (identifier) @call.name)
215 "#,
216 ),
217 #[cfg(feature = "lang-python")]
218 Self::Python => Some(
219 r#"
220 (call function: (identifier) @call.name)
221 (call function: (attribute attribute: (identifier) @call.name))
222 "#,
223 ),
224 #[cfg(feature = "lang-typescript")]
225 Self::TypeScript | Self::Tsx => Some(
226 r#"
227 (call_expression function: (identifier) @call.name)
228 (call_expression function: (member_expression property: (property_identifier) @call.name))
229 "#,
230 ),
231 #[cfg(feature = "lang-javascript")]
232 Self::JavaScript | Self::Jsx => Some(
233 r#"
234 (call_expression function: (identifier) @call.name)
235 (call_expression function: (member_expression property: (property_identifier) @call.name))
236 "#,
237 ),
238 #[cfg(feature = "lang-kotlin")]
239 Self::Kotlin => Some(
240 r#"
241 (call_expression (simple_identifier) @call.name)
242 "#,
243 ),
244 _ => None,
245 }
246 }
247
248 pub fn extract_symbols(&self, source: &[u8]) -> Result<Vec<Symbol>> {
249 let mut parser = Parser::new();
250 let ts_lang = self.tree_sitter_language();
251 parser.set_language(&ts_lang)?;
252 let tree = parser
253 .parse(source, None)
254 .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
255 #[cfg(feature = "lang-markdown")]
256 if *self == Self::Markdown {
257 return Ok(tsift_md_ast::markdown_symbols_from_tree(&tree, source)
258 .into_iter()
259 .map(md_symbol_to_symbol)
260 .collect());
261 }
262 let query = Query::new(&ts_lang, self.symbol_query())?;
263 let mut cursor = QueryCursor::new();
264 let mut symbols = Vec::new();
265 let capture_names: Vec<String> = query
266 .capture_names()
267 .iter()
268 .map(|s| s.to_string())
269 .collect();
270
271 let mut matches = cursor.matches(&query, tree.root_node(), source);
272 while let Some(m) = matches.next() {
273 for capture in m.captures {
274 let capture_name = &capture_names[capture.index as usize];
275 if let Some(kind_str) = capture_name.strip_suffix(".name") {
276 let name = capture
277 .node
278 .utf8_text(source)
279 .unwrap_or("<invalid utf8>")
280 .to_string();
281 let node = symbol_node_for_capture(kind_str, capture.node);
282 let body_span = symbol_body_span(node);
283 symbols.push(Symbol {
284 name,
285 kind: kind_str.to_string(),
286 line: node.start_position().row,
287 end_line: node.end_position().row,
288 node_kind: node.kind().to_string(),
289 start_byte: node.start_byte(),
290 end_byte: node.end_byte(),
291 body_start_byte: body_span.map(|(start, _)| start),
292 body_end_byte: body_span.map(|(_, end)| end),
293 });
294 }
295 }
296 }
297
298 #[cfg(feature = "lang-bash")]
299 if *self == Self::Bash {
300 Self::extract_bash_aliases(&tree, source, &mut symbols);
301 }
302 symbols.sort_by(|a, b| a.line.cmp(&b.line).then(a.name.cmp(&b.name)));
303 symbols.dedup_by(|b, a| {
304 a.name == b.name && a.line == b.line && {
305 let a_generic = matches!(a.kind.as_str(), "variable" | "const");
306 let b_generic = matches!(b.kind.as_str(), "variable" | "const");
307 match (a_generic, b_generic) {
308 (true, false) => a.kind.clone_from(&b.kind),
309 (false, true) => {}
310 _ => {
311 if b.kind.len() > a.kind.len() {
312 a.kind.clone_from(&b.kind);
313 }
314 }
315 }
316 true
317 }
318 });
319 Ok(symbols)
320 }
321
322 #[cfg(feature = "lang-bash")]
323 fn extract_bash_aliases(tree: &tree_sitter::Tree, source: &[u8], symbols: &mut Vec<Symbol>) {
324 let mut tree_cursor = tree.root_node().walk();
325 if !tree_cursor.goto_first_child() {
326 return;
327 }
328 loop {
329 let node = tree_cursor.node();
330 if node.kind() == "command"
331 && let Some(name_node) = node.child_by_field_name("name")
332 {
333 let cmd = name_node.utf8_text(source).unwrap_or("");
334 if cmd == "alias" {
335 for i in 0..node.named_child_count() {
336 if let Some(arg) = node.named_child(i as u32)
337 && (arg.kind() == "concatenation" || arg.kind() == "word")
338 {
339 let text = arg.utf8_text(source).unwrap_or("");
340 if let Some(alias_name) = text.split('=').next()
341 && !alias_name.is_empty()
342 && alias_name != cmd
343 {
344 symbols.push(Symbol {
345 name: alias_name.to_string(),
346 kind: "alias".to_string(),
347 line: arg.start_position().row,
348 end_line: node.end_position().row,
349 node_kind: node.kind().to_string(),
350 start_byte: arg.start_byte(),
351 end_byte: node.end_byte(),
352 body_start_byte: None,
353 body_end_byte: None,
354 });
355 }
356 }
357 }
358 }
359 }
360 if !tree_cursor.goto_next_sibling() {
361 break;
362 }
363 }
364 }
365
366 pub fn all() -> Vec<Self> {
367 vec![
368 #[cfg(feature = "lang-rust")]
369 Self::Rust,
370 #[cfg(feature = "lang-python")]
371 Self::Python,
372 #[cfg(feature = "lang-typescript")]
373 Self::TypeScript,
374 #[cfg(feature = "lang-typescript")]
375 Self::Tsx,
376 #[cfg(feature = "lang-javascript")]
377 Self::JavaScript,
378 #[cfg(feature = "lang-javascript")]
379 Self::Jsx,
380 #[cfg(feature = "lang-kotlin")]
381 Self::Kotlin,
382 #[cfg(feature = "lang-zig")]
383 Self::Zig,
384 #[cfg(feature = "lang-bash")]
385 Self::Bash,
386 #[cfg(feature = "lang-markdown")]
387 Self::Markdown,
388 ]
389 }
390}
391
392fn symbol_node_for_capture<'tree>(
393 kind: &str,
394 name_node: tree_sitter::Node<'tree>,
395) -> tree_sitter::Node<'tree> {
396 let mut node = name_node.parent().unwrap_or(name_node);
397 if kind == "code_block" {
398 while let Some(parent) = node.parent() {
399 node = parent;
400 if node.kind() == "fenced_code_block" {
401 break;
402 }
403 }
404 }
405 node
406}
407
408fn symbol_body_span(node: tree_sitter::Node<'_>) -> Option<(usize, usize)> {
409 if let Some(body) = node.child_by_field_name("body") {
410 return Some((body.start_byte(), body.end_byte()));
411 }
412 for idx in 0..node.named_child_count() {
413 let Some(child) = node.named_child(idx as u32) else {
414 continue;
415 };
416 if matches!(
417 child.kind(),
418 "block"
419 | "declaration_list"
420 | "field_declaration_list"
421 | "enum_variant_list"
422 | "match_block"
423 | "statement_block"
424 | "suite"
425 ) {
426 return Some((child.start_byte(), child.end_byte()));
427 }
428 }
429 None
430}
431
432#[cfg(feature = "lang-markdown")]
433fn md_symbol_to_symbol(md: tsift_md_ast::MdSymbol) -> Symbol {
434 Symbol {
435 name: md.name,
436 kind: md.kind,
437 line: md.line,
438 end_line: md.end_line,
439 node_kind: md.node_kind,
440 start_byte: md.start_byte,
441 end_byte: md.end_byte,
442 body_start_byte: md.body_start_byte,
443 body_end_byte: md.body_end_byte,
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_all_grammars_create_parser() {
453 for lang in Lang::all() {
454 let ts_lang = lang.tree_sitter_language();
455 let mut parser = tree_sitter::Parser::new();
456 parser
457 .set_language(&ts_lang)
458 .unwrap_or_else(|e| panic!("failed to set language for {:?}: {}", lang, e));
459 }
460 }
461
462 #[test]
463 fn test_extension_dispatch() {
464 let cases = [
465 ("rs", "rust"),
466 ("py", "python"),
467 ("pyi", "python"),
468 ("ts", "typescript"),
469 ("tsx", "tsx"),
470 ("js", "javascript"),
471 ("mjs", "javascript"),
472 ("cjs", "javascript"),
473 ("jsx", "jsx"),
474 ("kt", "kotlin"),
475 ("kts", "kotlin"),
476 ("zig", "zig"),
477 ("sh", "bash"),
478 ("bash", "bash"),
479 ("zsh", "bash"),
480 ("md", "markdown"),
481 ("mdx", "markdown"),
482 ];
483 for (ext, expected_name) in cases {
484 let lang = Lang::from_extension(ext)
485 .unwrap_or_else(|| panic!("no language for extension: {ext}"));
486 assert_eq!(lang.name(), expected_name, "wrong language for .{ext}");
487 }
488 }
489
490 #[test]
491 fn test_unknown_extension_returns_none() {
492 assert!(Lang::from_extension("xyz").is_none());
493 assert!(Lang::from_extension("").is_none());
494 assert!(Lang::from_extension("txt").is_none());
495 }
496
497 #[cfg(feature = "lang-rust")]
498 #[test]
499 fn test_parse_rust_snippet() {
500 let lang = Lang::Rust;
501 let mut parser = tree_sitter::Parser::new();
502 parser.set_language(&lang.tree_sitter_language()).unwrap();
503 let tree = parser.parse("fn main() {}", None).unwrap();
504 assert_eq!(tree.root_node().kind(), "source_file");
505 assert!(!tree.root_node().has_error());
506 }
507
508 #[cfg(feature = "lang-python")]
509 #[test]
510 fn test_parse_python_snippet() {
511 let lang = Lang::Python;
512 let mut parser = tree_sitter::Parser::new();
513 parser.set_language(&lang.tree_sitter_language()).unwrap();
514 let tree = parser.parse("def hello():\n pass\n", None).unwrap();
515 assert_eq!(tree.root_node().kind(), "module");
516 assert!(!tree.root_node().has_error());
517 }
518
519 #[cfg(feature = "lang-typescript")]
520 #[test]
521 fn test_parse_typescript_snippet() {
522 let lang = Lang::TypeScript;
523 let mut parser = tree_sitter::Parser::new();
524 parser.set_language(&lang.tree_sitter_language()).unwrap();
525 let tree = parser
526 .parse("function greet(name: string): void {}", None)
527 .unwrap();
528 assert_eq!(tree.root_node().kind(), "program");
529 assert!(!tree.root_node().has_error());
530 }
531
532 #[cfg(feature = "lang-typescript")]
533 #[test]
534 fn test_parse_tsx_snippet() {
535 let lang = Lang::Tsx;
536 let mut parser = tree_sitter::Parser::new();
537 parser.set_language(&lang.tree_sitter_language()).unwrap();
538 let tree = parser
539 .parse("const App = () => <div>hello</div>;", None)
540 .unwrap();
541 assert_eq!(tree.root_node().kind(), "program");
542 assert!(!tree.root_node().has_error());
543 }
544
545 #[cfg(feature = "lang-javascript")]
546 #[test]
547 fn test_parse_javascript_snippet() {
548 let lang = Lang::JavaScript;
549 let mut parser = tree_sitter::Parser::new();
550 parser.set_language(&lang.tree_sitter_language()).unwrap();
551 let tree = parser
552 .parse("function hello() { return 42; }", None)
553 .unwrap();
554 assert_eq!(tree.root_node().kind(), "program");
555 assert!(!tree.root_node().has_error());
556 }
557
558 #[cfg(feature = "lang-kotlin")]
559 #[test]
560 fn test_parse_kotlin_snippet() {
561 let lang = Lang::Kotlin;
562 let mut parser = tree_sitter::Parser::new();
563 parser.set_language(&lang.tree_sitter_language()).unwrap();
564 let tree = parser
565 .parse("fun main() { println(\"hello\") }", None)
566 .unwrap();
567 assert_eq!(tree.root_node().kind(), "source_file");
568 assert!(!tree.root_node().has_error());
569 }
570
571 #[cfg(feature = "lang-zig")]
572 #[test]
573 fn test_parse_zig_snippet() {
574 let lang = Lang::Zig;
575 let mut parser = tree_sitter::Parser::new();
576 parser.set_language(&lang.tree_sitter_language()).unwrap();
577 let tree = parser.parse("pub fn main() !void {}", None).unwrap();
578 assert_eq!(tree.root_node().kind(), "source_file");
579 }
580
581 #[cfg(feature = "lang-bash")]
582 #[test]
583 fn test_parse_bash_snippet() {
584 let lang = Lang::Bash;
585 let mut parser = tree_sitter::Parser::new();
586 parser.set_language(&lang.tree_sitter_language()).unwrap();
587 let tree = parser
588 .parse("#!/bin/bash\nhello() { echo hi; }\n", None)
589 .unwrap();
590 assert_eq!(tree.root_node().kind(), "program");
591 assert!(!tree.root_node().has_error());
592 }
593
594 #[cfg(feature = "lang-markdown")]
595 #[test]
596 fn test_parse_markdown_snippet() {
597 let lang = Lang::Markdown;
598 let mut parser = tree_sitter::Parser::new();
599 parser.set_language(&lang.tree_sitter_language()).unwrap();
600 let tree = parser.parse("# Hello\n\nSome text.\n", None).unwrap();
601 assert_eq!(tree.root_node().kind(), "document");
602 assert!(!tree.root_node().has_error());
603 }
604
605 #[test]
606 fn test_all_symbol_queries_compile() {
607 for lang in Lang::all() {
608 let ts_lang = lang.tree_sitter_language();
609 tree_sitter::Query::new(&ts_lang, lang.symbol_query())
610 .unwrap_or_else(|e| panic!("query compile failed for {:?}: {}", lang, e));
611 }
612 }
613
614 #[cfg(feature = "lang-rust")]
615 #[test]
616 fn test_extract_rust_symbols() {
617 let source = b"fn main() {}\nstruct Foo;\nenum Bar {}\ntrait Baz {}\nconst X: i32 = 1;\nstatic Y: i32 = 2;\nmod inner {}\ntype Alias = i32;\n";
618 let symbols = Lang::Rust.extract_symbols(source).unwrap();
619 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
620 assert!(names.contains(&"main"), "missing main, got {:?}", names);
621 assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
622 assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
623 assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
624 assert!(names.contains(&"X"), "missing X, got {:?}", names);
625 assert!(names.contains(&"Y"), "missing Y, got {:?}", names);
626 assert!(names.contains(&"inner"), "missing inner, got {:?}", names);
627 assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
628 let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
629 assert_eq!(main_sym.kind, "function");
630 let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
631 assert_eq!(foo_sym.kind, "struct");
632 }
633
634 #[cfg(feature = "lang-python")]
635 #[test]
636 fn test_extract_python_symbols() {
637 let source =
638 b"def hello():\n pass\n\nclass MyClass:\n def method(self):\n pass\n";
639 let symbols = Lang::Python.extract_symbols(source).unwrap();
640 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
641 assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
642 assert!(
643 names.contains(&"MyClass"),
644 "missing MyClass, got {:?}",
645 names
646 );
647 assert!(names.contains(&"method"), "missing method, got {:?}", names);
648 let cls = symbols.iter().find(|s| s.name == "MyClass").unwrap();
649 assert_eq!(cls.kind, "class");
650 }
651
652 #[cfg(feature = "lang-typescript")]
653 #[test]
654 fn test_extract_typescript_symbols() {
655 let source = b"function greet(name: string): void {}\nclass Foo {}\ninterface Bar {}\ntype Alias = string;\nenum Color { Red, Green }\n";
656 let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
657 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
658 assert!(names.contains(&"greet"), "missing greet, got {:?}", names);
659 assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
660 assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
661 assert!(names.contains(&"Alias"), "missing Alias, got {:?}", names);
662 assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
663 }
664
665 #[cfg(feature = "lang-javascript")]
666 #[test]
667 fn test_extract_javascript_symbols() {
668 let source = b"function hello() { return 42; }\nclass Widget {}\n";
669 let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
670 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
671 assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
672 assert!(names.contains(&"Widget"), "missing Widget, got {:?}", names);
673 }
674
675 #[cfg(feature = "lang-kotlin")]
676 #[test]
677 fn test_extract_kotlin_symbols() {
678 let source = b"fun main() { println(\"hi\") }\nclass Foo\ninterface Bar\ndata class Baz(val x: Int)\nsealed class Qux\nenum class Color { RED, GREEN }\nobject Singleton\n";
679 let symbols = Lang::Kotlin.extract_symbols(source).unwrap();
680 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
681 assert!(names.contains(&"main"), "missing main, got {:?}", names);
682 assert!(names.contains(&"Foo"), "missing Foo, got {:?}", names);
683 assert!(names.contains(&"Bar"), "missing Bar, got {:?}", names);
684 assert!(names.contains(&"Baz"), "missing Baz, got {:?}", names);
685 assert!(names.contains(&"Qux"), "missing Qux, got {:?}", names);
686 assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
687 assert!(
688 names.contains(&"Singleton"),
689 "missing Singleton, got {:?}",
690 names
691 );
692 let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
693 assert_eq!(main_sym.kind, "function");
694 let foo_sym = symbols.iter().find(|s| s.name == "Foo").unwrap();
695 assert_eq!(foo_sym.kind, "class");
696 let bar_sym = symbols.iter().find(|s| s.name == "Bar").unwrap();
697 assert_eq!(bar_sym.kind, "interface");
698 let baz_sym = symbols.iter().find(|s| s.name == "Baz").unwrap();
699 assert_eq!(baz_sym.kind, "data_class");
700 let qux_sym = symbols.iter().find(|s| s.name == "Qux").unwrap();
701 assert_eq!(qux_sym.kind, "sealed_class");
702 let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
703 assert_eq!(color_sym.kind, "enum_class");
704 let singleton_sym = symbols.iter().find(|s| s.name == "Singleton").unwrap();
705 assert_eq!(singleton_sym.kind, "object");
706 assert_eq!(
707 symbols.len(),
708 7,
709 "expected exactly 7 symbols, got {:?}",
710 symbols
711 );
712 }
713
714 #[cfg(feature = "lang-zig")]
715 #[test]
716 fn test_extract_zig_symbols() {
717 let source = b"const std = @import(\"std\");\npub fn main() !void {}\nconst Point = struct { x: i32, y: i32 };\nconst Color = enum { red, green, blue };\nconst Result = union(enum) { ok: i32, err: []const u8 };\nconst MAX: i32 = 100;\n";
718 let symbols = Lang::Zig.extract_symbols(source).unwrap();
719 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
720 assert!(names.contains(&"main"), "missing main, got {:?}", names);
721 assert!(names.contains(&"Point"), "missing Point, got {:?}", names);
722 assert!(names.contains(&"Color"), "missing Color, got {:?}", names);
723 assert!(names.contains(&"Result"), "missing Result, got {:?}", names);
724 assert!(names.contains(&"std"), "missing std, got {:?}", names);
725 assert!(names.contains(&"MAX"), "missing MAX, got {:?}", names);
726 let main_sym = symbols.iter().find(|s| s.name == "main").unwrap();
727 assert_eq!(main_sym.kind, "function");
728 let point_sym = symbols.iter().find(|s| s.name == "Point").unwrap();
729 assert_eq!(point_sym.kind, "struct");
730 let color_sym = symbols.iter().find(|s| s.name == "Color").unwrap();
731 assert_eq!(color_sym.kind, "enum");
732 let result_sym = symbols.iter().find(|s| s.name == "Result").unwrap();
733 assert_eq!(result_sym.kind, "union");
734 let max_sym = symbols.iter().find(|s| s.name == "MAX").unwrap();
735 assert_eq!(max_sym.kind, "const");
736 }
737
738 #[cfg(feature = "lang-bash")]
739 #[test]
740 fn test_extract_bash_symbols() {
741 let source = b"#!/bin/bash\nhello() { echo hi; }\nfunction world { echo world; }\nalias ll='ls -la'\nalias grep='grep --color=auto'\n";
742 let symbols = Lang::Bash.extract_symbols(source).unwrap();
743 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
744 assert!(names.contains(&"hello"), "missing hello, got {:?}", names);
745 assert!(names.contains(&"world"), "missing world, got {:?}", names);
746 assert!(names.contains(&"ll"), "missing alias ll, got {:?}", names);
747 assert!(
748 names.contains(&"grep"),
749 "missing alias grep, got {:?}",
750 names
751 );
752 let hello_sym = symbols.iter().find(|s| s.name == "hello").unwrap();
753 assert_eq!(hello_sym.kind, "function");
754 let ll_sym = symbols.iter().find(|s| s.name == "ll").unwrap();
755 assert_eq!(ll_sym.kind, "alias");
756 }
757
758 #[cfg(feature = "lang-markdown")]
759 #[test]
760 fn test_extract_markdown_symbols() {
761 let source = b"# Title\n\n## Section One\n\nSome text.\n\n- Run setup\n - Confirm setup\n\n```rust\nfn main() {}\n```\n\n### Subsection\n\n```python\ndef hello():\n pass\n```\n\n## Next Section\n\nDone.\n";
762 let symbols = Lang::Markdown.extract_symbols(source).unwrap();
763 let headings: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "heading").collect();
764 let code_blocks: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "code_block").collect();
765 let list_items: Vec<&Symbol> = symbols.iter().filter(|s| s.kind == "list_item").collect();
766 assert_eq!(headings.len(), 4, "expected 4 headings, got {:?}", headings);
767 assert_eq!(
768 code_blocks.len(),
769 2,
770 "expected 2 code blocks, got {:?}",
771 code_blocks
772 );
773 assert_eq!(
774 list_items.len(),
775 2,
776 "expected 2 list items, got {:?}",
777 list_items
778 );
779 let title = headings.iter().find(|s| s.name == "Title").unwrap();
780 let section = headings.iter().find(|s| s.name == "Section One").unwrap();
781 let next = headings.iter().find(|s| s.name == "Next Section").unwrap();
782 assert_eq!(title.node_kind, "atx_heading");
783 assert!(title.end_byte > next.start_byte);
784 assert_eq!(section.end_byte, next.start_byte);
785 assert!(
786 section.body_start_byte.unwrap() > section.start_byte,
787 "heading body should begin after the marker line"
788 );
789 assert!(
790 code_blocks.iter().any(|s| s.name == "rust"),
791 "missing rust block, got {:?}",
792 code_blocks
793 );
794 assert!(
795 code_blocks.iter().any(|s| s.name == "python"),
796 "missing python block, got {:?}",
797 code_blocks
798 );
799 assert!(
800 list_items.iter().any(|s| s.name == "Run setup"),
801 "missing top-level list item, got {:?}",
802 list_items
803 );
804 }
805
806 #[cfg(feature = "lang-python")]
807 #[test]
808 fn test_python_async_def() {
809 let source = b"async def fetch_data():\n await get()\n\ndef sync_fn():\n pass\n";
810 let symbols = Lang::Python.extract_symbols(source).unwrap();
811 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
812 assert!(
813 names.contains(&"fetch_data"),
814 "missing async function, got {:?}",
815 names
816 );
817 assert!(
818 names.contains(&"sync_fn"),
819 "missing sync function, got {:?}",
820 names
821 );
822 }
823
824 #[cfg(feature = "lang-python")]
825 #[test]
826 fn test_python_decorated_function() {
827 let source = b"@staticmethod\ndef helper():\n pass\n\n@property\ndef name(self):\n return self._name\n";
828 let symbols = Lang::Python.extract_symbols(source).unwrap();
829 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
830 assert!(
831 names.contains(&"helper"),
832 "missing decorated function, got {:?}",
833 names
834 );
835 assert!(
836 names.contains(&"name"),
837 "missing property function, got {:?}",
838 names
839 );
840 }
841
842 #[cfg(feature = "lang-typescript")]
843 #[test]
844 fn test_typescript_arrow_exports() {
845 let source = b"export const Foo = () => { return 42; };\nexport const Bar = (x: number): number => x + 1;\nconst local = () => {};\nfunction regular() {}\n";
846 let symbols = Lang::TypeScript.extract_symbols(source).unwrap();
847 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
848 assert!(
849 names.contains(&"Foo"),
850 "missing arrow export Foo, got {:?}",
851 names
852 );
853 assert!(
854 names.contains(&"Bar"),
855 "missing arrow export Bar, got {:?}",
856 names
857 );
858 assert!(
859 names.contains(&"local"),
860 "missing local arrow, got {:?}",
861 names
862 );
863 assert!(
864 names.contains(&"regular"),
865 "missing regular function, got {:?}",
866 names
867 );
868 }
869
870 #[cfg(feature = "lang-typescript")]
871 #[test]
872 fn test_tsx_arrow_component() {
873 let source = b"export const MyComponent = () => <div>hello</div>;\nfunction Other() { return <span/>; }\n";
874 let symbols = Lang::Tsx.extract_symbols(source).unwrap();
875 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
876 assert!(
877 names.contains(&"MyComponent"),
878 "missing arrow component, got {:?}",
879 names
880 );
881 assert!(
882 names.contains(&"Other"),
883 "missing function component, got {:?}",
884 names
885 );
886 }
887
888 #[cfg(feature = "lang-javascript")]
889 #[test]
890 fn test_javascript_arrow_exports() {
891 let source = b"export const handler = () => { return 'ok'; };\nconst helper = (x) => x * 2;\nfunction regular() {}\n";
892 let symbols = Lang::JavaScript.extract_symbols(source).unwrap();
893 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
894 assert!(
895 names.contains(&"handler"),
896 "missing arrow export, got {:?}",
897 names
898 );
899 assert!(
900 names.contains(&"helper"),
901 "missing local arrow, got {:?}",
902 names
903 );
904 assert!(
905 names.contains(&"regular"),
906 "missing regular function, got {:?}",
907 names
908 );
909 }
910
911 #[cfg(feature = "lang-javascript")]
912 #[test]
913 fn test_jsx_arrow_component() {
914 let source = b"const App = () => <div>hi</div>;\nfunction Page() { return <main/>; }\n";
915 let symbols = Lang::Jsx.extract_symbols(source).unwrap();
916 let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
917 assert!(
918 names.contains(&"App"),
919 "missing arrow JSX component, got {:?}",
920 names
921 );
922 assert!(
923 names.contains(&"Page"),
924 "missing function component, got {:?}",
925 names
926 );
927 }
928}