1use std::path::Path;
9use tree_sitter::Node;
10
11use super::error::{AstQueryError, Result};
12use super::types::{Context, ContextItem, ContextKind, ContextualMatch, ContextualMatchLocation};
13use crate::graph::unified::build::StagingGraph;
14use crate::graph::unified::concurrent::CodeGraph;
15use crate::plugin::PluginManager;
16
17pub struct ContextExtractor {
33 plugin_manager: PluginManager,
35}
36
37impl ContextExtractor {
38 #[must_use]
45 pub fn new() -> Self {
46 Self::with_plugin_manager(PluginManager::new())
47 }
48
49 #[must_use]
69 pub fn with_plugin_manager(plugin_manager: PluginManager) -> Self {
70 Self { plugin_manager }
71 }
72
73 #[allow(clippy::too_many_lines)]
108 pub fn extract_from_file(&self, path: &Path) -> Result<Vec<ContextualMatch>> {
109 let plugin = self.plugin_manager.plugin_for_path(path).ok_or_else(|| {
111 AstQueryError::ContextExtraction(format!(
112 "No plugin found for path: {}",
113 path.display()
114 ))
115 })?;
116
117 let raw_content = std::fs::read(path)?;
119
120 let lang_name = plugin.metadata().id;
122
123 let (prepared_content, tree) = plugin
125 .prepare_ast(&raw_content)
126 .map_err(|e| AstQueryError::ContextExtraction(format!("Failed to parse AST: {e:?}")))?;
127 let parse_content = prepared_content.as_ref();
128
129 let builder = plugin.graph_builder().ok_or_else(|| {
130 AstQueryError::ContextExtraction(format!("No graph builder registered for {lang_name}"))
131 })?;
132
133 let mut staging = StagingGraph::new();
135 builder
136 .build_graph(&tree, parse_content, path, &mut staging)
137 .map_err(|e| {
138 AstQueryError::ContextExtraction(format!(
139 "Failed to build graph for {}: {e}",
140 path.display()
141 ))
142 })?;
143
144 staging.attach_body_hashes(&raw_content);
145
146 let mut graph = CodeGraph::new();
147 let file_id = graph
148 .files_mut()
149 .register_with_language(path, Some(builder.language()))
150 .map_err(|e| {
151 AstQueryError::ContextExtraction(format!(
152 "Failed to register file {}: {e}",
153 path.display()
154 ))
155 })?;
156 staging.apply_file_id(file_id);
157
158 let string_remap = staging.commit_strings(graph.strings_mut()).map_err(|e| {
159 AstQueryError::ContextExtraction(format!(
160 "Failed to commit strings for {}: {e}",
161 path.display()
162 ))
163 })?;
164 staging.apply_string_remap(&string_remap).map_err(|e| {
165 AstQueryError::ContextExtraction(format!(
166 "Failed to remap strings for {}: {e}",
167 path.display()
168 ))
169 })?;
170 let _node_id_map = staging.commit_nodes(graph.nodes_mut()).map_err(|e| {
171 AstQueryError::ContextExtraction(format!(
172 "Failed to commit nodes for {}: {e}",
173 path.display()
174 ))
175 })?;
176
177 let content_str = String::from_utf8_lossy(&raw_content);
179 let root_node = tree.root_node();
180
181 let mut contextual_matches = Vec::new();
183 for (_, entry) in graph.nodes().iter() {
184 if ContextKind::from_node_kind(entry.kind).is_none() {
185 continue;
186 }
187 if entry.start_line == 0 {
188 continue;
189 }
190 let start_line = entry.start_line;
191 let start_column = entry.start_column;
192 let mut node = Self::find_defining_node(root_node, start_line, start_column, lang_name);
193
194 if node.is_none()
195 && Self::looks_like_byte_span(
196 entry.start_line,
197 entry.end_line,
198 entry.start_column,
199 entry.end_column,
200 &content_str,
201 )
202 {
203 node = Self::find_defining_node_by_bytes(
204 root_node,
205 entry.start_column as usize,
206 entry.end_column as usize,
207 lang_name,
208 );
209 }
210
211 if let Some(node) = node {
212 let semantic_context = Self::build_context(&node, &content_str, lang_name);
214 let match_name = semantic_context.immediate.name.clone();
215
216 let location = ContextualMatchLocation::new(
217 path.to_path_buf(),
218 entry.start_line,
219 entry.start_column,
220 entry.end_line,
221 entry.end_column,
222 );
223 contextual_matches.push(ContextualMatch::new(
224 match_name,
225 location,
226 semantic_context,
227 lang_name.to_string(),
228 ));
229 }
230 }
231
232 Ok(contextual_matches)
233 }
234
235 fn find_defining_node<'a>(
239 root: Node<'a>,
240 line: u32,
241 column: u32,
242 lang_name: &str,
243 ) -> Option<Node<'a>> {
244 let mut cursor = root.walk();
245 Self::find_defining_node_recursive(root, line, column, lang_name, &mut cursor)
246 }
247
248 fn find_defining_node_by_bytes<'a>(
249 root: Node<'a>,
250 start: usize,
251 end: usize,
252 lang_name: &str,
253 ) -> Option<Node<'a>> {
254 let target = root.descendant_for_byte_range(start, end)?;
255 let mut current = Some(target);
256
257 while let Some(node) = current {
258 if Self::is_named_scope(&node, lang_name) {
259 return Some(node);
260 }
261 current = node.parent();
262 }
263
264 None
265 }
266
267 fn looks_like_byte_span(
268 start_line: u32,
269 end_line: u32,
270 start_column: u32,
271 end_column: u32,
272 source: &str,
273 ) -> bool {
274 if start_line != 1 || end_line != 1 {
275 return false;
276 }
277 let first_line_len = source.lines().next().map_or(0, str::len);
278 let start = start_column as usize;
279 let end = end_column as usize;
280 start > first_line_len || end > first_line_len
281 }
282
283 fn find_defining_node_recursive<'a>(
285 node: Node<'a>,
286 line: u32,
287 column: u32,
288 lang_name: &str,
289 cursor: &mut tree_sitter::TreeCursor<'a>,
290 ) -> Option<Node<'a>> {
291 let start_pos = node.start_position();
294 let end_pos = node.end_position();
295
296 let node_start_line = start_pos
298 .row
299 .try_into()
300 .unwrap_or(u32::MAX)
301 .saturating_add(1);
302 let node_end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
303
304 let line_in_range = line >= node_start_line && line <= node_end_line;
306
307 let start_col: u32 = start_pos.column.try_into().unwrap_or(u32::MAX);
309 let end_col: u32 = end_pos.column.try_into().unwrap_or(u32::MAX);
310
311 let col_in_range = if line == node_start_line && line == node_end_line {
313 column >= start_col && column <= end_col
315 } else if line == node_start_line {
316 column >= start_col
318 } else if line == node_end_line {
319 column <= end_col
321 } else {
322 true
324 };
325
326 if !line_in_range || !col_in_range {
328 return None;
329 }
330
331 let children: Vec<Node<'a>> = node.children(cursor).collect();
333
334 for child in children {
336 let child_end = child.end_position();
337 let child_end_line: u32 = child_end
339 .row
340 .try_into()
341 .unwrap_or(u32::MAX)
342 .saturating_add(1);
343 if child_end_line >= line
344 && let Some(found) =
345 Self::find_defining_node_recursive(child, line, column, lang_name, cursor)
346 {
347 return Some(found);
348 }
349 }
350
351 if Self::is_named_scope(&node, lang_name) {
353 return Some(node);
354 }
355
356 None
357 }
358
359 fn build_context(node: &Node, source: &str, lang_name: &str) -> Context {
363 let source_bytes = source.as_bytes();
364
365 let immediate = Self::node_to_context_item(node, source_bytes, lang_name);
367
368 let mut parent = None;
370 let mut ancestors = Vec::new();
371 let mut current = node.parent();
372
373 while let Some(node) = current {
374 if Self::is_named_scope(&node, lang_name) {
376 let item = Self::node_to_context_item(&node, source_bytes, lang_name);
377
378 if parent.is_none() {
379 parent = Some(item);
380 } else {
381 ancestors.push(item);
382 }
383 }
384
385 current = node.parent();
386 }
387
388 Context::new(immediate, parent, ancestors)
389 }
390
391 fn node_to_context_item(node: &Node, source_bytes: &[u8], lang_name: &str) -> ContextItem {
393 let name = Self::extract_name(node, source_bytes, lang_name)
395 .unwrap_or_else(|| "<anonymous>".to_string());
396
397 let kind = Self::node_to_context_kind(node, lang_name);
399
400 let start_pos = node.start_position();
402 let end_pos = node.end_position();
403
404 let start_line = start_pos
406 .row
407 .try_into()
408 .unwrap_or(u32::MAX)
409 .saturating_add(1);
410 let end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
411
412 ContextItem::new(
413 name,
414 kind,
415 start_line,
416 end_line,
417 node.start_byte(),
418 node.end_byte(),
419 )
420 }
421
422 fn is_named_scope(node: &Node, lang_name: &str) -> bool {
424 let kind = node.kind();
425
426 if matches!(kind, "source_file" | "program" | "module") {
428 return false;
429 }
430
431 match lang_name {
432 "rust" => matches!(
433 kind,
434 "function_item"
435 | "impl_item"
436 | "trait_item"
437 | "struct_item"
438 | "enum_item"
439 | "mod_item"
440 ),
441 "javascript" | "typescript" => matches!(
442 kind,
443 "function_declaration"
444 | "method_definition"
445 | "class_declaration"
446 | "lexical_declaration"
447 ),
448 "python" => matches!(kind, "function_definition" | "class_definition"),
449 "go" => matches!(
450 kind,
451 "function_declaration" | "method_declaration" | "type_declaration"
452 ),
453 _ => false,
454 }
455 }
456
457 fn identifier_kinds(lang_name: &str) -> &'static [&'static str] {
459 match lang_name {
460 "rust" => &["identifier", "type_identifier"],
461 "javascript" | "typescript" => &["identifier", "property_identifier"],
462 "python" | "go" => &["identifier"],
463 _ => &[],
464 }
465 }
466
467 fn extract_name(node: &Node, source_bytes: &[u8], lang_name: &str) -> Option<String> {
469 let kinds = Self::identifier_kinds(lang_name);
470 if kinds.is_empty() {
471 return None;
472 }
473
474 let mut cursor = node.walk();
475 node.children(&mut cursor)
476 .find(|child| kinds.contains(&child.kind()))
477 .and_then(|child| child.utf8_text(source_bytes).ok())
478 .map(std::string::ToString::to_string)
479 }
480
481 fn node_to_context_kind(node: &Node, lang_name: &str) -> ContextKind {
483 if lang_name == "rust" && node.kind() == "function_item" {
484 let mut current = node.parent();
485 while let Some(parent) = current {
486 if matches!(parent.kind(), "impl_item" | "trait_item") {
487 return ContextKind::Method;
488 }
489 current = parent.parent();
490 }
491 }
492
493 Self::node_kind_to_context_kind(node.kind(), lang_name)
494 }
495
496 fn node_kind_to_context_kind(node_kind: &str, lang_name: &str) -> ContextKind {
497 match lang_name {
498 "rust" => match node_kind {
499 "impl_item" => ContextKind::Impl,
500 "trait_item" => ContextKind::Trait,
501 "struct_item" => ContextKind::Struct,
502 "enum_item" => ContextKind::Enum,
503 "mod_item" => ContextKind::Module,
504 "const_item" => ContextKind::Constant,
505 "static_item" => ContextKind::Variable,
506 "type_item" => ContextKind::TypeAlias,
507 _ => ContextKind::Function,
508 },
509 "javascript" | "typescript" => match node_kind {
510 "method_definition" => ContextKind::Method,
511 "class_declaration" => ContextKind::Class,
512 "lexical_declaration" | "variable_declaration" => ContextKind::Variable,
513 _ => ContextKind::Function,
514 },
515 "python" => match node_kind {
516 "class_definition" => ContextKind::Class,
517 _ => ContextKind::Function,
518 },
519 "go" => match node_kind {
520 "method_declaration" => ContextKind::Method,
521 "type_declaration" => ContextKind::Struct,
522 _ => ContextKind::Function,
523 },
524 _ => ContextKind::Function,
525 }
526 }
527
528 pub fn extract_from_directory(&self, root: &Path) -> Result<Vec<ContextualMatch>> {
535 let mut all_matches = Vec::new();
536
537 for entry in walkdir::WalkDir::new(root)
538 .follow_links(false)
539 .into_iter()
540 .filter_map(std::result::Result::ok)
541 {
542 let path = entry.path();
543 if path.is_file() {
544 if let Ok(mut matches) = self.extract_from_file(path) {
546 all_matches.append(&mut matches);
547 }
548 }
549 }
550
551 Ok(all_matches)
552 }
553}
554
555impl Default for ContextExtractor {
556 fn default() -> Self {
557 Self::new()
558 }
559}
560
561#[cfg(all(test, feature = "context-tests"))]
580mod tests {
581 use super::*;
582 use std::fs;
583 use tempfile::TempDir;
584
585 fn create_test_plugin_manager() -> crate::plugin::PluginManager {
592 crate::test_support::plugin_factory::with_builtin_plugins()
596 }
597
598 #[test]
599 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
600 fn test_extract_rust_function_context() {
601 let dir = TempDir::new().unwrap();
602 let file_path = dir.path().join("test.rs");
603 fs::write(
604 &file_path,
605 r#"
606fn top_level() {
607 println!("hello");
608}
609
610struct MyStruct {
611 value: i32,
612}
613
614impl MyStruct {
615 fn method(&self) -> i32 {
616 self.value
617 }
618}
619"#,
620 )
621 .unwrap();
622
623 let manager = create_test_plugin_manager();
624 let extractor = ContextExtractor::with_plugin_manager(manager);
625 let matches = extractor.extract_from_file(&file_path).unwrap();
626
627 assert!(
629 matches.len() >= 2,
630 "Expected at least 2 matches, found {}",
631 matches.len()
632 );
633
634 let top_level = matches.iter().find(|m| m.name == "top_level");
636 assert!(top_level.is_some(), "Should find top_level function");
637 if let Some(m) = top_level {
638 assert_eq!(m.context.depth(), 1, "top_level should be at depth 1");
639 assert_eq!(m.context.path(), "top_level");
640 }
641
642 let method = matches.iter().find(|m| m.name == "method");
644 if let Some(m) = method {
645 assert!(m.context.depth() >= 1, "method should have depth >= 1");
646 assert!(m.context.parent.is_some(), "method should have a parent");
647 }
648 }
649
650 #[test]
651 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
652 fn test_extract_nested_context() {
653 let dir = TempDir::new().unwrap();
654 let file_path = dir.path().join("test.rs");
655 fs::write(
656 &file_path,
657 r"
658mod outer {
659 struct Inner {
660 value: i32,
661 }
662
663 impl Inner {
664 fn deeply_nested(&self) {
665 // nested function
666 }
667 }
668}
669",
670 )
671 .unwrap();
672
673 let manager = create_test_plugin_manager();
674 let extractor = ContextExtractor::with_plugin_manager(manager);
675 let matches = extractor.extract_from_file(&file_path).unwrap();
676
677 let method = matches.iter().find(|m| m.name == "deeply_nested");
679 if let Some(m) = method {
680 assert!(m.context.parent.is_some(), "Should have parent");
682 assert!(m.context.depth() >= 1, "Should have depth >= 1");
684 }
685 }
686
687 #[test]
692 #[ignore = "JavaScript plugin not registered in test helper"]
693 fn test_extract_javascript_class() {
694 let dir = TempDir::new().unwrap();
695 let file_path = dir.path().join("test.js");
696 fs::write(
697 &file_path,
698 r#"
699function topLevel() {
700 console.log("hello");
701}
702
703class MyClass {
704 constructor(name) {
705 this.name = name;
706 }
707
708 greet() {
709 console.log("Hello " + this.name);
710 }
711}
712"#,
713 )
714 .unwrap();
715
716 let manager = create_test_plugin_manager();
717 let extractor = ContextExtractor::with_plugin_manager(manager);
718 let matches = extractor.extract_from_file(&file_path).unwrap();
719
720 assert!(matches.len() >= 2, "Should find at least 2 matches");
721
722 let top_fn = matches.iter().find(|m| m.name == "topLevel");
724 if let Some(m) = top_fn {
725 assert_eq!(m.context.depth(), 1, "topLevel should be at depth 1");
726 }
727
728 let class = matches.iter().find(|m| m.name == "MyClass");
730 assert!(class.is_some(), "Should find MyClass");
731 }
732
733 #[test]
734 #[ignore = "Python plugin not registered in test helper"]
735 fn test_extract_python_context() {
736 let dir = TempDir::new().unwrap();
737 let file_path = dir.path().join("test.py");
738 fs::write(
739 &file_path,
740 r#"
741def top_level():
742 print("hello")
743
744class MyClass:
745 def method(self):
746 return 42
747"#,
748 )
749 .unwrap();
750
751 let manager = create_test_plugin_manager();
752 let extractor = ContextExtractor::with_plugin_manager(manager);
753 let matches = extractor.extract_from_file(&file_path).unwrap();
754
755 assert!(matches.len() >= 2, "Should find at least 2 matches");
756
757 let top_fn = matches.iter().find(|m| m.name == "top_level");
759 if let Some(m) = top_fn {
760 assert_eq!(m.context.depth(), 1);
761 }
762
763 let method = matches.iter().find(|m| m.name == "method");
765 if let Some(m) = method {
766 assert!(m.context.depth() >= 2);
767 assert!(m.context.parent.is_some());
768 }
769 }
770
771 #[test]
772 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
773 fn test_empty_file() {
774 let dir = TempDir::new().unwrap();
775 let file_path = dir.path().join("empty.rs");
776 fs::write(&file_path, "").unwrap();
777
778 let manager = create_test_plugin_manager();
779 let extractor = ContextExtractor::with_plugin_manager(manager);
780 let matches = extractor.extract_from_file(&file_path).unwrap();
781
782 assert_eq!(matches.len(), 0);
783 }
784
785 #[test]
788 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
789 fn test_position_matching_single_line_function() {
790 let dir = TempDir::new().unwrap();
792 let file_path = dir.path().join("test.rs");
793 fs::write(
794 &file_path,
795 r#"
796fn single_line() { println!("hello"); }
797"#,
798 )
799 .unwrap();
800
801 let manager = create_test_plugin_manager();
802 let extractor = ContextExtractor::with_plugin_manager(manager);
803 let matches = extractor.extract_from_file(&file_path).unwrap();
804
805 let func = matches.iter().find(|m| m.name == "single_line");
807 assert!(func.is_some(), "Should find single-line function");
808
809 if let Some(m) = func {
810 assert_eq!(m.context.depth(), 1);
811 assert_eq!(m.context.path(), "single_line");
812 }
813 }
814
815 #[test]
816 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
817 fn test_position_matching_multiline_function() {
818 let dir = TempDir::new().unwrap();
820 let file_path = dir.path().join("test.rs");
821 fs::write(
822 &file_path,
823 r"
824fn multiline() {
825 let x = 1;
826 let y = 2;
827 x + y
828}
829",
830 )
831 .unwrap();
832
833 let manager = create_test_plugin_manager();
834 let extractor = ContextExtractor::with_plugin_manager(manager);
835 let matches = extractor.extract_from_file(&file_path).unwrap();
836
837 let func = matches.iter().find(|m| m.name == "multiline");
839 assert!(func.is_some(), "Should find multi-line function");
840
841 if let Some(m) = func {
842 assert_eq!(m.context.depth(), 1);
843 assert!(m.end_line > m.start_line + 1);
845 }
846 }
847
848 #[test]
849 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
850 fn test_position_matching_nested_structures() {
851 let dir = TempDir::new().unwrap();
853 let file_path = dir.path().join("test.rs");
854 fs::write(
855 &file_path,
856 r"
857mod outer {
858 struct Inner {
859 field: i32,
860 }
861
862 impl Inner {
863 fn method(&self) -> i32 {
864 self.field
865 }
866 }
867}
868",
869 )
870 .unwrap();
871
872 let manager = create_test_plugin_manager();
873 let extractor = ContextExtractor::with_plugin_manager(manager);
874 let matches = extractor.extract_from_file(&file_path).unwrap();
875
876 let method = matches.iter().find(|m| m.name == "method");
878 assert!(method.is_some(), "Should find nested method");
879
880 if let Some(m) = method {
881 assert_eq!(m.context.depth(), 3, "Method should have depth 3");
885 assert!(m.context.parent.is_some(), "Method should have parent");
886 if let Some(parent) = &m.context.parent {
887 assert_eq!(parent.name, "Inner", "Method parent should be Inner impl");
889 }
890 }
891 }
892
893 #[test]
894 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
895 fn test_position_matching_with_comments() {
896 let dir = TempDir::new().unwrap();
898 let file_path = dir.path().join("test.rs");
899 fs::write(
900 &file_path,
901 r#"
902// This is a comment
903/// Documentation comment
904fn documented_function() {
905 // Internal comment
906 println!("test");
907}
908"#,
909 )
910 .unwrap();
911
912 let manager = create_test_plugin_manager();
913 let extractor = ContextExtractor::with_plugin_manager(manager);
914 let matches = extractor.extract_from_file(&file_path).unwrap();
915
916 let func = matches.iter().find(|m| m.name == "documented_function");
918 assert!(func.is_some(), "Should find function with comments");
919
920 if let Some(m) = func {
921 assert_eq!(m.context.depth(), 1);
922 }
923 }
924
925 #[test]
926 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
927 fn test_position_matching_edge_positions() {
928 let dir = TempDir::new().unwrap();
930 let file_path = dir.path().join("test.rs");
931 fs::write(
932 &file_path,
933 r"
934struct Container {
935 value: i32,
936}
937
938impl Container {
939 fn new(val: i32) -> Self {
940 Self { value: val }
941 }
942}
943",
944 )
945 .unwrap();
946
947 let manager = create_test_plugin_manager();
948 let extractor = ContextExtractor::with_plugin_manager(manager);
949 let matches = extractor.extract_from_file(&file_path).unwrap();
950
951 let container = matches.iter().find(|m| m.name == "Container");
953 let new_method = matches.iter().find(|m| m.name == "new");
954
955 assert!(container.is_some(), "Should find Container struct");
956 assert!(new_method.is_some(), "Should find new method");
957
958 if let Some(m) = new_method {
959 assert_eq!(m.context.depth(), 2, "Method should have depth 2");
963 if let Some(parent) = &m.context.parent {
964 assert_eq!(
965 parent.name, "Container",
966 "Method parent should be Container impl"
967 );
968 }
969 }
970 }
971}