1use std::path::Path;
9use tree_sitter::Node;
10
11use super::error::{AstQueryError, Result};
12use super::types::{Context, ContextItem, ContextKind, ContextualMatch, ContextualMatchLocation};
13use crate::graph::unified::build::StagingGraph;
14use crate::graph::unified::concurrent::CodeGraph;
15use crate::plugin::PluginManager;
16
17pub struct ContextExtractor {
33 plugin_manager: PluginManager,
35}
36
37impl ContextExtractor {
38 #[must_use]
45 pub fn new() -> Self {
46 Self::with_plugin_manager(PluginManager::new())
47 }
48
49 #[must_use]
69 pub fn with_plugin_manager(plugin_manager: PluginManager) -> Self {
70 Self { plugin_manager }
71 }
72
73 #[allow(clippy::too_many_lines)]
108 pub fn extract_from_file(&self, path: &Path) -> Result<Vec<ContextualMatch>> {
109 let plugin = self.plugin_manager.plugin_for_path(path).ok_or_else(|| {
111 AstQueryError::ContextExtraction(format!(
112 "No plugin found for path: {}",
113 path.display()
114 ))
115 })?;
116
117 let content = std::fs::read(path)?;
119
120 let lang_name = plugin.metadata().id;
122
123 let tree = plugin
125 .parse_ast(&content)
126 .map_err(|e| AstQueryError::ContextExtraction(format!("Failed to parse AST: {e:?}")))?;
127
128 let builder = plugin.graph_builder().ok_or_else(|| {
129 AstQueryError::ContextExtraction(format!("No graph builder registered for {lang_name}"))
130 })?;
131
132 let mut staging = StagingGraph::new();
134 builder
135 .build_graph(&tree, &content, path, &mut staging)
136 .map_err(|e| {
137 AstQueryError::ContextExtraction(format!(
138 "Failed to build graph for {}: {e}",
139 path.display()
140 ))
141 })?;
142
143 staging.attach_body_hashes(&content);
144
145 let mut graph = CodeGraph::new();
146 let file_id = graph
147 .files_mut()
148 .register_with_language(path, Some(builder.language()))
149 .map_err(|e| {
150 AstQueryError::ContextExtraction(format!(
151 "Failed to register file {}: {e}",
152 path.display()
153 ))
154 })?;
155 staging.apply_file_id(file_id);
156
157 let string_remap = staging.commit_strings(graph.strings_mut()).map_err(|e| {
158 AstQueryError::ContextExtraction(format!(
159 "Failed to commit strings for {}: {e}",
160 path.display()
161 ))
162 })?;
163 staging.apply_string_remap(&string_remap).map_err(|e| {
164 AstQueryError::ContextExtraction(format!(
165 "Failed to remap strings for {}: {e}",
166 path.display()
167 ))
168 })?;
169 let _node_id_map = staging.commit_nodes(graph.nodes_mut()).map_err(|e| {
170 AstQueryError::ContextExtraction(format!(
171 "Failed to commit nodes for {}: {e}",
172 path.display()
173 ))
174 })?;
175
176 let content_str = String::from_utf8_lossy(&content);
178 let root_node = tree.root_node();
179
180 let mut contextual_matches = Vec::new();
182 for (_, entry) in graph.nodes().iter() {
183 if ContextKind::from_node_kind(entry.kind).is_none() {
184 continue;
185 }
186 if entry.start_line == 0 {
187 continue;
188 }
189 let start_line = entry.start_line;
190 let start_column = entry.start_column;
191 let mut node = Self::find_defining_node(root_node, start_line, start_column, lang_name);
192
193 if node.is_none()
194 && Self::looks_like_byte_span(
195 entry.start_line,
196 entry.end_line,
197 entry.start_column,
198 entry.end_column,
199 &content_str,
200 )
201 {
202 node = Self::find_defining_node_by_bytes(
203 root_node,
204 entry.start_column as usize,
205 entry.end_column as usize,
206 lang_name,
207 );
208 }
209
210 if let Some(node) = node {
211 let semantic_context = Self::build_context(&node, &content_str, lang_name);
213 let match_name = semantic_context.immediate.name.clone();
214
215 let location = ContextualMatchLocation::new(
216 path.to_path_buf(),
217 entry.start_line,
218 entry.start_column,
219 entry.end_line,
220 entry.end_column,
221 );
222 contextual_matches.push(ContextualMatch::new(
223 match_name,
224 location,
225 semantic_context,
226 lang_name.to_string(),
227 ));
228 }
229 }
230
231 Ok(contextual_matches)
232 }
233
234 fn find_defining_node<'a>(
238 root: Node<'a>,
239 line: u32,
240 column: u32,
241 lang_name: &str,
242 ) -> Option<Node<'a>> {
243 let mut cursor = root.walk();
244 Self::find_defining_node_recursive(root, line, column, lang_name, &mut cursor)
245 }
246
247 fn find_defining_node_by_bytes<'a>(
248 root: Node<'a>,
249 start: usize,
250 end: usize,
251 lang_name: &str,
252 ) -> Option<Node<'a>> {
253 let target = root.descendant_for_byte_range(start, end)?;
254 let mut current = Some(target);
255
256 while let Some(node) = current {
257 if Self::is_named_scope(&node, lang_name) {
258 return Some(node);
259 }
260 current = node.parent();
261 }
262
263 None
264 }
265
266 fn looks_like_byte_span(
267 start_line: u32,
268 end_line: u32,
269 start_column: u32,
270 end_column: u32,
271 source: &str,
272 ) -> bool {
273 if start_line != 1 || end_line != 1 {
274 return false;
275 }
276 let first_line_len = source.lines().next().map_or(0, str::len);
277 let start = start_column as usize;
278 let end = end_column as usize;
279 start > first_line_len || end > first_line_len
280 }
281
282 fn find_defining_node_recursive<'a>(
284 node: Node<'a>,
285 line: u32,
286 column: u32,
287 lang_name: &str,
288 cursor: &mut tree_sitter::TreeCursor<'a>,
289 ) -> Option<Node<'a>> {
290 let start_pos = node.start_position();
293 let end_pos = node.end_position();
294
295 let node_start_line = start_pos
297 .row
298 .try_into()
299 .unwrap_or(u32::MAX)
300 .saturating_add(1);
301 let node_end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
302
303 let line_in_range = line >= node_start_line && line <= node_end_line;
305
306 let start_col: u32 = start_pos.column.try_into().unwrap_or(u32::MAX);
308 let end_col: u32 = end_pos.column.try_into().unwrap_or(u32::MAX);
309
310 let col_in_range = if line == node_start_line && line == node_end_line {
312 column >= start_col && column <= end_col
314 } else if line == node_start_line {
315 column >= start_col
317 } else if line == node_end_line {
318 column <= end_col
320 } else {
321 true
323 };
324
325 if !line_in_range || !col_in_range {
327 return None;
328 }
329
330 let children: Vec<Node<'a>> = node.children(cursor).collect();
332
333 for child in children {
335 let child_end = child.end_position();
336 let child_end_line: u32 = child_end
338 .row
339 .try_into()
340 .unwrap_or(u32::MAX)
341 .saturating_add(1);
342 if child_end_line >= line
343 && let Some(found) =
344 Self::find_defining_node_recursive(child, line, column, lang_name, cursor)
345 {
346 return Some(found);
347 }
348 }
349
350 if Self::is_named_scope(&node, lang_name) {
352 return Some(node);
353 }
354
355 None
356 }
357
358 fn build_context(node: &Node, source: &str, lang_name: &str) -> Context {
362 let source_bytes = source.as_bytes();
363
364 let immediate = Self::node_to_context_item(node, source_bytes, lang_name);
366
367 let mut parent = None;
369 let mut ancestors = Vec::new();
370 let mut current = node.parent();
371
372 while let Some(node) = current {
373 if Self::is_named_scope(&node, lang_name) {
375 let item = Self::node_to_context_item(&node, source_bytes, lang_name);
376
377 if parent.is_none() {
378 parent = Some(item);
379 } else {
380 ancestors.push(item);
381 }
382 }
383
384 current = node.parent();
385 }
386
387 Context::new(immediate, parent, ancestors)
388 }
389
390 fn node_to_context_item(node: &Node, source_bytes: &[u8], lang_name: &str) -> ContextItem {
392 let name = Self::extract_name(node, source_bytes, lang_name)
394 .unwrap_or_else(|| "<anonymous>".to_string());
395
396 let kind = Self::node_to_context_kind(node, lang_name);
398
399 let start_pos = node.start_position();
401 let end_pos = node.end_position();
402
403 let start_line = start_pos
405 .row
406 .try_into()
407 .unwrap_or(u32::MAX)
408 .saturating_add(1);
409 let end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
410
411 ContextItem::new(
412 name,
413 kind,
414 start_line,
415 end_line,
416 node.start_byte(),
417 node.end_byte(),
418 )
419 }
420
421 fn is_named_scope(node: &Node, lang_name: &str) -> bool {
423 let kind = node.kind();
424
425 if matches!(kind, "source_file" | "program" | "module") {
427 return false;
428 }
429
430 match lang_name {
431 "rust" => matches!(
432 kind,
433 "function_item"
434 | "impl_item"
435 | "trait_item"
436 | "struct_item"
437 | "enum_item"
438 | "mod_item"
439 ),
440 "javascript" | "typescript" => matches!(
441 kind,
442 "function_declaration"
443 | "method_definition"
444 | "class_declaration"
445 | "lexical_declaration"
446 ),
447 "python" => matches!(kind, "function_definition" | "class_definition"),
448 "go" => matches!(
449 kind,
450 "function_declaration" | "method_declaration" | "type_declaration"
451 ),
452 _ => false,
453 }
454 }
455
456 fn identifier_kinds(lang_name: &str) -> &'static [&'static str] {
458 match lang_name {
459 "rust" => &["identifier", "type_identifier"],
460 "javascript" | "typescript" => &["identifier", "property_identifier"],
461 "python" | "go" => &["identifier"],
462 _ => &[],
463 }
464 }
465
466 fn extract_name(node: &Node, source_bytes: &[u8], lang_name: &str) -> Option<String> {
468 let kinds = Self::identifier_kinds(lang_name);
469 if kinds.is_empty() {
470 return None;
471 }
472
473 let mut cursor = node.walk();
474 node.children(&mut cursor)
475 .find(|child| kinds.contains(&child.kind()))
476 .and_then(|child| child.utf8_text(source_bytes).ok())
477 .map(std::string::ToString::to_string)
478 }
479
480 fn node_to_context_kind(node: &Node, lang_name: &str) -> ContextKind {
482 if lang_name == "rust" && node.kind() == "function_item" {
483 let mut current = node.parent();
484 while let Some(parent) = current {
485 if matches!(parent.kind(), "impl_item" | "trait_item") {
486 return ContextKind::Method;
487 }
488 current = parent.parent();
489 }
490 }
491
492 Self::node_kind_to_context_kind(node.kind(), lang_name)
493 }
494
495 fn node_kind_to_context_kind(node_kind: &str, lang_name: &str) -> ContextKind {
496 match lang_name {
497 "rust" => match node_kind {
498 "impl_item" => ContextKind::Impl,
499 "trait_item" => ContextKind::Trait,
500 "struct_item" => ContextKind::Struct,
501 "enum_item" => ContextKind::Enum,
502 "mod_item" => ContextKind::Module,
503 "const_item" => ContextKind::Constant,
504 "static_item" => ContextKind::Variable,
505 "type_item" => ContextKind::TypeAlias,
506 _ => ContextKind::Function,
507 },
508 "javascript" | "typescript" => match node_kind {
509 "method_definition" => ContextKind::Method,
510 "class_declaration" => ContextKind::Class,
511 "lexical_declaration" | "variable_declaration" => ContextKind::Variable,
512 _ => ContextKind::Function,
513 },
514 "python" => match node_kind {
515 "class_definition" => ContextKind::Class,
516 _ => ContextKind::Function,
517 },
518 "go" => match node_kind {
519 "method_declaration" => ContextKind::Method,
520 "type_declaration" => ContextKind::Struct,
521 _ => ContextKind::Function,
522 },
523 _ => ContextKind::Function,
524 }
525 }
526
527 pub fn extract_from_directory(&self, root: &Path) -> Result<Vec<ContextualMatch>> {
534 let mut all_matches = Vec::new();
535
536 for entry in walkdir::WalkDir::new(root)
537 .follow_links(false)
538 .into_iter()
539 .filter_map(std::result::Result::ok)
540 {
541 let path = entry.path();
542 if path.is_file() {
543 if let Ok(mut matches) = self.extract_from_file(path) {
545 all_matches.append(&mut matches);
546 }
547 }
548 }
549
550 Ok(all_matches)
551 }
552}
553
554impl Default for ContextExtractor {
555 fn default() -> Self {
556 Self::new()
557 }
558}
559
560#[cfg(all(test, feature = "context-tests"))]
579mod tests {
580 use super::*;
581 use std::fs;
582 use tempfile::TempDir;
583
584 fn create_test_plugin_manager() -> crate::plugin::PluginManager {
591 crate::test_support::plugin_factory::with_builtin_plugins()
595 }
596
597 #[test]
598 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
599 fn test_extract_rust_function_context() {
600 let dir = TempDir::new().unwrap();
601 let file_path = dir.path().join("test.rs");
602 fs::write(
603 &file_path,
604 r#"
605fn top_level() {
606 println!("hello");
607}
608
609struct MyStruct {
610 value: i32,
611}
612
613impl MyStruct {
614 fn method(&self) -> i32 {
615 self.value
616 }
617}
618"#,
619 )
620 .unwrap();
621
622 let manager = create_test_plugin_manager();
623 let extractor = ContextExtractor::with_plugin_manager(manager);
624 let matches = extractor.extract_from_file(&file_path).unwrap();
625
626 assert!(
628 matches.len() >= 2,
629 "Expected at least 2 matches, found {}",
630 matches.len()
631 );
632
633 let top_level = matches.iter().find(|m| m.name == "top_level");
635 assert!(top_level.is_some(), "Should find top_level function");
636 if let Some(m) = top_level {
637 assert_eq!(m.context.depth(), 1, "top_level should be at depth 1");
638 assert_eq!(m.context.path(), "top_level");
639 }
640
641 let method = matches.iter().find(|m| m.name == "method");
643 if let Some(m) = method {
644 assert!(m.context.depth() >= 1, "method should have depth >= 1");
645 assert!(m.context.parent.is_some(), "method should have a parent");
646 }
647 }
648
649 #[test]
650 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
651 fn test_extract_nested_context() {
652 let dir = TempDir::new().unwrap();
653 let file_path = dir.path().join("test.rs");
654 fs::write(
655 &file_path,
656 r"
657mod outer {
658 struct Inner {
659 value: i32,
660 }
661
662 impl Inner {
663 fn deeply_nested(&self) {
664 // nested function
665 }
666 }
667}
668",
669 )
670 .unwrap();
671
672 let manager = create_test_plugin_manager();
673 let extractor = ContextExtractor::with_plugin_manager(manager);
674 let matches = extractor.extract_from_file(&file_path).unwrap();
675
676 let method = matches.iter().find(|m| m.name == "deeply_nested");
678 if let Some(m) = method {
679 assert!(m.context.parent.is_some(), "Should have parent");
681 assert!(m.context.depth() >= 1, "Should have depth >= 1");
683 }
684 }
685
686 #[test]
691 #[ignore = "JavaScript plugin not registered in test helper"]
692 fn test_extract_javascript_class() {
693 let dir = TempDir::new().unwrap();
694 let file_path = dir.path().join("test.js");
695 fs::write(
696 &file_path,
697 r#"
698function topLevel() {
699 console.log("hello");
700}
701
702class MyClass {
703 constructor(name) {
704 this.name = name;
705 }
706
707 greet() {
708 console.log("Hello " + this.name);
709 }
710}
711"#,
712 )
713 .unwrap();
714
715 let manager = create_test_plugin_manager();
716 let extractor = ContextExtractor::with_plugin_manager(manager);
717 let matches = extractor.extract_from_file(&file_path).unwrap();
718
719 assert!(matches.len() >= 2, "Should find at least 2 matches");
720
721 let top_fn = matches.iter().find(|m| m.name == "topLevel");
723 if let Some(m) = top_fn {
724 assert_eq!(m.context.depth(), 1, "topLevel should be at depth 1");
725 }
726
727 let class = matches.iter().find(|m| m.name == "MyClass");
729 assert!(class.is_some(), "Should find MyClass");
730 }
731
732 #[test]
733 #[ignore = "Python plugin not registered in test helper"]
734 fn test_extract_python_context() {
735 let dir = TempDir::new().unwrap();
736 let file_path = dir.path().join("test.py");
737 fs::write(
738 &file_path,
739 r#"
740def top_level():
741 print("hello")
742
743class MyClass:
744 def method(self):
745 return 42
746"#,
747 )
748 .unwrap();
749
750 let manager = create_test_plugin_manager();
751 let extractor = ContextExtractor::with_plugin_manager(manager);
752 let matches = extractor.extract_from_file(&file_path).unwrap();
753
754 assert!(matches.len() >= 2, "Should find at least 2 matches");
755
756 let top_fn = matches.iter().find(|m| m.name == "top_level");
758 if let Some(m) = top_fn {
759 assert_eq!(m.context.depth(), 1);
760 }
761
762 let method = matches.iter().find(|m| m.name == "method");
764 if let Some(m) = method {
765 assert!(m.context.depth() >= 2);
766 assert!(m.context.parent.is_some());
767 }
768 }
769
770 #[test]
771 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
772 fn test_empty_file() {
773 let dir = TempDir::new().unwrap();
774 let file_path = dir.path().join("empty.rs");
775 fs::write(&file_path, "").unwrap();
776
777 let manager = create_test_plugin_manager();
778 let extractor = ContextExtractor::with_plugin_manager(manager);
779 let matches = extractor.extract_from_file(&file_path).unwrap();
780
781 assert_eq!(matches.len(), 0);
782 }
783
784 #[test]
787 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
788 fn test_position_matching_single_line_function() {
789 let dir = TempDir::new().unwrap();
791 let file_path = dir.path().join("test.rs");
792 fs::write(
793 &file_path,
794 r#"
795fn single_line() { println!("hello"); }
796"#,
797 )
798 .unwrap();
799
800 let manager = create_test_plugin_manager();
801 let extractor = ContextExtractor::with_plugin_manager(manager);
802 let matches = extractor.extract_from_file(&file_path).unwrap();
803
804 let func = matches.iter().find(|m| m.name == "single_line");
806 assert!(func.is_some(), "Should find single-line function");
807
808 if let Some(m) = func {
809 assert_eq!(m.context.depth(), 1);
810 assert_eq!(m.context.path(), "single_line");
811 }
812 }
813
814 #[test]
815 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
816 fn test_position_matching_multiline_function() {
817 let dir = TempDir::new().unwrap();
819 let file_path = dir.path().join("test.rs");
820 fs::write(
821 &file_path,
822 r"
823fn multiline() {
824 let x = 1;
825 let y = 2;
826 x + y
827}
828",
829 )
830 .unwrap();
831
832 let manager = create_test_plugin_manager();
833 let extractor = ContextExtractor::with_plugin_manager(manager);
834 let matches = extractor.extract_from_file(&file_path).unwrap();
835
836 let func = matches.iter().find(|m| m.name == "multiline");
838 assert!(func.is_some(), "Should find multi-line function");
839
840 if let Some(m) = func {
841 assert_eq!(m.context.depth(), 1);
842 assert!(m.end_line > m.start_line + 1);
844 }
845 }
846
847 #[test]
848 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
849 fn test_position_matching_nested_structures() {
850 let dir = TempDir::new().unwrap();
852 let file_path = dir.path().join("test.rs");
853 fs::write(
854 &file_path,
855 r"
856mod outer {
857 struct Inner {
858 field: i32,
859 }
860
861 impl Inner {
862 fn method(&self) -> i32 {
863 self.field
864 }
865 }
866}
867",
868 )
869 .unwrap();
870
871 let manager = create_test_plugin_manager();
872 let extractor = ContextExtractor::with_plugin_manager(manager);
873 let matches = extractor.extract_from_file(&file_path).unwrap();
874
875 let method = matches.iter().find(|m| m.name == "method");
877 assert!(method.is_some(), "Should find nested method");
878
879 if let Some(m) = method {
880 assert_eq!(m.context.depth(), 3, "Method should have depth 3");
884 assert!(m.context.parent.is_some(), "Method should have parent");
885 if let Some(parent) = &m.context.parent {
886 assert_eq!(parent.name, "Inner", "Method parent should be Inner impl");
888 }
889 }
890 }
891
892 #[test]
893 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
894 fn test_position_matching_with_comments() {
895 let dir = TempDir::new().unwrap();
897 let file_path = dir.path().join("test.rs");
898 fs::write(
899 &file_path,
900 r#"
901// This is a comment
902/// Documentation comment
903fn documented_function() {
904 // Internal comment
905 println!("test");
906}
907"#,
908 )
909 .unwrap();
910
911 let manager = create_test_plugin_manager();
912 let extractor = ContextExtractor::with_plugin_manager(manager);
913 let matches = extractor.extract_from_file(&file_path).unwrap();
914
915 let func = matches.iter().find(|m| m.name == "documented_function");
917 assert!(func.is_some(), "Should find function with comments");
918
919 if let Some(m) = func {
920 assert_eq!(m.context.depth(), 1);
921 }
922 }
923
924 #[test]
925 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
926 fn test_position_matching_edge_positions() {
927 let dir = TempDir::new().unwrap();
929 let file_path = dir.path().join("test.rs");
930 fs::write(
931 &file_path,
932 r"
933struct Container {
934 value: i32,
935}
936
937impl Container {
938 fn new(val: i32) -> Self {
939 Self { value: val }
940 }
941}
942",
943 )
944 .unwrap();
945
946 let manager = create_test_plugin_manager();
947 let extractor = ContextExtractor::with_plugin_manager(manager);
948 let matches = extractor.extract_from_file(&file_path).unwrap();
949
950 let container = matches.iter().find(|m| m.name == "Container");
952 let new_method = matches.iter().find(|m| m.name == "new");
953
954 assert!(container.is_some(), "Should find Container struct");
955 assert!(new_method.is_some(), "Should find new method");
956
957 if let Some(m) = new_method {
958 assert_eq!(m.context.depth(), 2, "Method should have depth 2");
962 if let Some(parent) = &m.context.parent {
963 assert_eq!(
964 parent.name, "Container",
965 "Method parent should be Container impl"
966 );
967 }
968 }
969 }
970}