1use std::cmp::Reverse;
2use std::iter::{self, Peekable};
3use std::mem::take;
4use std::sync::Arc;
5
6use arc_swap::ArcSwap;
7use hashbrown::{HashMap, HashSet};
8use once_cell::sync::Lazy;
9use regex_cursor::engines::meta::Regex;
10use ropey::RopeSlice;
11
12use crate::config::{LanguageConfig, LanguageLoader};
13use crate::highlighter::Highlight;
14use crate::locals::Locals;
15use crate::parse::LayerUpdateFlags;
16use crate::{Injection, Language, Layer, LayerData, Range, Syntax, TREE_SITTER_MATCH_LIMIT};
17use tree_sitter::{
18 query::{self, InvalidPredicateError, UserPredicate},
19 Capture, Grammar, InactiveQueryCursor, MatchedNodeIdx, Node, Pattern, Query, QueryMatch,
20};
21
22const SHEBANG: &str = r"#!\s*(?:\S*[/\\](?:env\s+(?:\-\S+\s+)*)?)?([^\s\.\d]+)";
23static SHEBANG_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(SHEBANG).unwrap());
24
25#[derive(Clone, Default, Debug)]
26pub struct InjectionProperties {
27 include_children: IncludedChildren,
28 language: Option<Box<str>>,
29 combined: bool,
30}
31
32#[derive(Debug, Clone, Copy)]
39pub enum InjectionLanguageMarker<'a> {
40 Name(&'a str),
45 Match(RopeSlice<'a>),
51 Filename(RopeSlice<'a>),
52 Shebang(RopeSlice<'a>),
53}
54
55#[derive(Clone, Debug)]
56pub struct InjectionQueryMatch<'tree> {
57 include_children: IncludedChildren,
58 language: Language,
59 scope: Option<InjectionScope>,
60 node: Node<'tree>,
61 last_match: bool,
62 pattern: Pattern,
63}
64
65#[derive(Clone, Debug, Hash, PartialEq, Eq)]
66enum InjectionScope {
67 Match {
68 id: u32,
69 },
70 Pattern {
71 pattern: Pattern,
72 language: Language,
73 },
74}
75
76#[derive(Clone, Copy, Default, Debug, PartialEq, Eq)]
77enum IncludedChildren {
78 #[default]
79 None,
80 All,
81 Unnamed,
82}
83
84#[derive(Debug)]
85pub struct InjectionsQuery {
86 injection_query: Query,
87 injection_properties: HashMap<Pattern, InjectionProperties>,
88 injection_content_capture: Option<Capture>,
89 injection_language_capture: Option<Capture>,
90 injection_filename_capture: Option<Capture>,
91 injection_shebang_capture: Option<Capture>,
92 pub(crate) local_query: Query,
94 pub(crate) not_scope_inherits: HashSet<Pattern>,
96 pub(crate) local_scope_capture: Option<Capture>,
97 pub(crate) local_definition_captures: ArcSwap<HashMap<Capture, Highlight>>,
98}
99
100impl InjectionsQuery {
101 pub fn new(
102 grammar: Grammar,
103 injection_query_text: &str,
104 local_query_text: &str,
105 ) -> Result<Self, query::ParseError> {
106 let mut query_source =
107 String::with_capacity(injection_query_text.len() + local_query_text.len());
108 query_source.push_str(injection_query_text);
109 query_source.push_str(local_query_text);
110
111 let mut injection_properties: HashMap<Pattern, InjectionProperties> = HashMap::new();
112 let mut not_scope_inherits = HashSet::new();
113 let injection_query = Query::new(grammar, injection_query_text, |pattern, predicate| {
114 match predicate {
115 UserPredicate::SetProperty {
117 key: "injection.include-unnamed-children",
118 val: None,
119 } => {
120 injection_properties
121 .entry(pattern)
122 .or_default()
123 .include_children = IncludedChildren::Unnamed
124 }
125 UserPredicate::SetProperty {
126 key: "injection.include-children",
127 val: None,
128 } => {
129 injection_properties
130 .entry(pattern)
131 .or_default()
132 .include_children = IncludedChildren::All
133 }
134 UserPredicate::SetProperty {
135 key: "injection.language",
136 val: Some(lang),
137 } => injection_properties.entry(pattern).or_default().language = Some(lang.into()),
138 UserPredicate::SetProperty {
139 key: "injection.combined",
140 val: None,
141 } => injection_properties.entry(pattern).or_default().combined = true,
142 predicate => {
143 return Err(InvalidPredicateError::unknown(predicate));
144 }
145 }
146 Ok(())
147 })?;
148 let mut local_query = Query::new(grammar, local_query_text, |pattern, predicate| {
149 match predicate {
150 UserPredicate::SetProperty {
151 key: "local.scope-inherits",
152 val,
153 } => {
154 if val.is_some_and(|val| val != "true") {
155 not_scope_inherits.insert(pattern);
156 }
157 }
158 predicate => {
159 return Err(InvalidPredicateError::unknown(predicate));
160 }
161 }
162 Ok(())
163 })?;
164
165 local_query.disable_capture("local.reference");
168
169 Ok(InjectionsQuery {
170 injection_properties,
171 injection_content_capture: injection_query.get_capture("injection.content"),
172 injection_language_capture: injection_query.get_capture("injection.language"),
173 injection_filename_capture: injection_query.get_capture("injection.filename"),
174 injection_shebang_capture: injection_query.get_capture("injection.shebang"),
175 injection_query,
176 not_scope_inherits,
177 local_scope_capture: local_query.get_capture("local.scope"),
178 local_definition_captures: ArcSwap::from_pointee(HashMap::new()),
179 local_query,
180 })
181 }
182
183 pub(crate) fn configure(&self, f: &mut impl FnMut(&str) -> Option<Highlight>) {
184 let local_definition_captures = self
185 .local_query
186 .captures()
187 .filter_map(|(capture, name)| {
188 let suffix = name.strip_prefix("local.definition.")?;
189 Some((capture, f(suffix)?))
190 })
191 .collect();
192 self.local_definition_captures
193 .store(Arc::new(local_definition_captures));
194 }
195
196 fn process_match<'a, 'tree>(
197 &self,
198 query_match: &QueryMatch<'a, 'tree>,
199 node_idx: MatchedNodeIdx,
200 source: RopeSlice<'a>,
201 loader: impl LanguageLoader,
202 ) -> Option<InjectionQueryMatch<'tree>> {
203 let properties = self.injection_properties.get(&query_match.pattern());
204
205 let mut marker = None;
206 let mut last_content_node = 0;
207 let mut content_nodes = 0;
208 for (i, matched_node) in query_match.matched_nodes().enumerate() {
209 let capture = Some(matched_node.capture);
210 if capture == self.injection_language_capture {
211 let range = matched_node.node.byte_range();
212 marker = Some(InjectionLanguageMarker::Match(
213 source.byte_slice(range.start as usize..range.end as usize),
214 ));
215 } else if capture == self.injection_filename_capture {
216 let range = matched_node.node.byte_range();
217 marker = Some(InjectionLanguageMarker::Filename(
218 source.byte_slice(range.start as usize..range.end as usize),
219 ));
220 } else if capture == self.injection_shebang_capture {
221 let range = matched_node.node.byte_range();
222 let node_slice = source.byte_slice(range.start as usize..range.end as usize);
223
224 let lines = if let Ok(end) = node_slice.try_line_to_byte(2) {
227 node_slice.byte_slice(..end)
228 } else {
229 node_slice
230 };
231
232 marker = SHEBANG_REGEX
233 .captures_iter(regex_cursor::Input::new(lines))
234 .map(|cap| {
235 let cap = lines.byte_slice(cap.get_group(1).unwrap().range());
236 InjectionLanguageMarker::Shebang(cap)
237 })
238 .next()
239 } else if capture == self.injection_content_capture {
240 content_nodes += 1;
241
242 last_content_node = i as u32;
243 }
244 }
245 let marker = marker.or(properties
246 .and_then(|p| p.language.as_deref())
247 .map(InjectionLanguageMarker::Name))?;
248
249 let language = loader.language_for_marker(marker)?;
250 let scope = if properties.is_some_and(|p| p.combined) {
251 Some(InjectionScope::Pattern {
252 pattern: query_match.pattern(),
253 language,
254 })
255 } else if content_nodes != 1 {
256 Some(InjectionScope::Match {
257 id: query_match.id(),
258 })
259 } else {
260 None
261 };
262
263 Some(InjectionQueryMatch {
264 language,
265 scope,
266 include_children: properties.map(|p| p.include_children).unwrap_or_default(),
267 node: query_match.matched_node(node_idx).node.clone(),
268 last_match: last_content_node == node_idx,
269 pattern: query_match.pattern(),
270 })
271 }
272
273 fn execute<'a>(
288 &'a self,
289 node: &Node<'a>,
290 source: RopeSlice<'a>,
291 loader: &'a impl LanguageLoader,
292 ) -> impl Iterator<Item = InjectionQueryMatch<'a>> + 'a {
293 let mut cursor = InactiveQueryCursor::new(0..u32::MAX, TREE_SITTER_MATCH_LIMIT)
294 .execute_query(&self.injection_query, node, source);
295 let injection_content_capture = self.injection_content_capture.unwrap();
296 let iter = iter::from_fn(move || loop {
297 let (query_match, node_idx) = cursor.next_matched_node()?;
298 if query_match.matched_node(node_idx).capture != injection_content_capture {
299 continue;
300 }
301 let Some(mat) = self.process_match(&query_match, node_idx, source, loader) else {
302 query_match.remove();
303 continue;
304 };
305 let range = query_match.matched_node(node_idx).node.byte_range();
306 if mat.last_match {
307 query_match.remove();
308 }
309 if range.is_empty() {
310 continue;
311 }
312 break Some(mat);
313 });
314 let mut buf = Vec::new();
315 let mut iter = iter.peekable();
316 iter::from_fn(move || {
318 if let Some(mat) = buf.pop() {
319 return Some(mat);
320 }
321 let mut res = iter.next()?;
322 if res.include_children == IncludedChildren::None {
326 let mut fast_return = true;
327 while let Some(overlap) =
328 iter.next_if(|mat| mat.node.byte_range() == res.node.byte_range())
329 {
330 if overlap.include_children != IncludedChildren::None {
331 buf.push(overlap);
332 fast_return = false;
333 break;
334 }
335 res = overlap;
337 }
338 if fast_return {
339 return Some(res);
340 }
341 }
342
343 while let Some(overlap) = iter.next_if(|mat| mat.node.end_byte() <= res.node.end_byte())
346 {
347 buf.push(overlap)
348 }
349 if buf.is_empty() {
350 return Some(res);
351 }
352 buf.push(res);
353 buf.sort_unstable_by_key(|mat| (mat.pattern, Reverse(mat.node.start_byte())));
354 buf.pop()
355 })
356 }
357}
358
359impl Syntax {
360 pub(crate) fn run_injection_query(
361 &mut self,
362 layer: Layer,
363 edits: &[tree_sitter::InputEdit],
364 source: RopeSlice<'_>,
365 loader: &impl LanguageLoader,
366 mut parse_layer: impl FnMut(Layer),
367 ) {
368 self.map_injections(layer, None, edits);
369 let layer_data = &mut self.layer_mut(layer);
370 let Some(LanguageConfig {
371 injection_query: ref injections_query,
372 ..
373 }) = loader.get_config(layer_data.language)
374 else {
375 return;
376 };
377 if injections_query.injection_content_capture.is_none() {
378 return;
379 }
380
381 let parent_ranges = take(&mut layer_data.ranges);
383 let parse_tree = layer_data.parse_tree.take().unwrap();
384 let mut injections: Vec<Injection> = Vec::with_capacity(layer_data.injections.len());
385 let mut old_injections = take(&mut layer_data.injections).into_iter().peekable();
386
387 let injection_query = injections_query.execute(&parse_tree.root_node(), source, loader);
388
389 let mut combined_injections: HashMap<InjectionScope, Layer> = HashMap::with_capacity(32);
390 for mat in injection_query {
391 let matched_node_range = mat.node.byte_range();
392 let mut insert_position = injections.len();
393 if let Some(last_injection) = injections
404 .last()
405 .filter(|injection| ranges_intersect(&injection.range, &matched_node_range))
406 {
407 if last_injection.range.start <= matched_node_range.start {
410 continue;
411 } else {
412 insert_position = injections.partition_point(|injection| {
413 injection.range.end <= matched_node_range.start
414 });
415 if injections[insert_position].range.start < matched_node_range.end {
416 continue;
417 }
418 }
419 }
420
421 let language = mat.language;
422 let reused_injection =
423 self.reuse_injection(language, matched_node_range.clone(), &mut old_injections);
424 let layer = match mat.scope {
425 Some(scope @ InjectionScope::Match { .. }) if mat.last_match => {
426 combined_injections.remove(&scope).unwrap_or_else(|| {
427 self.init_injection(layer, mat.language, reused_injection.clone())
428 })
429 }
430 Some(scope) => *combined_injections.entry(scope).or_insert_with(|| {
431 self.init_injection(layer, mat.language, reused_injection.clone())
432 }),
433 None => self.init_injection(layer, mat.language, reused_injection.clone()),
434 };
435 let mut layer_data = self.layer_mut(layer);
436 if !layer_data.flags.touched {
437 layer_data.flags.touched = true;
438 parse_layer(layer)
439 }
440 if layer_data.flags.reused {
441 layer_data.flags.modified |= reused_injection.as_ref().map_or(true, |injection| {
442 injection.matched_node_range != matched_node_range || injection.layer != layer
443 });
444 } else if let Some(reused_injection) = reused_injection {
445 layer_data.flags.reused = true;
446 layer_data.flags.modified = true;
447 let reused_parse_tree = self.layer(reused_injection.layer).tree().cloned();
448 layer_data = self.layer_mut(layer);
449 layer_data.parse_tree = reused_parse_tree;
450 }
451
452 let old_len = injections.len();
453 intersect_ranges(mat.include_children, mat.node, &parent_ranges, |range| {
454 layer_data.ranges.push(tree_sitter::Range {
455 start_point: tree_sitter::Point::ZERO,
456 end_point: tree_sitter::Point::ZERO,
457 start_byte: range.start,
458 end_byte: range.end,
459 });
460 injections.push(Injection {
461 range,
462 layer,
463 matched_node_range: matched_node_range.clone(),
464 });
465 });
466 if old_len != insert_position {
467 let inserted = injections.len() - old_len;
468 injections[insert_position..].rotate_right(inserted);
469 layer_data.ranges[insert_position..].rotate_right(inserted);
470 }
471 }
472
473 for old_injection in old_injections {
477 self.layer_mut(old_injection.layer).flags.modified = true;
478 }
479
480 let layer_data = &mut self.layer_mut(layer);
481 layer_data.ranges = parent_ranges;
482 layer_data.parse_tree = Some(parse_tree);
483 layer_data.injections = injections;
484 }
485
486 fn map_injections(
488 &mut self,
489 layer: Layer,
490 offset: Option<i32>,
492 mut edits: &[tree_sitter::InputEdit],
493 ) {
494 if edits.is_empty() && offset.unwrap_or(0) == 0 {
495 return;
496 }
497 let layer_data = self.layer_mut(layer);
498 let first_relevant_injection = layer_data
499 .injections
500 .partition_point(|injection| injection.range.end < edits[0].start_byte);
501 if first_relevant_injection == layer_data.injections.len() {
502 return;
503 }
504 let mut offset = if let Some(offset) = offset {
505 let first_relevant_edit = edits.partition_point(|edit| {
506 (edit.old_end_byte as i32) < (layer_data.ranges[0].end_byte as i32 - offset)
507 });
508 edits = &edits[first_relevant_edit..];
509 offset
510 } else {
511 0
512 };
513 let mut edits = edits.iter().peekable();
516 let mut injections = take(&mut layer_data.injections);
517 for injection in &mut injections[first_relevant_injection..] {
518 let injection_range = &mut injection.range;
519 let matched_node_range = &mut injection.matched_node_range;
520 let flags = &mut self.layer_mut(injection.layer).flags;
521
522 debug_assert!(matched_node_range.start <= injection_range.start);
523 debug_assert!(matched_node_range.end >= injection_range.end);
524
525 while let Some(edit) =
526 edits.next_if(|edit| edit.old_end_byte < matched_node_range.start)
527 {
528 offset += edit.offset();
529 }
530 let mut mapped_node_range_start = (matched_node_range.start as i32 + offset) as u32;
531 if let Some(edit) = edits
532 .peek()
533 .filter(|edit| edit.start_byte <= matched_node_range.start)
534 {
535 mapped_node_range_start = (edit.new_end_byte as i32 + offset) as u32;
536 }
537 while let Some(edit) = edits.next_if(|edit| edit.old_end_byte < injection_range.start) {
538 offset += edit.offset();
539 }
540 flags.moved = offset != 0;
541 let mut mapped_start = (injection_range.start as i32 + offset) as u32;
542 if let Some(edit) = edits.next_if(|edit| edit.old_end_byte <= injection_range.end) {
543 if edit.start_byte < injection_range.start {
544 flags.moved = true;
545 mapped_start = (edit.new_end_byte as i32 + offset) as u32;
546 } else {
547 flags.modified = true;
548 }
549 offset += edit.offset();
550 while let Some(edit) =
551 edits.next_if(|edit| edit.old_end_byte <= injection_range.end)
552 {
553 offset += edit.offset();
554 }
555 }
556 let mut mapped_end = (injection_range.end as i32 + offset) as u32;
557 if let Some(edit) = edits
558 .peek()
559 .filter(|edit| edit.start_byte <= injection_range.end)
560 {
561 flags.modified = true;
562
563 if edit.start_byte < injection_range.start {
564 mapped_start = (edit.new_end_byte as i32 + offset) as u32;
565 mapped_end = mapped_start;
566 }
567 }
568 let mut mapped_node_range_end = (matched_node_range.end as i32 + offset) as u32;
569 if let Some(edit) = edits
570 .peek()
571 .filter(|edit| edit.start_byte <= matched_node_range.end)
572 {
573 if edit.start_byte < matched_node_range.start {
574 mapped_node_range_start = (edit.new_end_byte as i32 + offset) as u32;
575 mapped_node_range_end = mapped_node_range_start;
576 }
577 }
578 *injection_range = mapped_start..mapped_end;
579 *matched_node_range = mapped_node_range_start..mapped_node_range_end;
580 }
581 self.layer_mut(layer).injections = injections;
582 }
583
584 fn init_injection(
585 &mut self,
586 parent: Layer,
587 language: Language,
588 reuse: Option<Injection>,
589 ) -> Layer {
590 match reuse {
591 Some(old_injection) => {
592 let layer_data = self.layer_mut(old_injection.layer);
593 debug_assert_eq!(layer_data.parent, Some(parent));
594 layer_data.flags.reused = true;
595 layer_data.ranges.clear();
596 old_injection.layer
597 }
598 None => {
599 let layer = self.layers.insert(LayerData {
600 language,
601 parse_tree: None,
602 ranges: Vec::new(),
603 injections: Vec::new(),
604 flags: LayerUpdateFlags::default(),
605 parent: Some(parent),
606 locals: Locals::default(),
607 });
608 Layer(layer as u32)
609 }
610 }
611 }
612
613 fn reuse_injection(
615 &mut self,
616 language: Language,
617 new_range: Range,
618 injections: &mut Peekable<impl Iterator<Item = Injection>>,
619 ) -> Option<Injection> {
620 while let Some(skipped) =
621 injections.next_if(|injection| injection.range.end <= new_range.start)
622 {
623 self.layer_mut(skipped.layer).flags.modified = true;
629 }
630 injections
631 .next_if(|injection| {
632 injection.range.start < new_range.end
633 && self.layer(injection.layer).language == language
634 && !self.layer(injection.layer).flags.reused
635 })
636 .clone()
637 }
638}
639
640fn intersect_ranges(
641 include_children: IncludedChildren,
642 node: Node,
643 parent_ranges: &[tree_sitter::Range],
644 push_range: impl FnMut(Range),
645) {
646 let range = node.byte_range();
647 let i = parent_ranges.partition_point(|parent_range| parent_range.end_byte <= range.start);
648 let parent_ranges = parent_ranges[i..]
649 .iter()
650 .map(|range| range.start_byte..range.end_byte);
651 match include_children {
652 IncludedChildren::None => intersect_ranges_impl(
653 range,
654 node.children().map(|node| node.byte_range()),
655 parent_ranges,
656 push_range,
657 ),
658 IncludedChildren::All => {
659 intersect_ranges_impl(range, [].into_iter(), parent_ranges, push_range)
660 }
661 IncludedChildren::Unnamed => intersect_ranges_impl(
662 range,
663 node.children()
664 .filter(|node| node.is_named())
665 .map(|node| node.byte_range()),
666 parent_ranges,
667 push_range,
668 ),
669 }
670}
671
672fn intersect_ranges_impl(
673 range: Range,
674 excluded_ranges: impl Iterator<Item = Range>,
675 parent_ranges: impl Iterator<Item = Range>,
676 mut push_range: impl FnMut(Range),
677) {
678 let mut start = range.start;
679 let mut excluded_ranges = excluded_ranges.filter(|range| !range.is_empty()).peekable();
680 let mut parent_ranges = parent_ranges.peekable();
681 loop {
682 let parent_range = parent_ranges.peek().unwrap().clone();
683 if let Some(excluded_range) =
684 excluded_ranges.next_if(|range| range.start <= parent_range.end)
685 {
686 if excluded_range.start >= range.end {
687 break;
688 }
689 if start != excluded_range.start {
690 push_range(start..excluded_range.start)
691 }
692 start = excluded_range.end;
693 } else {
694 parent_ranges.next();
695 if parent_range.end >= range.end {
696 break;
697 }
698 if start != parent_range.end {
699 push_range(start..parent_range.end)
700 }
701 let Some(next_parent_range) = parent_ranges.peek() else {
702 return;
703 };
704 start = next_parent_range.start;
705 }
706 }
707 if start != range.end {
708 push_range(start..range.end)
709 }
710}
711
712fn ranges_intersect(a: &Range, b: &Range) -> bool {
713 a.start == b.start || (a.end > b.start && b.end > a.start)
715}