1use std::path::Path;
9use tree_sitter::Node;
10
11use super::error::{AstQueryError, Result};
12use super::types::{Context, ContextItem, ContextKind, ContextualMatch, ContextualMatchLocation};
13use crate::graph::unified::build::StagingGraph;
14use crate::graph::unified::concurrent::CodeGraph;
15use crate::plugin::PluginManager;
16
17pub struct ContextExtractor {
33 plugin_manager: PluginManager,
35}
36
37impl ContextExtractor {
38 #[must_use]
45 pub fn new() -> Self {
46 Self::with_plugin_manager(PluginManager::new())
47 }
48
49 #[must_use]
69 pub fn with_plugin_manager(plugin_manager: PluginManager) -> Self {
70 Self { plugin_manager }
71 }
72
73 #[allow(clippy::too_many_lines)]
108 pub fn extract_from_file(&self, path: &Path) -> Result<Vec<ContextualMatch>> {
109 let plugin = self.plugin_manager.plugin_for_path(path).ok_or_else(|| {
111 AstQueryError::ContextExtraction(format!(
112 "No plugin found for path: {}",
113 path.display()
114 ))
115 })?;
116
117 let raw_content = std::fs::read(path)?;
119
120 let lang_name = plugin.metadata().id;
122
123 let (prepared_content, tree) = plugin
125 .prepare_ast(&raw_content)
126 .map_err(|e| AstQueryError::ContextExtraction(format!("Failed to parse AST: {e:?}")))?;
127 let parse_content = prepared_content.as_ref();
128
129 let builder = plugin.graph_builder().ok_or_else(|| {
130 AstQueryError::ContextExtraction(format!("No graph builder registered for {lang_name}"))
131 })?;
132
133 let mut staging = StagingGraph::new();
135 builder
136 .build_graph(&tree, parse_content, path, &mut staging)
137 .map_err(|e| {
138 AstQueryError::ContextExtraction(format!(
139 "Failed to build graph for {}: {e}",
140 path.display()
141 ))
142 })?;
143
144 staging.attach_body_hashes(&raw_content, None);
147
148 let mut graph = CodeGraph::new();
149 let file_id = graph
150 .files_mut()
151 .register_with_language(path, Some(builder.language()))
152 .map_err(|e| {
153 AstQueryError::ContextExtraction(format!(
154 "Failed to register file {}: {e}",
155 path.display()
156 ))
157 })?;
158 staging.apply_file_id(file_id);
159
160 let string_remap = staging.commit_strings(graph.strings_mut()).map_err(|e| {
161 AstQueryError::ContextExtraction(format!(
162 "Failed to commit strings for {}: {e}",
163 path.display()
164 ))
165 })?;
166 staging.apply_string_remap(&string_remap).map_err(|e| {
167 AstQueryError::ContextExtraction(format!(
168 "Failed to remap strings for {}: {e}",
169 path.display()
170 ))
171 })?;
172 let _node_id_map = staging.commit_nodes(graph.nodes_mut()).map_err(|e| {
173 AstQueryError::ContextExtraction(format!(
174 "Failed to commit nodes for {}: {e}",
175 path.display()
176 ))
177 })?;
178
179 let content_str = String::from_utf8_lossy(&raw_content);
181 let root_node = tree.root_node();
182
183 let mut contextual_matches = Vec::new();
189 for (_, entry) in graph.nodes().iter() {
190 if entry.is_unified_loser() {
191 continue;
192 }
193 if ContextKind::from_node_kind(entry.kind).is_none() {
194 continue;
195 }
196 if entry.start_line == 0 {
197 continue;
198 }
199 let start_line = entry.start_line;
200 let start_column = entry.start_column;
201 let mut node = Self::find_defining_node(root_node, start_line, start_column, lang_name);
202
203 if node.is_none()
204 && Self::looks_like_byte_span(
205 entry.start_line,
206 entry.end_line,
207 entry.start_column,
208 entry.end_column,
209 &content_str,
210 )
211 {
212 node = Self::find_defining_node_by_bytes(
213 root_node,
214 entry.start_column as usize,
215 entry.end_column as usize,
216 lang_name,
217 );
218 }
219
220 if let Some(node) = node {
221 let semantic_context = Self::build_context(&node, &content_str, lang_name);
223 let match_name = semantic_context.immediate.name.clone();
224
225 let location = ContextualMatchLocation::new(
226 path.to_path_buf(),
227 entry.start_line,
228 entry.start_column,
229 entry.end_line,
230 entry.end_column,
231 );
232 contextual_matches.push(ContextualMatch::new(
233 match_name,
234 location,
235 semantic_context,
236 lang_name.to_string(),
237 ));
238 }
239 }
240
241 Ok(contextual_matches)
242 }
243
244 fn find_defining_node<'a>(
248 root: Node<'a>,
249 line: u32,
250 column: u32,
251 lang_name: &str,
252 ) -> Option<Node<'a>> {
253 let mut cursor = root.walk();
254 Self::find_defining_node_recursive(root, line, column, lang_name, &mut cursor)
255 }
256
257 fn find_defining_node_by_bytes<'a>(
258 root: Node<'a>,
259 start: usize,
260 end: usize,
261 lang_name: &str,
262 ) -> Option<Node<'a>> {
263 let target = root.descendant_for_byte_range(start, end)?;
264 let mut current = Some(target);
265
266 while let Some(node) = current {
267 if Self::is_named_scope(&node, lang_name) {
268 return Some(node);
269 }
270 current = node.parent();
271 }
272
273 None
274 }
275
276 fn looks_like_byte_span(
277 start_line: u32,
278 end_line: u32,
279 start_column: u32,
280 end_column: u32,
281 source: &str,
282 ) -> bool {
283 if start_line != 1 || end_line != 1 {
284 return false;
285 }
286 let first_line_len = source.lines().next().map_or(0, str::len);
287 let start = start_column as usize;
288 let end = end_column as usize;
289 start > first_line_len || end > first_line_len
290 }
291
292 fn find_defining_node_recursive<'a>(
294 node: Node<'a>,
295 line: u32,
296 column: u32,
297 lang_name: &str,
298 cursor: &mut tree_sitter::TreeCursor<'a>,
299 ) -> Option<Node<'a>> {
300 let start_pos = node.start_position();
303 let end_pos = node.end_position();
304
305 let node_start_line = start_pos
307 .row
308 .try_into()
309 .unwrap_or(u32::MAX)
310 .saturating_add(1);
311 let node_end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
312
313 let line_in_range = line >= node_start_line && line <= node_end_line;
315
316 let start_col: u32 = start_pos.column.try_into().unwrap_or(u32::MAX);
318 let end_col: u32 = end_pos.column.try_into().unwrap_or(u32::MAX);
319
320 let col_in_range = if line == node_start_line && line == node_end_line {
322 column >= start_col && column <= end_col
324 } else if line == node_start_line {
325 column >= start_col
327 } else if line == node_end_line {
328 column <= end_col
330 } else {
331 true
333 };
334
335 if !line_in_range || !col_in_range {
337 return None;
338 }
339
340 let children: Vec<Node<'a>> = node.children(cursor).collect();
342
343 for child in children {
345 let child_end = child.end_position();
346 let child_end_line: u32 = child_end
348 .row
349 .try_into()
350 .unwrap_or(u32::MAX)
351 .saturating_add(1);
352 if child_end_line >= line
353 && let Some(found) =
354 Self::find_defining_node_recursive(child, line, column, lang_name, cursor)
355 {
356 return Some(found);
357 }
358 }
359
360 if Self::is_named_scope(&node, lang_name) {
362 return Some(node);
363 }
364
365 None
366 }
367
368 fn build_context(node: &Node, source: &str, lang_name: &str) -> Context {
372 let source_bytes = source.as_bytes();
373
374 let immediate = Self::node_to_context_item(node, source_bytes, lang_name);
376
377 let mut parent = None;
379 let mut ancestors = Vec::new();
380 let mut current = node.parent();
381
382 while let Some(node) = current {
383 if Self::is_named_scope(&node, lang_name) {
385 let item = Self::node_to_context_item(&node, source_bytes, lang_name);
386
387 if parent.is_none() {
388 parent = Some(item);
389 } else {
390 ancestors.push(item);
391 }
392 }
393
394 current = node.parent();
395 }
396
397 Context::new(immediate, parent, ancestors)
398 }
399
400 fn node_to_context_item(node: &Node, source_bytes: &[u8], lang_name: &str) -> ContextItem {
402 let name = Self::extract_name(node, source_bytes, lang_name)
404 .unwrap_or_else(|| "<anonymous>".to_string());
405
406 let kind = Self::node_to_context_kind(node, lang_name);
408
409 let start_pos = node.start_position();
411 let end_pos = node.end_position();
412
413 let start_line = start_pos
415 .row
416 .try_into()
417 .unwrap_or(u32::MAX)
418 .saturating_add(1);
419 let end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
420
421 ContextItem::new(
422 name,
423 kind,
424 start_line,
425 end_line,
426 node.start_byte(),
427 node.end_byte(),
428 )
429 }
430
431 fn is_named_scope(node: &Node, lang_name: &str) -> bool {
433 let kind = node.kind();
434
435 if matches!(kind, "source_file" | "program" | "module") {
437 return false;
438 }
439
440 match lang_name {
441 "rust" => matches!(
442 kind,
443 "function_item"
444 | "impl_item"
445 | "trait_item"
446 | "struct_item"
447 | "enum_item"
448 | "mod_item"
449 ),
450 "javascript" | "typescript" => matches!(
451 kind,
452 "function_declaration"
453 | "method_definition"
454 | "class_declaration"
455 | "lexical_declaration"
456 ),
457 "python" => matches!(kind, "function_definition" | "class_definition"),
458 "go" => matches!(
459 kind,
460 "function_declaration" | "method_declaration" | "type_declaration"
461 ),
462 _ => false,
463 }
464 }
465
466 fn identifier_kinds(lang_name: &str) -> &'static [&'static str] {
468 match lang_name {
469 "rust" => &["identifier", "type_identifier"],
470 "javascript" | "typescript" => &["identifier", "property_identifier"],
471 "python" | "go" => &["identifier"],
472 _ => &[],
473 }
474 }
475
476 fn extract_name(node: &Node, source_bytes: &[u8], lang_name: &str) -> Option<String> {
478 let kinds = Self::identifier_kinds(lang_name);
479 if kinds.is_empty() {
480 return None;
481 }
482
483 let mut cursor = node.walk();
484 node.children(&mut cursor)
485 .find(|child| kinds.contains(&child.kind()))
486 .and_then(|child| child.utf8_text(source_bytes).ok())
487 .map(std::string::ToString::to_string)
488 }
489
490 fn node_to_context_kind(node: &Node, lang_name: &str) -> ContextKind {
492 if lang_name == "rust" && node.kind() == "function_item" {
493 let mut current = node.parent();
494 while let Some(parent) = current {
495 if matches!(parent.kind(), "impl_item" | "trait_item") {
496 return ContextKind::Method;
497 }
498 current = parent.parent();
499 }
500 }
501
502 Self::node_kind_to_context_kind(node.kind(), lang_name)
503 }
504
505 fn node_kind_to_context_kind(node_kind: &str, lang_name: &str) -> ContextKind {
506 match lang_name {
507 "rust" => match node_kind {
508 "impl_item" => ContextKind::Impl,
509 "trait_item" => ContextKind::Trait,
510 "struct_item" => ContextKind::Struct,
511 "enum_item" => ContextKind::Enum,
512 "mod_item" => ContextKind::Module,
513 "const_item" => ContextKind::Constant,
514 "static_item" => ContextKind::Variable,
515 "type_item" => ContextKind::TypeAlias,
516 _ => ContextKind::Function,
517 },
518 "javascript" | "typescript" => match node_kind {
519 "method_definition" => ContextKind::Method,
520 "class_declaration" => ContextKind::Class,
521 "lexical_declaration" | "variable_declaration" => ContextKind::Variable,
522 _ => ContextKind::Function,
523 },
524 "python" => match node_kind {
525 "class_definition" => ContextKind::Class,
526 _ => ContextKind::Function,
527 },
528 "go" => match node_kind {
529 "method_declaration" => ContextKind::Method,
530 "type_declaration" => ContextKind::Struct,
531 _ => ContextKind::Function,
532 },
533 _ => ContextKind::Function,
534 }
535 }
536
537 pub fn extract_from_directory(&self, root: &Path) -> Result<Vec<ContextualMatch>> {
544 let mut all_matches = Vec::new();
545
546 for entry in walkdir::WalkDir::new(root)
547 .follow_links(false)
548 .into_iter()
549 .filter_map(std::result::Result::ok)
550 {
551 let path = entry.path();
552 if path.is_file() {
553 if let Ok(mut matches) = self.extract_from_file(path) {
555 all_matches.append(&mut matches);
556 }
557 }
558 }
559
560 Ok(all_matches)
561 }
562}
563
564impl Default for ContextExtractor {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570#[cfg(all(test, feature = "context-tests"))]
589mod tests {
590 use super::*;
591 use std::fs;
592 use tempfile::TempDir;
593
594 fn create_test_plugin_manager() -> crate::plugin::PluginManager {
601 crate::test_support::plugin_factory::with_builtin_plugins()
605 }
606
607 #[test]
608 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
609 fn test_extract_rust_function_context() {
610 let dir = TempDir::new().unwrap();
611 let file_path = dir.path().join("test.rs");
612 fs::write(
613 &file_path,
614 r#"
615fn top_level() {
616 println!("hello");
617}
618
619struct MyStruct {
620 value: i32,
621}
622
623impl MyStruct {
624 fn method(&self) -> i32 {
625 self.value
626 }
627}
628"#,
629 )
630 .unwrap();
631
632 let manager = create_test_plugin_manager();
633 let extractor = ContextExtractor::with_plugin_manager(manager);
634 let matches = extractor.extract_from_file(&file_path).unwrap();
635
636 assert!(
638 matches.len() >= 2,
639 "Expected at least 2 matches, found {}",
640 matches.len()
641 );
642
643 let top_level = matches.iter().find(|m| m.name == "top_level");
645 assert!(top_level.is_some(), "Should find top_level function");
646 if let Some(m) = top_level {
647 assert_eq!(m.context.depth(), 1, "top_level should be at depth 1");
648 assert_eq!(m.context.path(), "top_level");
649 }
650
651 let method = matches.iter().find(|m| m.name == "method");
653 if let Some(m) = method {
654 assert!(m.context.depth() >= 1, "method should have depth >= 1");
655 assert!(m.context.parent.is_some(), "method should have a parent");
656 }
657 }
658
659 #[test]
660 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
661 fn test_extract_nested_context() {
662 let dir = TempDir::new().unwrap();
663 let file_path = dir.path().join("test.rs");
664 fs::write(
665 &file_path,
666 r"
667mod outer {
668 struct Inner {
669 value: i32,
670 }
671
672 impl Inner {
673 fn deeply_nested(&self) {
674 // nested function
675 }
676 }
677}
678",
679 )
680 .unwrap();
681
682 let manager = create_test_plugin_manager();
683 let extractor = ContextExtractor::with_plugin_manager(manager);
684 let matches = extractor.extract_from_file(&file_path).unwrap();
685
686 let method = matches.iter().find(|m| m.name == "deeply_nested");
688 if let Some(m) = method {
689 assert!(m.context.parent.is_some(), "Should have parent");
691 assert!(m.context.depth() >= 1, "Should have depth >= 1");
693 }
694 }
695
696 #[test]
701 #[ignore = "JavaScript plugin not registered in test helper"]
702 fn test_extract_javascript_class() {
703 let dir = TempDir::new().unwrap();
704 let file_path = dir.path().join("test.js");
705 fs::write(
706 &file_path,
707 r#"
708function topLevel() {
709 console.log("hello");
710}
711
712class MyClass {
713 constructor(name) {
714 this.name = name;
715 }
716
717 greet() {
718 console.log("Hello " + this.name);
719 }
720}
721"#,
722 )
723 .unwrap();
724
725 let manager = create_test_plugin_manager();
726 let extractor = ContextExtractor::with_plugin_manager(manager);
727 let matches = extractor.extract_from_file(&file_path).unwrap();
728
729 assert!(matches.len() >= 2, "Should find at least 2 matches");
730
731 let top_fn = matches.iter().find(|m| m.name == "topLevel");
733 if let Some(m) = top_fn {
734 assert_eq!(m.context.depth(), 1, "topLevel should be at depth 1");
735 }
736
737 let class = matches.iter().find(|m| m.name == "MyClass");
739 assert!(class.is_some(), "Should find MyClass");
740 }
741
742 #[test]
743 #[ignore = "Python plugin not registered in test helper"]
744 fn test_extract_python_context() {
745 let dir = TempDir::new().unwrap();
746 let file_path = dir.path().join("test.py");
747 fs::write(
748 &file_path,
749 r#"
750def top_level():
751 print("hello")
752
753class MyClass:
754 def method(self):
755 return 42
756"#,
757 )
758 .unwrap();
759
760 let manager = create_test_plugin_manager();
761 let extractor = ContextExtractor::with_plugin_manager(manager);
762 let matches = extractor.extract_from_file(&file_path).unwrap();
763
764 assert!(matches.len() >= 2, "Should find at least 2 matches");
765
766 let top_fn = matches.iter().find(|m| m.name == "top_level");
768 if let Some(m) = top_fn {
769 assert_eq!(m.context.depth(), 1);
770 }
771
772 let method = matches.iter().find(|m| m.name == "method");
774 if let Some(m) = method {
775 assert!(m.context.depth() >= 2);
776 assert!(m.context.parent.is_some());
777 }
778 }
779
780 #[test]
781 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
782 fn test_empty_file() {
783 let dir = TempDir::new().unwrap();
784 let file_path = dir.path().join("empty.rs");
785 fs::write(&file_path, "").unwrap();
786
787 let manager = create_test_plugin_manager();
788 let extractor = ContextExtractor::with_plugin_manager(manager);
789 let matches = extractor.extract_from_file(&file_path).unwrap();
790
791 assert_eq!(matches.len(), 0);
792 }
793
794 #[test]
797 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
798 fn test_position_matching_single_line_function() {
799 let dir = TempDir::new().unwrap();
801 let file_path = dir.path().join("test.rs");
802 fs::write(
803 &file_path,
804 r#"
805fn single_line() { println!("hello"); }
806"#,
807 )
808 .unwrap();
809
810 let manager = create_test_plugin_manager();
811 let extractor = ContextExtractor::with_plugin_manager(manager);
812 let matches = extractor.extract_from_file(&file_path).unwrap();
813
814 let func = matches.iter().find(|m| m.name == "single_line");
816 assert!(func.is_some(), "Should find single-line function");
817
818 if let Some(m) = func {
819 assert_eq!(m.context.depth(), 1);
820 assert_eq!(m.context.path(), "single_line");
821 }
822 }
823
824 #[test]
825 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
826 fn test_position_matching_multiline_function() {
827 let dir = TempDir::new().unwrap();
829 let file_path = dir.path().join("test.rs");
830 fs::write(
831 &file_path,
832 r"
833fn multiline() {
834 let x = 1;
835 let y = 2;
836 x + y
837}
838",
839 )
840 .unwrap();
841
842 let manager = create_test_plugin_manager();
843 let extractor = ContextExtractor::with_plugin_manager(manager);
844 let matches = extractor.extract_from_file(&file_path).unwrap();
845
846 let func = matches.iter().find(|m| m.name == "multiline");
848 assert!(func.is_some(), "Should find multi-line function");
849
850 if let Some(m) = func {
851 assert_eq!(m.context.depth(), 1);
852 assert!(m.end_line > m.start_line + 1);
854 }
855 }
856
857 #[test]
858 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
859 fn test_position_matching_nested_structures() {
860 let dir = TempDir::new().unwrap();
862 let file_path = dir.path().join("test.rs");
863 fs::write(
864 &file_path,
865 r"
866mod outer {
867 struct Inner {
868 field: i32,
869 }
870
871 impl Inner {
872 fn method(&self) -> i32 {
873 self.field
874 }
875 }
876}
877",
878 )
879 .unwrap();
880
881 let manager = create_test_plugin_manager();
882 let extractor = ContextExtractor::with_plugin_manager(manager);
883 let matches = extractor.extract_from_file(&file_path).unwrap();
884
885 let method = matches.iter().find(|m| m.name == "method");
887 assert!(method.is_some(), "Should find nested method");
888
889 if let Some(m) = method {
890 assert_eq!(m.context.depth(), 3, "Method should have depth 3");
894 assert!(m.context.parent.is_some(), "Method should have parent");
895 if let Some(parent) = &m.context.parent {
896 assert_eq!(parent.name, "Inner", "Method parent should be Inner impl");
898 }
899 }
900 }
901
902 #[test]
903 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
904 fn test_position_matching_with_comments() {
905 let dir = TempDir::new().unwrap();
907 let file_path = dir.path().join("test.rs");
908 fs::write(
909 &file_path,
910 r#"
911// This is a comment
912/// Documentation comment
913fn documented_function() {
914 // Internal comment
915 println!("test");
916}
917"#,
918 )
919 .unwrap();
920
921 let manager = create_test_plugin_manager();
922 let extractor = ContextExtractor::with_plugin_manager(manager);
923 let matches = extractor.extract_from_file(&file_path).unwrap();
924
925 let func = matches.iter().find(|m| m.name == "documented_function");
927 assert!(func.is_some(), "Should find function with comments");
928
929 if let Some(m) = func {
930 assert_eq!(m.context.depth(), 1);
931 }
932 }
933
934 #[test]
935 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
936 fn test_position_matching_edge_positions() {
937 let dir = TempDir::new().unwrap();
939 let file_path = dir.path().join("test.rs");
940 fs::write(
941 &file_path,
942 r"
943struct Container {
944 value: i32,
945}
946
947impl Container {
948 fn new(val: i32) -> Self {
949 Self { value: val }
950 }
951}
952",
953 )
954 .unwrap();
955
956 let manager = create_test_plugin_manager();
957 let extractor = ContextExtractor::with_plugin_manager(manager);
958 let matches = extractor.extract_from_file(&file_path).unwrap();
959
960 let container = matches.iter().find(|m| m.name == "Container");
962 let new_method = matches.iter().find(|m| m.name == "new");
963
964 assert!(container.is_some(), "Should find Container struct");
965 assert!(new_method.is_some(), "Should find new method");
966
967 if let Some(m) = new_method {
968 assert_eq!(m.context.depth(), 2, "Method should have depth 2");
972 if let Some(parent) = &m.context.parent {
973 assert_eq!(
974 parent.name, "Container",
975 "Method parent should be Container impl"
976 );
977 }
978 }
979 }
980}