1use ahash::AHashMap;
7
8use crate::Error;
9use crate::node::{NodeInfo, node_info_from_node};
10
11#[cfg(feature = "serde")]
13pub(crate) fn ahashmap_is_empty<K, V>(map: &AHashMap<K, V>) -> bool {
14 map.is_empty()
15}
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub enum CaptureOutput {
21 Text,
23 Node,
25 #[default]
27 Full,
28}
29
30#[derive(Debug, Clone)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct ExtractionPattern {
34 pub query: String,
36 #[cfg_attr(feature = "serde", serde(default))]
38 pub capture_output: CaptureOutput,
39 #[cfg_attr(feature = "serde", serde(default))]
42 pub child_fields: Vec<String>,
43 #[cfg_attr(feature = "serde", serde(default))]
45 pub max_results: Option<usize>,
46 #[cfg_attr(feature = "serde", serde(default))]
48 pub byte_range: Option<(usize, usize)>,
49}
50
51#[derive(Debug, Clone)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct ExtractionConfig {
55 pub language: String,
57 pub patterns: AHashMap<String, ExtractionPattern>,
59}
60
61#[derive(Debug, Clone)]
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64pub struct CaptureResult {
65 pub name: String,
67 pub node: Option<NodeInfo>,
69 pub text: Option<String>,
71 pub child_fields: AHashMap<String, Option<String>>,
73 pub start_byte: usize,
75}
76
77#[derive(Debug, Clone)]
79#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
80pub struct MatchResult {
81 pub pattern_index: usize,
83 pub captures: Vec<CaptureResult>,
85}
86
87#[derive(Debug, Clone)]
89#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
90pub struct PatternResult {
91 pub matches: Vec<MatchResult>,
93 pub total_count: usize,
95}
96
97#[derive(Debug, Clone)]
99#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
100pub struct ExtractionResult {
101 pub language: String,
103 pub results: AHashMap<String, PatternResult>,
105}
106
107#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
110pub struct PatternValidation {
111 pub valid: bool,
113 pub capture_names: Vec<String>,
115 pub pattern_count: usize,
117 pub warnings: Vec<String>,
119 pub errors: Vec<String>,
121}
122
123#[derive(Debug, Clone)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub struct ValidationResult {
127 pub valid: bool,
129 pub patterns: AHashMap<String, PatternValidation>,
131}
132
133pub struct CompiledExtraction {
139 language: tree_sitter::Language,
140 language_name: String,
141 patterns: Vec<(String, tree_sitter::Query, ExtractionPattern)>,
143}
144
145unsafe impl Send for CompiledExtraction {}
148unsafe impl Sync for CompiledExtraction {}
150
151impl std::fmt::Debug for CompiledExtraction {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("CompiledExtraction")
154 .field("language_name", &self.language_name)
155 .field("pattern_count", &self.patterns.len())
156 .finish()
157 }
158}
159
160pub fn extract(source: &str, config: &ExtractionConfig) -> Result<ExtractionResult, Error> {
170 let compiled = CompiledExtraction::compile(config)?;
171 compiled.extract(source)
172}
173
174pub fn validate_extraction(config: &ExtractionConfig) -> Result<ValidationResult, Error> {
183 let lang = crate::get_language(&config.language)?;
184 let mut all_valid = true;
185 let mut patterns = AHashMap::new();
186
187 for (name, pat) in &config.patterns {
188 match tree_sitter::Query::new(&lang, &pat.query) {
189 Ok(query) => {
190 let capture_names: Vec<String> = query.capture_names().iter().map(|s| s.to_string()).collect();
191 let pattern_count = query.pattern_count();
192
193 let mut warnings = Vec::new();
194 for field in &pat.child_fields {
196 if field.is_empty() {
197 warnings.push(format!("empty child field name in pattern '{name}'"));
198 }
199 }
200
201 patterns.insert(
202 name.clone(),
203 PatternValidation {
204 valid: true,
205 capture_names,
206 pattern_count,
207 warnings,
208 errors: Vec::new(),
209 },
210 );
211 }
212 Err(e) => {
213 all_valid = false;
214 patterns.insert(
215 name.clone(),
216 PatternValidation {
217 valid: false,
218 capture_names: Vec::new(),
219 pattern_count: 0,
220 warnings: Vec::new(),
221 errors: vec![e.to_string()],
222 },
223 );
224 }
225 }
226 }
227
228 Ok(ValidationResult {
229 valid: all_valid,
230 patterns,
231 })
232}
233
234impl CompiledExtraction {
235 pub fn compile(config: &ExtractionConfig) -> Result<Self, Error> {
241 let language = crate::get_language(&config.language)?;
242 Self::compile_with_language(language, &config.language, &config.patterns)
243 }
244
245 pub fn compile_with_language(
254 language: tree_sitter::Language,
255 language_name: &str,
256 extraction_patterns: &AHashMap<String, ExtractionPattern>,
257 ) -> Result<Self, Error> {
258 let mut patterns = Vec::with_capacity(extraction_patterns.len());
259
260 for (name, pat) in extraction_patterns {
261 let query = tree_sitter::Query::new(&language, &pat.query)
262 .map_err(|e| Error::QueryError(format!("pattern '{name}': {e}")))?;
263 patterns.push((name.clone(), query, pat.clone()));
264 }
265
266 Ok(Self {
267 language,
268 language_name: language_name.to_string(),
269 patterns,
270 })
271 }
272
273 pub fn extract(&self, source: &str) -> Result<ExtractionResult, Error> {
279 let mut parser = tree_sitter::Parser::new();
280 parser
281 .set_language(&self.language)
282 .map_err(|e| Error::ParserSetup(format!("{e}")))?;
283 let tree = parser.parse(source.as_bytes(), None).ok_or(Error::ParseFailed)?;
284 self.extract_from_tree(&tree, source.as_bytes())
285 }
286
287 pub fn extract_from_tree(&self, tree: &tree_sitter::Tree, source: &[u8]) -> Result<ExtractionResult, Error> {
293 use tree_sitter::StreamingIterator;
294
295 let mut results = AHashMap::new();
296
297 for (name, query, pat) in &self.patterns {
298 let mut cursor = tree_sitter::QueryCursor::new();
299
300 if let Some((start, end)) = pat.byte_range {
302 cursor.set_byte_range(start..end);
303 }
304
305 let capture_names: Vec<String> = query.capture_names().iter().map(|s| s.to_string()).collect();
306
307 let mut matches_iter = cursor.matches(query, tree.root_node(), source);
308 let mut match_results = Vec::new();
309 let mut total_count: usize = 0;
310
311 while let Some(m) = matches_iter.next() {
312 total_count += 1;
313
314 if let Some(max) = pat.max_results
316 && match_results.len() >= max
317 {
318 continue;
319 }
320
321 let mut captures = Vec::with_capacity(m.captures.len());
322 for cap in m.captures {
323 let cap_name = capture_names
324 .get(cap.index as usize)
325 .ok_or_else(|| Error::QueryError(format!("invalid capture index {}", cap.index)))?;
326 let ts_node = cap.node;
327 let info = node_info_from_node(ts_node);
328 let capture_start_byte = info.start_byte;
329
330 let text = match pat.capture_output {
331 CaptureOutput::Text | CaptureOutput::Full => {
332 crate::node::extract_text(source, &info).ok().map(String::from)
333 }
334 CaptureOutput::Node => None,
335 };
336
337 let node = match pat.capture_output {
338 CaptureOutput::Node | CaptureOutput::Full => Some(info),
339 CaptureOutput::Text => None,
340 };
341
342 let child_field_values = if pat.child_fields.is_empty() {
344 AHashMap::new()
345 } else {
346 let mut fields = AHashMap::with_capacity(pat.child_fields.len());
347 for field_name in &pat.child_fields {
348 let value = ts_node.child_by_field_name(field_name.as_str()).and_then(|child| {
349 let child_info = node_info_from_node(child);
350 crate::node::extract_text(source, &child_info).ok().map(String::from)
351 });
352 fields.insert(field_name.clone(), value);
353 }
354 fields
355 };
356
357 captures.push(CaptureResult {
358 name: cap_name.clone(),
359 node,
360 text,
361 child_fields: child_field_values,
362 start_byte: capture_start_byte,
363 });
364 }
365
366 match_results.push(MatchResult {
367 pattern_index: m.pattern_index,
368 captures,
369 });
370 }
371
372 match_results.sort_by_key(|m| m.captures.first().map_or(0, |c| c.start_byte));
374
375 results.insert(
376 name.clone(),
377 PatternResult {
378 matches: match_results,
379 total_count,
380 },
381 );
382 }
383
384 Ok(ExtractionResult {
385 language: self.language_name.clone(),
386 results,
387 })
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 fn skip_if_no_python() -> bool {
397 !crate::has_language("python")
398 }
399
400 fn python_config(patterns: AHashMap<String, ExtractionPattern>) -> ExtractionConfig {
401 ExtractionConfig {
402 language: "python".to_string(),
403 patterns,
404 }
405 }
406
407 fn single_pattern(name: &str, query: &str) -> AHashMap<String, ExtractionPattern> {
408 let mut m = AHashMap::new();
409 m.insert(
410 name.to_string(),
411 ExtractionPattern {
412 query: query.to_string(),
413 capture_output: CaptureOutput::default(),
414 child_fields: Vec::new(),
415 max_results: None,
416 byte_range: None,
417 },
418 );
419 m
420 }
421
422 #[test]
423 fn test_basic_extraction() {
424 if skip_if_no_python() {
425 return;
426 }
427 let config = python_config(single_pattern(
428 "functions",
429 "(function_definition name: (identifier) @fn_name) @fn_def",
430 ));
431 let result = extract("def hello():\n pass\n\ndef world():\n pass\n", &config).unwrap();
432 assert_eq!(result.language, "python");
433
434 let fns = &result.results["functions"];
435 assert_eq!(fns.total_count, 2);
436 assert_eq!(fns.matches.len(), 2);
437
438 for m in &fns.matches {
440 assert_eq!(m.captures.len(), 2);
441 }
442 }
443
444 #[test]
445 fn test_capture_output_text_only() {
446 if skip_if_no_python() {
447 return;
448 }
449 let mut patterns = AHashMap::new();
450 patterns.insert(
451 "names".to_string(),
452 ExtractionPattern {
453 query: "(function_definition name: (identifier) @fn_name)".to_string(),
454 capture_output: CaptureOutput::Text,
455 child_fields: Vec::new(),
456 max_results: None,
457 byte_range: None,
458 },
459 );
460 let config = python_config(patterns);
461 let result = extract("def foo():\n pass\n", &config).unwrap();
462 let names = &result.results["names"];
463 assert_eq!(names.matches.len(), 1);
464
465 let cap = &names.matches[0].captures[0];
466 assert_eq!(cap.name, "fn_name");
467 assert!(cap.text.is_some());
468 assert_eq!(cap.text.as_deref(), Some("foo"));
469 assert!(cap.node.is_none(), "Text mode should not include NodeInfo");
470 }
471
472 #[test]
473 fn test_capture_output_node_only() {
474 if skip_if_no_python() {
475 return;
476 }
477 let mut patterns = AHashMap::new();
478 patterns.insert(
479 "names".to_string(),
480 ExtractionPattern {
481 query: "(function_definition name: (identifier) @fn_name)".to_string(),
482 capture_output: CaptureOutput::Node,
483 child_fields: Vec::new(),
484 max_results: None,
485 byte_range: None,
486 },
487 );
488 let config = python_config(patterns);
489 let result = extract("def foo():\n pass\n", &config).unwrap();
490 let cap = &result.results["names"].matches[0].captures[0];
491 assert!(cap.node.is_some(), "Node mode should include NodeInfo");
492 assert!(cap.text.is_none(), "Node mode should not include text");
493 }
494
495 #[test]
496 fn test_capture_output_full() {
497 if skip_if_no_python() {
498 return;
499 }
500 let mut patterns = AHashMap::new();
501 patterns.insert(
502 "names".to_string(),
503 ExtractionPattern {
504 query: "(function_definition name: (identifier) @fn_name)".to_string(),
505 capture_output: CaptureOutput::Full,
506 child_fields: Vec::new(),
507 max_results: None,
508 byte_range: None,
509 },
510 );
511 let config = python_config(patterns);
512 let result = extract("def foo():\n pass\n", &config).unwrap();
513 let cap = &result.results["names"].matches[0].captures[0];
514 assert!(cap.node.is_some(), "Full mode should include NodeInfo");
515 assert!(cap.text.is_some(), "Full mode should include text");
516 assert_eq!(cap.text.as_deref(), Some("foo"));
517 }
518
519 #[test]
520 fn test_child_fields_extraction() {
521 if skip_if_no_python() {
522 return;
523 }
524 let mut patterns = AHashMap::new();
525 patterns.insert(
526 "functions".to_string(),
527 ExtractionPattern {
528 query: "(function_definition) @fn_def".to_string(),
529 capture_output: CaptureOutput::Full,
530 child_fields: vec!["name".to_string(), "parameters".to_string()],
531 max_results: None,
532 byte_range: None,
533 },
534 );
535 let config = python_config(patterns);
536 let result = extract("def greet(name):\n pass\n", &config).unwrap();
537 let fns = &result.results["functions"];
538 assert_eq!(fns.matches.len(), 1);
539
540 let cap = &fns.matches[0].captures[0];
541 assert!(cap.child_fields.contains_key("name"));
542 assert_eq!(cap.child_fields["name"].as_deref(), Some("greet"));
543 assert!(cap.child_fields.contains_key("parameters"));
544 assert!(cap.child_fields["parameters"].is_some());
546 }
547
548 #[test]
549 fn test_validation_valid_query() {
550 if skip_if_no_python() {
551 return;
552 }
553 let config = python_config(single_pattern(
554 "fns",
555 "(function_definition name: (identifier) @fn_name)",
556 ));
557 let validation = validate_extraction(&config).unwrap();
558 assert!(validation.valid);
559 let pv = &validation.patterns["fns"];
560 assert!(pv.valid);
561 assert!(pv.capture_names.contains(&"fn_name".to_string()));
562 assert!(pv.errors.is_empty());
563 }
564
565 #[test]
566 fn test_validation_invalid_query() {
567 if skip_if_no_python() {
568 return;
569 }
570 let config = python_config(single_pattern("bad", "((((not valid syntax"));
571 let validation = validate_extraction(&config).unwrap();
572 assert!(!validation.valid);
573 let pv = &validation.patterns["bad"];
574 assert!(!pv.valid);
575 assert!(!pv.errors.is_empty());
576 }
577
578 #[test]
579 fn test_validation_unknown_language() {
580 let config = ExtractionConfig {
581 language: "nonexistent_xyz_lang".to_string(),
582 patterns: AHashMap::new(),
583 };
584 let result = validate_extraction(&config);
585 assert!(result.is_err());
586 }
587
588 #[test]
589 fn test_max_results_truncation() {
590 if skip_if_no_python() {
591 return;
592 }
593 let mut patterns = AHashMap::new();
594 patterns.insert(
595 "fns".to_string(),
596 ExtractionPattern {
597 query: "(function_definition name: (identifier) @fn_name)".to_string(),
598 capture_output: CaptureOutput::Text,
599 child_fields: Vec::new(),
600 max_results: Some(1),
601 byte_range: None,
602 },
603 );
604 let config = python_config(patterns);
605 let result = extract("def a():\n pass\ndef b():\n pass\ndef c():\n pass\n", &config).unwrap();
606 let fns = &result.results["fns"];
607 assert_eq!(fns.matches.len(), 1, "should be truncated to max_results=1");
608 assert_eq!(fns.total_count, 3, "total_count should reflect all matches");
609 }
610
611 #[test]
612 fn test_compiled_extraction_reuse() {
613 if skip_if_no_python() {
614 return;
615 }
616 let config = python_config(single_pattern(
617 "fns",
618 "(function_definition name: (identifier) @fn_name)",
619 ));
620 let compiled = CompiledExtraction::compile(&config).unwrap();
621
622 let r1 = compiled.extract("def a():\n pass\n").unwrap();
623 let r2 = compiled.extract("def x():\n pass\ndef y():\n pass\n").unwrap();
624
625 assert_eq!(r1.results["fns"].total_count, 1);
626 assert_eq!(r2.results["fns"].total_count, 2);
627 }
628
629 #[test]
630 fn test_empty_results() {
631 if skip_if_no_python() {
632 return;
633 }
634 let config = python_config(single_pattern(
635 "classes",
636 "(class_definition name: (identifier) @cls_name)",
637 ));
638 let result = extract("x = 1\n", &config).unwrap();
640 let classes = &result.results["classes"];
641 assert!(classes.matches.is_empty());
642 assert_eq!(classes.total_count, 0);
643 }
644
645 #[test]
646 fn test_send_sync() {
647 fn assert_send<T: Send>() {}
648 fn assert_sync<T: Sync>() {}
649 assert_send::<CompiledExtraction>();
650 assert_sync::<CompiledExtraction>();
651 assert_send::<ExtractionResult>();
652 assert_sync::<ExtractionResult>();
653 assert_send::<ExtractionConfig>();
654 assert_sync::<ExtractionConfig>();
655 assert_send::<CaptureOutput>();
656 assert_sync::<CaptureOutput>();
657 }
658
659 #[test]
660 fn test_byte_range_restriction() {
661 if skip_if_no_python() {
662 return;
663 }
664 let source = "def a():\n pass\ndef b():\n pass\ndef c():\n pass\n";
665 let second_fn_start = source.find("def b").unwrap();
667 let second_fn_end = source[second_fn_start..]
668 .find("def c")
669 .map_or(source.len(), |i| second_fn_start + i);
670 let mut patterns = AHashMap::new();
671 patterns.insert(
672 "fns".to_string(),
673 ExtractionPattern {
674 query: "(function_definition name: (identifier) @fn_name)".to_string(),
675 capture_output: CaptureOutput::Text,
676 child_fields: Vec::new(),
677 max_results: None,
678 byte_range: Some((second_fn_start, second_fn_end)),
679 },
680 );
681 let config = python_config(patterns);
682 let result = extract(source, &config).unwrap();
683 let fns = &result.results["fns"];
684 assert_eq!(fns.matches.len(), 1, "byte_range should restrict to one function");
685 assert_eq!(
686 fns.matches[0].captures[0].text.as_deref(),
687 Some("b"),
688 "should capture function 'b' within the byte range"
689 );
690 }
691
692 #[test]
693 fn test_result_ordering() {
694 if skip_if_no_python() {
695 return;
696 }
697 for mode in [CaptureOutput::Text, CaptureOutput::Node, CaptureOutput::Full] {
699 let mut patterns = AHashMap::new();
700 patterns.insert(
701 "fns".to_string(),
702 ExtractionPattern {
703 query: "(function_definition name: (identifier) @fn_name)".to_string(),
704 capture_output: mode.clone(),
705 child_fields: Vec::new(),
706 max_results: None,
707 byte_range: None,
708 },
709 );
710 let config = python_config(patterns);
711 let result = extract(
712 "def alpha():\n pass\ndef beta():\n pass\ndef gamma():\n pass\n",
713 &config,
714 )
715 .unwrap();
716 let fns = &result.results["fns"];
717 assert_eq!(fns.matches.len(), 3);
718
719 let start_bytes: Vec<usize> = fns.matches.iter().map(|m| m.captures[0].start_byte).collect();
721 for pair in start_bytes.windows(2) {
722 assert!(
723 pair[0] < pair[1],
724 "results should be sorted by position, got {start_bytes:?} in mode {mode:?}"
725 );
726 }
727 }
728 }
729
730 #[test]
731 fn test_extract_from_tree() {
732 if skip_if_no_python() {
733 return;
734 }
735 let config = python_config(single_pattern(
736 "fns",
737 "(function_definition name: (identifier) @fn_name)",
738 ));
739 let compiled = CompiledExtraction::compile(&config).unwrap();
740
741 let source = "def hello():\n pass\n";
742 let tree = crate::parse::parse_string("python", source.as_bytes()).unwrap();
743 let result = compiled.extract_from_tree(&tree, source.as_bytes()).unwrap();
744
745 assert_eq!(result.results["fns"].total_count, 1);
746 let cap = &result.results["fns"].matches[0].captures[0];
747 assert_eq!(cap.text.as_deref(), Some("hello"));
748 }
749}