1use std::path::Path;
9use tree_sitter::Node;
10
11use super::error::{AstQueryError, Result};
12use super::types::{Context, ContextItem, ContextKind, ContextualMatch, ContextualMatchLocation};
13use crate::graph::unified::build::StagingGraph;
14use crate::graph::unified::concurrent::CodeGraph;
15use crate::plugin::PluginManager;
16
17pub struct ContextExtractor {
33 plugin_manager: PluginManager,
35}
36
37impl ContextExtractor {
38 #[must_use]
45 pub fn new() -> Self {
46 Self::with_plugin_manager(PluginManager::new())
47 }
48
49 #[must_use]
69 pub fn with_plugin_manager(plugin_manager: PluginManager) -> Self {
70 Self { plugin_manager }
71 }
72
73 #[allow(clippy::too_many_lines)]
108 pub fn extract_from_file(&self, path: &Path) -> Result<Vec<ContextualMatch>> {
109 let plugin = self.plugin_manager.plugin_for_path(path).ok_or_else(|| {
111 AstQueryError::ContextExtraction(format!(
112 "No plugin found for path: {}",
113 path.display()
114 ))
115 })?;
116
117 let raw_content = std::fs::read(path)?;
119
120 let lang_name = plugin.metadata().id;
122
123 let (prepared_content, tree) = plugin
125 .prepare_ast(&raw_content)
126 .map_err(|e| AstQueryError::ContextExtraction(format!("Failed to parse AST: {e:?}")))?;
127 let parse_content = prepared_content.as_ref();
128
129 let builder = plugin.graph_builder().ok_or_else(|| {
130 AstQueryError::ContextExtraction(format!("No graph builder registered for {lang_name}"))
131 })?;
132
133 let mut staging = StagingGraph::new();
135 builder
136 .build_graph(&tree, parse_content, path, &mut staging)
137 .map_err(|e| {
138 AstQueryError::ContextExtraction(format!(
139 "Failed to build graph for {}: {e}",
140 path.display()
141 ))
142 })?;
143
144 staging.attach_body_hashes(&raw_content);
145
146 let mut graph = CodeGraph::new();
147 let file_id = graph
148 .files_mut()
149 .register_with_language(path, Some(builder.language()))
150 .map_err(|e| {
151 AstQueryError::ContextExtraction(format!(
152 "Failed to register file {}: {e}",
153 path.display()
154 ))
155 })?;
156 staging.apply_file_id(file_id);
157
158 let string_remap = staging.commit_strings(graph.strings_mut()).map_err(|e| {
159 AstQueryError::ContextExtraction(format!(
160 "Failed to commit strings for {}: {e}",
161 path.display()
162 ))
163 })?;
164 staging.apply_string_remap(&string_remap).map_err(|e| {
165 AstQueryError::ContextExtraction(format!(
166 "Failed to remap strings for {}: {e}",
167 path.display()
168 ))
169 })?;
170 let _node_id_map = staging.commit_nodes(graph.nodes_mut()).map_err(|e| {
171 AstQueryError::ContextExtraction(format!(
172 "Failed to commit nodes for {}: {e}",
173 path.display()
174 ))
175 })?;
176
177 let content_str = String::from_utf8_lossy(&raw_content);
179 let root_node = tree.root_node();
180
181 let mut contextual_matches = Vec::new();
187 for (_, entry) in graph.nodes().iter() {
188 if entry.is_unified_loser() {
189 continue;
190 }
191 if ContextKind::from_node_kind(entry.kind).is_none() {
192 continue;
193 }
194 if entry.start_line == 0 {
195 continue;
196 }
197 let start_line = entry.start_line;
198 let start_column = entry.start_column;
199 let mut node = Self::find_defining_node(root_node, start_line, start_column, lang_name);
200
201 if node.is_none()
202 && Self::looks_like_byte_span(
203 entry.start_line,
204 entry.end_line,
205 entry.start_column,
206 entry.end_column,
207 &content_str,
208 )
209 {
210 node = Self::find_defining_node_by_bytes(
211 root_node,
212 entry.start_column as usize,
213 entry.end_column as usize,
214 lang_name,
215 );
216 }
217
218 if let Some(node) = node {
219 let semantic_context = Self::build_context(&node, &content_str, lang_name);
221 let match_name = semantic_context.immediate.name.clone();
222
223 let location = ContextualMatchLocation::new(
224 path.to_path_buf(),
225 entry.start_line,
226 entry.start_column,
227 entry.end_line,
228 entry.end_column,
229 );
230 contextual_matches.push(ContextualMatch::new(
231 match_name,
232 location,
233 semantic_context,
234 lang_name.to_string(),
235 ));
236 }
237 }
238
239 Ok(contextual_matches)
240 }
241
242 fn find_defining_node<'a>(
246 root: Node<'a>,
247 line: u32,
248 column: u32,
249 lang_name: &str,
250 ) -> Option<Node<'a>> {
251 let mut cursor = root.walk();
252 Self::find_defining_node_recursive(root, line, column, lang_name, &mut cursor)
253 }
254
255 fn find_defining_node_by_bytes<'a>(
256 root: Node<'a>,
257 start: usize,
258 end: usize,
259 lang_name: &str,
260 ) -> Option<Node<'a>> {
261 let target = root.descendant_for_byte_range(start, end)?;
262 let mut current = Some(target);
263
264 while let Some(node) = current {
265 if Self::is_named_scope(&node, lang_name) {
266 return Some(node);
267 }
268 current = node.parent();
269 }
270
271 None
272 }
273
274 fn looks_like_byte_span(
275 start_line: u32,
276 end_line: u32,
277 start_column: u32,
278 end_column: u32,
279 source: &str,
280 ) -> bool {
281 if start_line != 1 || end_line != 1 {
282 return false;
283 }
284 let first_line_len = source.lines().next().map_or(0, str::len);
285 let start = start_column as usize;
286 let end = end_column as usize;
287 start > first_line_len || end > first_line_len
288 }
289
290 fn find_defining_node_recursive<'a>(
292 node: Node<'a>,
293 line: u32,
294 column: u32,
295 lang_name: &str,
296 cursor: &mut tree_sitter::TreeCursor<'a>,
297 ) -> Option<Node<'a>> {
298 let start_pos = node.start_position();
301 let end_pos = node.end_position();
302
303 let node_start_line = start_pos
305 .row
306 .try_into()
307 .unwrap_or(u32::MAX)
308 .saturating_add(1);
309 let node_end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
310
311 let line_in_range = line >= node_start_line && line <= node_end_line;
313
314 let start_col: u32 = start_pos.column.try_into().unwrap_or(u32::MAX);
316 let end_col: u32 = end_pos.column.try_into().unwrap_or(u32::MAX);
317
318 let col_in_range = if line == node_start_line && line == node_end_line {
320 column >= start_col && column <= end_col
322 } else if line == node_start_line {
323 column >= start_col
325 } else if line == node_end_line {
326 column <= end_col
328 } else {
329 true
331 };
332
333 if !line_in_range || !col_in_range {
335 return None;
336 }
337
338 let children: Vec<Node<'a>> = node.children(cursor).collect();
340
341 for child in children {
343 let child_end = child.end_position();
344 let child_end_line: u32 = child_end
346 .row
347 .try_into()
348 .unwrap_or(u32::MAX)
349 .saturating_add(1);
350 if child_end_line >= line
351 && let Some(found) =
352 Self::find_defining_node_recursive(child, line, column, lang_name, cursor)
353 {
354 return Some(found);
355 }
356 }
357
358 if Self::is_named_scope(&node, lang_name) {
360 return Some(node);
361 }
362
363 None
364 }
365
366 fn build_context(node: &Node, source: &str, lang_name: &str) -> Context {
370 let source_bytes = source.as_bytes();
371
372 let immediate = Self::node_to_context_item(node, source_bytes, lang_name);
374
375 let mut parent = None;
377 let mut ancestors = Vec::new();
378 let mut current = node.parent();
379
380 while let Some(node) = current {
381 if Self::is_named_scope(&node, lang_name) {
383 let item = Self::node_to_context_item(&node, source_bytes, lang_name);
384
385 if parent.is_none() {
386 parent = Some(item);
387 } else {
388 ancestors.push(item);
389 }
390 }
391
392 current = node.parent();
393 }
394
395 Context::new(immediate, parent, ancestors)
396 }
397
398 fn node_to_context_item(node: &Node, source_bytes: &[u8], lang_name: &str) -> ContextItem {
400 let name = Self::extract_name(node, source_bytes, lang_name)
402 .unwrap_or_else(|| "<anonymous>".to_string());
403
404 let kind = Self::node_to_context_kind(node, lang_name);
406
407 let start_pos = node.start_position();
409 let end_pos = node.end_position();
410
411 let start_line = start_pos
413 .row
414 .try_into()
415 .unwrap_or(u32::MAX)
416 .saturating_add(1);
417 let end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
418
419 ContextItem::new(
420 name,
421 kind,
422 start_line,
423 end_line,
424 node.start_byte(),
425 node.end_byte(),
426 )
427 }
428
429 fn is_named_scope(node: &Node, lang_name: &str) -> bool {
431 let kind = node.kind();
432
433 if matches!(kind, "source_file" | "program" | "module") {
435 return false;
436 }
437
438 match lang_name {
439 "rust" => matches!(
440 kind,
441 "function_item"
442 | "impl_item"
443 | "trait_item"
444 | "struct_item"
445 | "enum_item"
446 | "mod_item"
447 ),
448 "javascript" | "typescript" => matches!(
449 kind,
450 "function_declaration"
451 | "method_definition"
452 | "class_declaration"
453 | "lexical_declaration"
454 ),
455 "python" => matches!(kind, "function_definition" | "class_definition"),
456 "go" => matches!(
457 kind,
458 "function_declaration" | "method_declaration" | "type_declaration"
459 ),
460 _ => false,
461 }
462 }
463
464 fn identifier_kinds(lang_name: &str) -> &'static [&'static str] {
466 match lang_name {
467 "rust" => &["identifier", "type_identifier"],
468 "javascript" | "typescript" => &["identifier", "property_identifier"],
469 "python" | "go" => &["identifier"],
470 _ => &[],
471 }
472 }
473
474 fn extract_name(node: &Node, source_bytes: &[u8], lang_name: &str) -> Option<String> {
476 let kinds = Self::identifier_kinds(lang_name);
477 if kinds.is_empty() {
478 return None;
479 }
480
481 let mut cursor = node.walk();
482 node.children(&mut cursor)
483 .find(|child| kinds.contains(&child.kind()))
484 .and_then(|child| child.utf8_text(source_bytes).ok())
485 .map(std::string::ToString::to_string)
486 }
487
488 fn node_to_context_kind(node: &Node, lang_name: &str) -> ContextKind {
490 if lang_name == "rust" && node.kind() == "function_item" {
491 let mut current = node.parent();
492 while let Some(parent) = current {
493 if matches!(parent.kind(), "impl_item" | "trait_item") {
494 return ContextKind::Method;
495 }
496 current = parent.parent();
497 }
498 }
499
500 Self::node_kind_to_context_kind(node.kind(), lang_name)
501 }
502
503 fn node_kind_to_context_kind(node_kind: &str, lang_name: &str) -> ContextKind {
504 match lang_name {
505 "rust" => match node_kind {
506 "impl_item" => ContextKind::Impl,
507 "trait_item" => ContextKind::Trait,
508 "struct_item" => ContextKind::Struct,
509 "enum_item" => ContextKind::Enum,
510 "mod_item" => ContextKind::Module,
511 "const_item" => ContextKind::Constant,
512 "static_item" => ContextKind::Variable,
513 "type_item" => ContextKind::TypeAlias,
514 _ => ContextKind::Function,
515 },
516 "javascript" | "typescript" => match node_kind {
517 "method_definition" => ContextKind::Method,
518 "class_declaration" => ContextKind::Class,
519 "lexical_declaration" | "variable_declaration" => ContextKind::Variable,
520 _ => ContextKind::Function,
521 },
522 "python" => match node_kind {
523 "class_definition" => ContextKind::Class,
524 _ => ContextKind::Function,
525 },
526 "go" => match node_kind {
527 "method_declaration" => ContextKind::Method,
528 "type_declaration" => ContextKind::Struct,
529 _ => ContextKind::Function,
530 },
531 _ => ContextKind::Function,
532 }
533 }
534
535 pub fn extract_from_directory(&self, root: &Path) -> Result<Vec<ContextualMatch>> {
542 let mut all_matches = Vec::new();
543
544 for entry in walkdir::WalkDir::new(root)
545 .follow_links(false)
546 .into_iter()
547 .filter_map(std::result::Result::ok)
548 {
549 let path = entry.path();
550 if path.is_file() {
551 if let Ok(mut matches) = self.extract_from_file(path) {
553 all_matches.append(&mut matches);
554 }
555 }
556 }
557
558 Ok(all_matches)
559 }
560}
561
562impl Default for ContextExtractor {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568#[cfg(all(test, feature = "context-tests"))]
587mod tests {
588 use super::*;
589 use std::fs;
590 use tempfile::TempDir;
591
592 fn create_test_plugin_manager() -> crate::plugin::PluginManager {
599 crate::test_support::plugin_factory::with_builtin_plugins()
603 }
604
605 #[test]
606 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
607 fn test_extract_rust_function_context() {
608 let dir = TempDir::new().unwrap();
609 let file_path = dir.path().join("test.rs");
610 fs::write(
611 &file_path,
612 r#"
613fn top_level() {
614 println!("hello");
615}
616
617struct MyStruct {
618 value: i32,
619}
620
621impl MyStruct {
622 fn method(&self) -> i32 {
623 self.value
624 }
625}
626"#,
627 )
628 .unwrap();
629
630 let manager = create_test_plugin_manager();
631 let extractor = ContextExtractor::with_plugin_manager(manager);
632 let matches = extractor.extract_from_file(&file_path).unwrap();
633
634 assert!(
636 matches.len() >= 2,
637 "Expected at least 2 matches, found {}",
638 matches.len()
639 );
640
641 let top_level = matches.iter().find(|m| m.name == "top_level");
643 assert!(top_level.is_some(), "Should find top_level function");
644 if let Some(m) = top_level {
645 assert_eq!(m.context.depth(), 1, "top_level should be at depth 1");
646 assert_eq!(m.context.path(), "top_level");
647 }
648
649 let method = matches.iter().find(|m| m.name == "method");
651 if let Some(m) = method {
652 assert!(m.context.depth() >= 1, "method should have depth >= 1");
653 assert!(m.context.parent.is_some(), "method should have a parent");
654 }
655 }
656
657 #[test]
658 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
659 fn test_extract_nested_context() {
660 let dir = TempDir::new().unwrap();
661 let file_path = dir.path().join("test.rs");
662 fs::write(
663 &file_path,
664 r"
665mod outer {
666 struct Inner {
667 value: i32,
668 }
669
670 impl Inner {
671 fn deeply_nested(&self) {
672 // nested function
673 }
674 }
675}
676",
677 )
678 .unwrap();
679
680 let manager = create_test_plugin_manager();
681 let extractor = ContextExtractor::with_plugin_manager(manager);
682 let matches = extractor.extract_from_file(&file_path).unwrap();
683
684 let method = matches.iter().find(|m| m.name == "deeply_nested");
686 if let Some(m) = method {
687 assert!(m.context.parent.is_some(), "Should have parent");
689 assert!(m.context.depth() >= 1, "Should have depth >= 1");
691 }
692 }
693
694 #[test]
699 #[ignore = "JavaScript plugin not registered in test helper"]
700 fn test_extract_javascript_class() {
701 let dir = TempDir::new().unwrap();
702 let file_path = dir.path().join("test.js");
703 fs::write(
704 &file_path,
705 r#"
706function topLevel() {
707 console.log("hello");
708}
709
710class MyClass {
711 constructor(name) {
712 this.name = name;
713 }
714
715 greet() {
716 console.log("Hello " + this.name);
717 }
718}
719"#,
720 )
721 .unwrap();
722
723 let manager = create_test_plugin_manager();
724 let extractor = ContextExtractor::with_plugin_manager(manager);
725 let matches = extractor.extract_from_file(&file_path).unwrap();
726
727 assert!(matches.len() >= 2, "Should find at least 2 matches");
728
729 let top_fn = matches.iter().find(|m| m.name == "topLevel");
731 if let Some(m) = top_fn {
732 assert_eq!(m.context.depth(), 1, "topLevel should be at depth 1");
733 }
734
735 let class = matches.iter().find(|m| m.name == "MyClass");
737 assert!(class.is_some(), "Should find MyClass");
738 }
739
740 #[test]
741 #[ignore = "Python plugin not registered in test helper"]
742 fn test_extract_python_context() {
743 let dir = TempDir::new().unwrap();
744 let file_path = dir.path().join("test.py");
745 fs::write(
746 &file_path,
747 r#"
748def top_level():
749 print("hello")
750
751class MyClass:
752 def method(self):
753 return 42
754"#,
755 )
756 .unwrap();
757
758 let manager = create_test_plugin_manager();
759 let extractor = ContextExtractor::with_plugin_manager(manager);
760 let matches = extractor.extract_from_file(&file_path).unwrap();
761
762 assert!(matches.len() >= 2, "Should find at least 2 matches");
763
764 let top_fn = matches.iter().find(|m| m.name == "top_level");
766 if let Some(m) = top_fn {
767 assert_eq!(m.context.depth(), 1);
768 }
769
770 let method = matches.iter().find(|m| m.name == "method");
772 if let Some(m) = method {
773 assert!(m.context.depth() >= 2);
774 assert!(m.context.parent.is_some());
775 }
776 }
777
778 #[test]
779 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
780 fn test_empty_file() {
781 let dir = TempDir::new().unwrap();
782 let file_path = dir.path().join("empty.rs");
783 fs::write(&file_path, "").unwrap();
784
785 let manager = create_test_plugin_manager();
786 let extractor = ContextExtractor::with_plugin_manager(manager);
787 let matches = extractor.extract_from_file(&file_path).unwrap();
788
789 assert_eq!(matches.len(), 0);
790 }
791
792 #[test]
795 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
796 fn test_position_matching_single_line_function() {
797 let dir = TempDir::new().unwrap();
799 let file_path = dir.path().join("test.rs");
800 fs::write(
801 &file_path,
802 r#"
803fn single_line() { println!("hello"); }
804"#,
805 )
806 .unwrap();
807
808 let manager = create_test_plugin_manager();
809 let extractor = ContextExtractor::with_plugin_manager(manager);
810 let matches = extractor.extract_from_file(&file_path).unwrap();
811
812 let func = matches.iter().find(|m| m.name == "single_line");
814 assert!(func.is_some(), "Should find single-line function");
815
816 if let Some(m) = func {
817 assert_eq!(m.context.depth(), 1);
818 assert_eq!(m.context.path(), "single_line");
819 }
820 }
821
822 #[test]
823 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
824 fn test_position_matching_multiline_function() {
825 let dir = TempDir::new().unwrap();
827 let file_path = dir.path().join("test.rs");
828 fs::write(
829 &file_path,
830 r"
831fn multiline() {
832 let x = 1;
833 let y = 2;
834 x + y
835}
836",
837 )
838 .unwrap();
839
840 let manager = create_test_plugin_manager();
841 let extractor = ContextExtractor::with_plugin_manager(manager);
842 let matches = extractor.extract_from_file(&file_path).unwrap();
843
844 let func = matches.iter().find(|m| m.name == "multiline");
846 assert!(func.is_some(), "Should find multi-line function");
847
848 if let Some(m) = func {
849 assert_eq!(m.context.depth(), 1);
850 assert!(m.end_line > m.start_line + 1);
852 }
853 }
854
855 #[test]
856 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
857 fn test_position_matching_nested_structures() {
858 let dir = TempDir::new().unwrap();
860 let file_path = dir.path().join("test.rs");
861 fs::write(
862 &file_path,
863 r"
864mod outer {
865 struct Inner {
866 field: i32,
867 }
868
869 impl Inner {
870 fn method(&self) -> i32 {
871 self.field
872 }
873 }
874}
875",
876 )
877 .unwrap();
878
879 let manager = create_test_plugin_manager();
880 let extractor = ContextExtractor::with_plugin_manager(manager);
881 let matches = extractor.extract_from_file(&file_path).unwrap();
882
883 let method = matches.iter().find(|m| m.name == "method");
885 assert!(method.is_some(), "Should find nested method");
886
887 if let Some(m) = method {
888 assert_eq!(m.context.depth(), 3, "Method should have depth 3");
892 assert!(m.context.parent.is_some(), "Method should have parent");
893 if let Some(parent) = &m.context.parent {
894 assert_eq!(parent.name, "Inner", "Method parent should be Inner impl");
896 }
897 }
898 }
899
900 #[test]
901 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
902 fn test_position_matching_with_comments() {
903 let dir = TempDir::new().unwrap();
905 let file_path = dir.path().join("test.rs");
906 fs::write(
907 &file_path,
908 r#"
909// This is a comment
910/// Documentation comment
911fn documented_function() {
912 // Internal comment
913 println!("test");
914}
915"#,
916 )
917 .unwrap();
918
919 let manager = create_test_plugin_manager();
920 let extractor = ContextExtractor::with_plugin_manager(manager);
921 let matches = extractor.extract_from_file(&file_path).unwrap();
922
923 let func = matches.iter().find(|m| m.name == "documented_function");
925 assert!(func.is_some(), "Should find function with comments");
926
927 if let Some(m) = func {
928 assert_eq!(m.context.depth(), 1);
929 }
930 }
931
932 #[test]
933 #[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
934 fn test_position_matching_edge_positions() {
935 let dir = TempDir::new().unwrap();
937 let file_path = dir.path().join("test.rs");
938 fs::write(
939 &file_path,
940 r"
941struct Container {
942 value: i32,
943}
944
945impl Container {
946 fn new(val: i32) -> Self {
947 Self { value: val }
948 }
949}
950",
951 )
952 .unwrap();
953
954 let manager = create_test_plugin_manager();
955 let extractor = ContextExtractor::with_plugin_manager(manager);
956 let matches = extractor.extract_from_file(&file_path).unwrap();
957
958 let container = matches.iter().find(|m| m.name == "Container");
960 let new_method = matches.iter().find(|m| m.name == "new");
961
962 assert!(container.is_some(), "Should find Container struct");
963 assert!(new_method.is_some(), "Should find new method");
964
965 if let Some(m) = new_method {
966 assert_eq!(m.context.depth(), 2, "Method should have depth 2");
970 if let Some(parent) = &m.context.parent {
971 assert_eq!(
972 parent.name, "Container",
973 "Method parent should be Container impl"
974 );
975 }
976 }
977 }
978}