1use ahash::AHashMap;
7
8use crate::Error;
9use crate::node::{NodeInfo, node_info_from_node};
10
11#[cfg(feature = "serde")]
13pub(crate) fn ahashmap_is_empty<K, V>(map: &AHashMap<K, V>) -> bool {
14 map.is_empty()
15}
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub enum CaptureOutput {
21 Text,
23 Node,
25 #[default]
27 Full,
28}
29
30#[derive(Debug, Clone)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct ExtractionPattern {
34 pub query: String,
36 #[cfg_attr(feature = "serde", serde(default))]
38 pub capture_output: CaptureOutput,
39 #[cfg_attr(feature = "serde", serde(default))]
42 pub child_fields: Vec<String>,
43 #[cfg_attr(feature = "serde", serde(default))]
45 pub max_results: Option<usize>,
46 #[cfg_attr(feature = "serde", serde(default))]
48 pub byte_range: Option<(usize, usize)>,
49}
50
51#[derive(Debug, Clone)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct ExtractionConfig {
55 pub language: String,
57 pub patterns: AHashMap<String, ExtractionPattern>,
59}
60
61#[derive(Debug, Clone)]
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64pub struct CaptureResult {
65 pub name: String,
67 pub node: Option<NodeInfo>,
69 pub text: Option<String>,
71 pub child_fields: AHashMap<String, Option<String>>,
73 pub start_byte: usize,
75}
76
77#[derive(Debug, Clone)]
79#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
80pub struct MatchResult {
81 pub pattern_index: usize,
83 pub captures: Vec<CaptureResult>,
85}
86
87#[derive(Debug, Clone)]
89#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
90pub struct PatternResult {
91 pub matches: Vec<MatchResult>,
93 pub total_count: usize,
95}
96
97#[derive(Debug, Clone)]
99#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
100pub struct ExtractionResult {
101 pub language: String,
103 pub results: AHashMap<String, PatternResult>,
105}
106
107#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
110pub struct PatternValidation {
111 pub valid: bool,
113 pub capture_names: Vec<String>,
115 pub pattern_count: usize,
117 pub warnings: Vec<String>,
119 pub errors: Vec<String>,
121}
122
123#[derive(Debug, Clone)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub struct ValidationResult {
127 pub valid: bool,
129 pub patterns: AHashMap<String, PatternValidation>,
131}
132
133struct CompiledPattern {
135 name: String,
136 query: tree_sitter::Query,
137 capture_names: Vec<String>,
138 config: ExtractionPattern,
139}
140
141pub struct CompiledExtraction {
147 language: tree_sitter::Language,
148 language_name: String,
149 patterns: Vec<CompiledPattern>,
150}
151
152unsafe impl Send for CompiledExtraction {}
155unsafe impl Sync for CompiledExtraction {}
157
158impl std::fmt::Debug for CompiledExtraction {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("CompiledExtraction")
161 .field("language_name", &self.language_name)
162 .field("pattern_count", &self.patterns.len())
163 .finish()
164 }
165}
166
167pub fn extract(source: &str, config: &ExtractionConfig) -> Result<ExtractionResult, Error> {
177 let compiled = CompiledExtraction::compile(config)?;
178 compiled.extract(source)
179}
180
181pub fn validate_extraction(config: &ExtractionConfig) -> Result<ValidationResult, Error> {
190 let lang = crate::get_language(&config.language)?;
191 let mut all_valid = true;
192 let mut patterns = AHashMap::new();
193
194 for (name, pat) in &config.patterns {
195 match tree_sitter::Query::new(&lang, &pat.query) {
196 Ok(query) => {
197 let capture_names: Vec<String> = query.capture_names().iter().map(|s| s.to_string()).collect();
198 let pattern_count = query.pattern_count();
199
200 let mut warnings = Vec::new();
201 for field in &pat.child_fields {
203 if field.is_empty() {
204 warnings.push(format!("empty child field name in pattern '{name}'"));
205 }
206 }
207
208 patterns.insert(
209 name.clone(),
210 PatternValidation {
211 valid: true,
212 capture_names,
213 pattern_count,
214 warnings,
215 errors: Vec::new(),
216 },
217 );
218 }
219 Err(e) => {
220 all_valid = false;
221 patterns.insert(
222 name.clone(),
223 PatternValidation {
224 valid: false,
225 capture_names: Vec::new(),
226 pattern_count: 0,
227 warnings: Vec::new(),
228 errors: vec![e.to_string()],
229 },
230 );
231 }
232 }
233 }
234
235 Ok(ValidationResult {
236 valid: all_valid,
237 patterns,
238 })
239}
240
241impl CompiledExtraction {
242 pub fn compile(config: &ExtractionConfig) -> Result<Self, Error> {
248 let language = crate::get_language(&config.language)?;
249 Self::compile_with_language(language, &config.language, &config.patterns)
250 }
251
252 pub fn compile_with_language(
261 language: tree_sitter::Language,
262 language_name: &str,
263 extraction_patterns: &AHashMap<String, ExtractionPattern>,
264 ) -> Result<Self, Error> {
265 let mut patterns = Vec::with_capacity(extraction_patterns.len());
266
267 for (name, pat) in extraction_patterns {
268 let query = tree_sitter::Query::new(&language, &pat.query)
269 .map_err(|e| Error::QueryError(format!("pattern '{name}': {e}")))?;
270 let capture_names = query.capture_names().iter().map(|s| s.to_string()).collect();
271 patterns.push(CompiledPattern {
272 name: name.clone(),
273 query,
274 capture_names,
275 config: pat.clone(),
276 });
277 }
278
279 Ok(Self {
280 language,
281 language_name: language_name.to_string(),
282 patterns,
283 })
284 }
285
286 pub fn extract(&self, source: &str) -> Result<ExtractionResult, Error> {
295 let tree = crate::parse::parse_with_language(&self.language_name, &self.language, source.as_bytes())?;
296 self.extract_from_tree(&tree, source.as_bytes())
297 }
298
299 pub fn extract_from_tree(&self, tree: &tree_sitter::Tree, source: &[u8]) -> Result<ExtractionResult, Error> {
305 use tree_sitter::StreamingIterator;
306
307 let mut results = AHashMap::with_capacity(self.patterns.len());
308 let mut cursor = tree_sitter::QueryCursor::new();
309
310 for cp in &self.patterns {
311 if let Some((start, end)) = cp.config.byte_range {
313 cursor.set_byte_range(start..end);
314 } else {
315 cursor.set_byte_range(0..usize::MAX);
316 }
317
318 let mut matches_iter = cursor.matches(&cp.query, tree.root_node(), source);
319 let mut match_results = Vec::new();
320 let mut total_count: usize = 0;
321
322 while let Some(m) = matches_iter.next() {
323 total_count += 1;
324
325 if let Some(max) = cp.config.max_results
327 && match_results.len() >= max
328 {
329 continue;
330 }
331
332 let mut captures = Vec::with_capacity(m.captures.len());
333 for cap in m.captures {
334 let cap_name = cp
335 .capture_names
336 .get(cap.index as usize)
337 .ok_or_else(|| Error::QueryError(format!("invalid capture index {}", cap.index)))?;
338 let ts_node = cap.node;
339 let info = node_info_from_node(ts_node);
340 let capture_start_byte = info.start_byte;
341
342 let text = match cp.config.capture_output {
343 CaptureOutput::Text | CaptureOutput::Full => {
344 crate::node::extract_text(source, &info).ok().map(String::from)
345 }
346 CaptureOutput::Node => None,
347 };
348
349 let node = match cp.config.capture_output {
350 CaptureOutput::Node | CaptureOutput::Full => Some(info),
351 CaptureOutput::Text => None,
352 };
353
354 let child_field_values = if cp.config.child_fields.is_empty() {
356 AHashMap::new()
357 } else {
358 let mut fields = AHashMap::with_capacity(cp.config.child_fields.len());
359 for field_name in &cp.config.child_fields {
360 let value = ts_node.child_by_field_name(field_name.as_str()).and_then(|child| {
361 let child_info = node_info_from_node(child);
362 crate::node::extract_text(source, &child_info).ok().map(String::from)
363 });
364 fields.insert(field_name.clone(), value);
365 }
366 fields
367 };
368
369 captures.push(CaptureResult {
370 name: cap_name.clone(),
371 node,
372 text,
373 child_fields: child_field_values,
374 start_byte: capture_start_byte,
375 });
376 }
377
378 match_results.push(MatchResult {
379 pattern_index: m.pattern_index,
380 captures,
381 });
382 }
383
384 match_results.sort_by_key(|m| m.captures.first().map_or(0, |c| c.start_byte));
386
387 results.insert(
388 cp.name.clone(),
389 PatternResult {
390 matches: match_results,
391 total_count,
392 },
393 );
394 }
395
396 Ok(ExtractionResult {
397 language: self.language_name.clone(),
398 results,
399 })
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 fn skip_if_no_python() -> bool {
409 !crate::has_language("python")
410 }
411
412 fn python_config(patterns: AHashMap<String, ExtractionPattern>) -> ExtractionConfig {
413 ExtractionConfig {
414 language: "python".to_string(),
415 patterns,
416 }
417 }
418
419 fn single_pattern(name: &str, query: &str) -> AHashMap<String, ExtractionPattern> {
420 let mut m = AHashMap::new();
421 m.insert(
422 name.to_string(),
423 ExtractionPattern {
424 query: query.to_string(),
425 capture_output: CaptureOutput::default(),
426 child_fields: Vec::new(),
427 max_results: None,
428 byte_range: None,
429 },
430 );
431 m
432 }
433
434 #[test]
435 fn test_basic_extraction() {
436 if skip_if_no_python() {
437 return;
438 }
439 let config = python_config(single_pattern(
440 "functions",
441 "(function_definition name: (identifier) @fn_name) @fn_def",
442 ));
443 let result = extract("def hello():\n pass\n\ndef world():\n pass\n", &config).unwrap();
444 assert_eq!(result.language, "python");
445
446 let fns = &result.results["functions"];
447 assert_eq!(fns.total_count, 2);
448 assert_eq!(fns.matches.len(), 2);
449
450 for m in &fns.matches {
452 assert_eq!(m.captures.len(), 2);
453 }
454 }
455
456 #[test]
457 fn test_capture_output_text_only() {
458 if skip_if_no_python() {
459 return;
460 }
461 let mut patterns = AHashMap::new();
462 patterns.insert(
463 "names".to_string(),
464 ExtractionPattern {
465 query: "(function_definition name: (identifier) @fn_name)".to_string(),
466 capture_output: CaptureOutput::Text,
467 child_fields: Vec::new(),
468 max_results: None,
469 byte_range: None,
470 },
471 );
472 let config = python_config(patterns);
473 let result = extract("def foo():\n pass\n", &config).unwrap();
474 let names = &result.results["names"];
475 assert_eq!(names.matches.len(), 1);
476
477 let cap = &names.matches[0].captures[0];
478 assert_eq!(cap.name, "fn_name");
479 assert!(cap.text.is_some());
480 assert_eq!(cap.text.as_deref(), Some("foo"));
481 assert!(cap.node.is_none(), "Text mode should not include NodeInfo");
482 }
483
484 #[test]
485 fn test_capture_output_node_only() {
486 if skip_if_no_python() {
487 return;
488 }
489 let mut patterns = AHashMap::new();
490 patterns.insert(
491 "names".to_string(),
492 ExtractionPattern {
493 query: "(function_definition name: (identifier) @fn_name)".to_string(),
494 capture_output: CaptureOutput::Node,
495 child_fields: Vec::new(),
496 max_results: None,
497 byte_range: None,
498 },
499 );
500 let config = python_config(patterns);
501 let result = extract("def foo():\n pass\n", &config).unwrap();
502 let cap = &result.results["names"].matches[0].captures[0];
503 assert!(cap.node.is_some(), "Node mode should include NodeInfo");
504 assert!(cap.text.is_none(), "Node mode should not include text");
505 }
506
507 #[test]
508 fn test_capture_output_full() {
509 if skip_if_no_python() {
510 return;
511 }
512 let mut patterns = AHashMap::new();
513 patterns.insert(
514 "names".to_string(),
515 ExtractionPattern {
516 query: "(function_definition name: (identifier) @fn_name)".to_string(),
517 capture_output: CaptureOutput::Full,
518 child_fields: Vec::new(),
519 max_results: None,
520 byte_range: None,
521 },
522 );
523 let config = python_config(patterns);
524 let result = extract("def foo():\n pass\n", &config).unwrap();
525 let cap = &result.results["names"].matches[0].captures[0];
526 assert!(cap.node.is_some(), "Full mode should include NodeInfo");
527 assert!(cap.text.is_some(), "Full mode should include text");
528 assert_eq!(cap.text.as_deref(), Some("foo"));
529 }
530
531 #[test]
532 fn test_child_fields_extraction() {
533 if skip_if_no_python() {
534 return;
535 }
536 let mut patterns = AHashMap::new();
537 patterns.insert(
538 "functions".to_string(),
539 ExtractionPattern {
540 query: "(function_definition) @fn_def".to_string(),
541 capture_output: CaptureOutput::Full,
542 child_fields: vec!["name".to_string(), "parameters".to_string()],
543 max_results: None,
544 byte_range: None,
545 },
546 );
547 let config = python_config(patterns);
548 let result = extract("def greet(name):\n pass\n", &config).unwrap();
549 let fns = &result.results["functions"];
550 assert_eq!(fns.matches.len(), 1);
551
552 let cap = &fns.matches[0].captures[0];
553 assert!(cap.child_fields.contains_key("name"));
554 assert_eq!(cap.child_fields["name"].as_deref(), Some("greet"));
555 assert!(cap.child_fields.contains_key("parameters"));
556 assert!(cap.child_fields["parameters"].is_some());
558 }
559
560 #[test]
561 fn test_validation_valid_query() {
562 if skip_if_no_python() {
563 return;
564 }
565 let config = python_config(single_pattern(
566 "fns",
567 "(function_definition name: (identifier) @fn_name)",
568 ));
569 let validation = validate_extraction(&config).unwrap();
570 assert!(validation.valid);
571 let pv = &validation.patterns["fns"];
572 assert!(pv.valid);
573 assert!(pv.capture_names.contains(&"fn_name".to_string()));
574 assert!(pv.errors.is_empty());
575 }
576
577 #[test]
578 fn test_validation_invalid_query() {
579 if skip_if_no_python() {
580 return;
581 }
582 let config = python_config(single_pattern("bad", "((((not valid syntax"));
583 let validation = validate_extraction(&config).unwrap();
584 assert!(!validation.valid);
585 let pv = &validation.patterns["bad"];
586 assert!(!pv.valid);
587 assert!(!pv.errors.is_empty());
588 }
589
590 #[test]
591 fn test_validation_unknown_language() {
592 let config = ExtractionConfig {
593 language: "nonexistent_xyz_lang".to_string(),
594 patterns: AHashMap::new(),
595 };
596 let result = validate_extraction(&config);
597 assert!(result.is_err());
598 }
599
600 #[test]
601 fn test_max_results_truncation() {
602 if skip_if_no_python() {
603 return;
604 }
605 let mut patterns = AHashMap::new();
606 patterns.insert(
607 "fns".to_string(),
608 ExtractionPattern {
609 query: "(function_definition name: (identifier) @fn_name)".to_string(),
610 capture_output: CaptureOutput::Text,
611 child_fields: Vec::new(),
612 max_results: Some(1),
613 byte_range: None,
614 },
615 );
616 let config = python_config(patterns);
617 let result = extract("def a():\n pass\ndef b():\n pass\ndef c():\n pass\n", &config).unwrap();
618 let fns = &result.results["fns"];
619 assert_eq!(fns.matches.len(), 1, "should be truncated to max_results=1");
620 assert_eq!(fns.total_count, 3, "total_count should reflect all matches");
621 }
622
623 #[test]
624 fn test_compiled_extraction_reuse() {
625 if skip_if_no_python() {
626 return;
627 }
628 let config = python_config(single_pattern(
629 "fns",
630 "(function_definition name: (identifier) @fn_name)",
631 ));
632 let compiled = CompiledExtraction::compile(&config).unwrap();
633
634 let r1 = compiled.extract("def a():\n pass\n").unwrap();
635 let r2 = compiled.extract("def x():\n pass\ndef y():\n pass\n").unwrap();
636
637 assert_eq!(r1.results["fns"].total_count, 1);
638 assert_eq!(r2.results["fns"].total_count, 2);
639 }
640
641 #[test]
642 fn test_empty_results() {
643 if skip_if_no_python() {
644 return;
645 }
646 let config = python_config(single_pattern(
647 "classes",
648 "(class_definition name: (identifier) @cls_name)",
649 ));
650 let result = extract("x = 1\n", &config).unwrap();
652 let classes = &result.results["classes"];
653 assert!(classes.matches.is_empty());
654 assert_eq!(classes.total_count, 0);
655 }
656
657 #[test]
658 fn test_send_sync() {
659 fn assert_send<T: Send>() {}
660 fn assert_sync<T: Sync>() {}
661 assert_send::<CompiledExtraction>();
662 assert_sync::<CompiledExtraction>();
663 assert_send::<ExtractionResult>();
664 assert_sync::<ExtractionResult>();
665 assert_send::<ExtractionConfig>();
666 assert_sync::<ExtractionConfig>();
667 assert_send::<CaptureOutput>();
668 assert_sync::<CaptureOutput>();
669 }
670
671 #[test]
672 fn test_byte_range_restriction() {
673 if skip_if_no_python() {
674 return;
675 }
676 let source = "def a():\n pass\ndef b():\n pass\ndef c():\n pass\n";
677 let second_fn_start = source.find("def b").unwrap();
679 let second_fn_end = source[second_fn_start..]
680 .find("def c")
681 .map_or(source.len(), |i| second_fn_start + i);
682 let mut patterns = AHashMap::new();
683 patterns.insert(
684 "fns".to_string(),
685 ExtractionPattern {
686 query: "(function_definition name: (identifier) @fn_name)".to_string(),
687 capture_output: CaptureOutput::Text,
688 child_fields: Vec::new(),
689 max_results: None,
690 byte_range: Some((second_fn_start, second_fn_end)),
691 },
692 );
693 let config = python_config(patterns);
694 let result = extract(source, &config).unwrap();
695 let fns = &result.results["fns"];
696 assert_eq!(fns.matches.len(), 1, "byte_range should restrict to one function");
697 assert_eq!(
698 fns.matches[0].captures[0].text.as_deref(),
699 Some("b"),
700 "should capture function 'b' within the byte range"
701 );
702 }
703
704 #[test]
705 fn test_result_ordering() {
706 if skip_if_no_python() {
707 return;
708 }
709 for mode in [CaptureOutput::Text, CaptureOutput::Node, CaptureOutput::Full] {
711 let mut patterns = AHashMap::new();
712 patterns.insert(
713 "fns".to_string(),
714 ExtractionPattern {
715 query: "(function_definition name: (identifier) @fn_name)".to_string(),
716 capture_output: mode.clone(),
717 child_fields: Vec::new(),
718 max_results: None,
719 byte_range: None,
720 },
721 );
722 let config = python_config(patterns);
723 let result = extract(
724 "def alpha():\n pass\ndef beta():\n pass\ndef gamma():\n pass\n",
725 &config,
726 )
727 .unwrap();
728 let fns = &result.results["fns"];
729 assert_eq!(fns.matches.len(), 3);
730
731 let start_bytes: Vec<usize> = fns.matches.iter().map(|m| m.captures[0].start_byte).collect();
733 for pair in start_bytes.windows(2) {
734 assert!(
735 pair[0] < pair[1],
736 "results should be sorted by position, got {start_bytes:?} in mode {mode:?}"
737 );
738 }
739 }
740 }
741
742 #[test]
743 fn test_extract_from_tree() {
744 if skip_if_no_python() {
745 return;
746 }
747 let config = python_config(single_pattern(
748 "fns",
749 "(function_definition name: (identifier) @fn_name)",
750 ));
751 let compiled = CompiledExtraction::compile(&config).unwrap();
752
753 let source = "def hello():\n pass\n";
754 let tree = crate::parse::parse_string("python", source.as_bytes()).unwrap();
755 let result = compiled.extract_from_tree(&tree, source.as_bytes()).unwrap();
756
757 assert_eq!(result.results["fns"].total_count, 1);
758 let cap = &result.results["fns"].matches[0].captures[0];
759 assert_eq!(cap.text.as_deref(), Some("hello"));
760 }
761
762 #[test]
763 fn test_byte_range_does_not_leak_between_patterns() {
764 if skip_if_no_python() {
765 return;
766 }
767 let source = "def a():\n pass\ndef b():\n pass\ndef c():\n pass\n";
768 let second_fn_start = source.find("def b").unwrap();
769 let second_fn_end = source[second_fn_start..]
770 .find("def c")
771 .map_or(source.len(), |i| second_fn_start + i);
772
773 let mut patterns = AHashMap::new();
776 patterns.insert(
777 "restricted".to_string(),
778 ExtractionPattern {
779 query: "(function_definition name: (identifier) @fn_name)".to_string(),
780 capture_output: CaptureOutput::Text,
781 child_fields: Vec::new(),
782 max_results: None,
783 byte_range: Some((second_fn_start, second_fn_end)),
784 },
785 );
786 patterns.insert(
787 "unrestricted".to_string(),
788 ExtractionPattern {
789 query: "(function_definition name: (identifier) @fn_name)".to_string(),
790 capture_output: CaptureOutput::Text,
791 child_fields: Vec::new(),
792 max_results: None,
793 byte_range: None,
794 },
795 );
796
797 let config = python_config(patterns);
798 let compiled = CompiledExtraction::compile(&config).unwrap();
799 let result = compiled.extract(source).unwrap();
800
801 let restricted = &result.results["restricted"];
802 assert_eq!(restricted.matches.len(), 1, "restricted pattern should find 1 function");
803
804 let unrestricted = &result.results["unrestricted"];
805 assert_eq!(
806 unrestricted.matches.len(),
807 3,
808 "unrestricted pattern should find all 3 functions, not be limited by previous byte_range"
809 );
810 }
811
812 #[test]
813 fn test_compiled_extraction_capture_names_precomputed() {
814 if skip_if_no_python() {
815 return;
816 }
817 let mut patterns = AHashMap::new();
819 patterns.insert(
820 "fns".to_string(),
821 ExtractionPattern {
822 query: "(function_definition name: (identifier) @fn_name) @fn_def".to_string(),
823 capture_output: CaptureOutput::Full,
824 child_fields: Vec::new(),
825 max_results: None,
826 byte_range: None,
827 },
828 );
829 let config = python_config(patterns);
830 let result = extract("def hello():\n pass\n", &config).unwrap();
831 let fns = &result.results["fns"];
832 assert_eq!(fns.matches.len(), 1);
833 let names: Vec<&str> = fns.matches[0].captures.iter().map(|c| c.name.as_str()).collect();
834 assert!(names.contains(&"fn_name"), "should have fn_name capture");
835 assert!(names.contains(&"fn_def"), "should have fn_def capture");
836 }
837}