thread_ast_engine/tree_sitter/
mod.rs1pub mod traversal;
69
70use crate::node::Root;
71
72use crate::AstGrep;
73#[cfg(feature = "matching")]
74use crate::Matcher;
75#[cfg(feature = "matching")]
76use crate::replacer::Replacer;
77use crate::source::{Content, Doc, Edit, SgNode};
78use crate::{Language, Position, node::KindId};
79use std::borrow::Cow;
80use std::num::NonZero;
81use thiserror::Error;
82#[cfg(feature = "matching")]
83use thread_utilities::RapidMap;
84pub use traversal::{TsPre, Visitor};
85pub use tree_sitter::Language as TSLanguage;
86use tree_sitter::{InputEdit, LanguageError, Node, Parser, Point, Tree};
87pub use tree_sitter::{Point as TSPoint, Range as TSRange};
88
89#[derive(Debug, Error)]
95pub enum TSParseError {
96 #[error("incompatible `Language` is assigned to a `Parser`.")]
102 Language(#[from] LanguageError),
103
104 #[error("general error when tree-sitter fails to parse.")]
115 TreeUnavailable,
116}
117
118#[inline]
119fn parse_lang(
120 parse_fn: impl Fn(&mut Parser) -> Option<Tree>,
121 ts_lang: &TSLanguage,
122) -> Result<Tree, TSParseError> {
123 let mut parser = Parser::new();
124 parser.set_language(ts_lang)?;
125 if let Some(tree) = parse_fn(&mut parser) {
126 Ok(tree)
127 } else {
128 Err(TSParseError::TreeUnavailable)
129 }
130}
131
132#[derive(Clone, Debug)]
161pub struct StrDoc<L: LanguageExt> {
162 pub src: String,
164 pub lang: L,
166 pub tree: Tree,
168}
169
170impl<L: LanguageExt> StrDoc<L> {
171 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
172 let src = src.to_string();
173 let ts_lang = lang.get_ts_language();
174 let tree =
175 parse_lang(|p| p.parse(src.as_bytes(), None), &ts_lang).map_err(|e| e.to_string())?;
176 Ok(Self { src, lang, tree })
177 }
178 pub fn new(src: &str, lang: L) -> Self {
179 Self::try_new(src, lang).expect("Parser tree error")
180 }
181 fn parse(&self, old_tree: Option<&Tree>) -> Result<Tree, TSParseError> {
182 let source = self.get_source();
183 let lang = self.get_lang().get_ts_language();
184 parse_lang(|p| p.parse(source.as_bytes(), old_tree), &lang)
185 }
186}
187
188impl<L: LanguageExt> Doc for StrDoc<L> {
189 type Source = String;
190 type Lang = L;
191 type Node<'r> = Node<'r>;
192 fn get_lang(&self) -> &Self::Lang {
193 &self.lang
194 }
195 fn get_source(&self) -> &Self::Source {
196 &self.src
197 }
198 fn do_edit(&mut self, edit: &Edit<Self::Source>) -> Result<(), String> {
199 let source = &mut self.src;
200 perform_edit(&mut self.tree, source, edit);
201 self.tree = self.parse(Some(&self.tree)).map_err(|e| e.to_string())?;
202 Ok(())
203 }
204 fn root_node(&self) -> Node<'_> {
205 self.tree.root_node()
206 }
207 fn get_node_text<'a>(&'a self, node: &Self::Node<'a>) -> Cow<'a, str> {
208 Cow::Borrowed(
209 node.utf8_text(self.src.as_bytes())
210 .expect("invalid source text encoding"),
211 )
212 }
213}
214
215struct NodeWalker<'tree> {
216 cursor: tree_sitter::TreeCursor<'tree>,
217 count: usize,
218}
219
220impl<'tree> Iterator for NodeWalker<'tree> {
221 type Item = Node<'tree>;
222 fn next(&mut self) -> Option<Self::Item> {
223 if self.count == 0 {
224 return None;
225 }
226 let ret = Some(self.cursor.node());
227 self.cursor.goto_next_sibling();
228 self.count -= 1;
229 ret
230 }
231}
232
233impl ExactSizeIterator for NodeWalker<'_> {
234 fn len(&self) -> usize {
235 self.count
236 }
237}
238
239impl<'r> SgNode<'r> for Node<'r> {
240 fn parent(&self) -> Option<Self> {
241 Node::parent(self)
242 }
243 fn ancestors(&self, root: Self) -> impl Iterator<Item = Self> {
244 let mut ancestor = Some(root);
245 let self_id = self.id();
246 std::iter::from_fn(move || {
247 let inner = ancestor.take()?;
248 if inner.id() == self_id {
249 return None;
250 }
251 ancestor = inner.child_with_descendant(*self);
252 Some(inner)
253 })
254 .collect::<Vec<_>>()
256 .into_iter()
257 .rev()
258 }
259 fn dfs(&self) -> impl Iterator<Item = Self> {
260 TsPre::new(self)
261 }
262 fn child(&self, nth: usize) -> Option<Self> {
263 Node::child(self, nth)
265 }
266 fn children(&self) -> impl ExactSizeIterator<Item = Self> {
267 let mut cursor = self.walk();
268 cursor.goto_first_child();
269 NodeWalker {
270 cursor,
271 count: self.child_count(),
272 }
273 }
274 fn child_by_field_id(&self, field_id: u16) -> Option<Self> {
275 Node::child_by_field_id(self, field_id)
276 }
277 fn next(&self) -> Option<Self> {
278 self.next_sibling()
279 }
280 fn prev(&self) -> Option<Self> {
281 self.prev_sibling()
282 }
283 fn next_all(&self) -> impl Iterator<Item = Self> {
284 let node = self.parent().unwrap_or(*self);
286 let mut cursor = node.walk();
287 cursor.goto_first_child_for_byte(self.start_byte());
288 std::iter::from_fn(move || {
289 if cursor.goto_next_sibling() {
290 Some(cursor.node())
291 } else {
292 None
293 }
294 })
295 }
296 fn prev_all(&self) -> impl Iterator<Item = Self> {
297 let node = self.parent().unwrap_or(*self);
299 let mut cursor = node.walk();
300 cursor.goto_first_child_for_byte(self.start_byte());
301 std::iter::from_fn(move || {
302 if cursor.goto_previous_sibling() {
303 Some(cursor.node())
304 } else {
305 None
306 }
307 })
308 }
309 fn is_named(&self) -> bool {
310 Node::is_named(self)
311 }
312 fn is_named_leaf(&self) -> bool {
315 self.named_child_count() == 0
316 }
317 fn is_leaf(&self) -> bool {
318 self.child_count() == 0
319 }
320 fn kind(&self) -> Cow<'_, str> {
321 Cow::Borrowed(Node::kind(self))
322 }
323 fn kind_id(&self) -> KindId {
324 Node::kind_id(self)
325 }
326 fn node_id(&self) -> usize {
327 self.id()
328 }
329 fn range(&self) -> std::ops::Range<usize> {
330 self.start_byte()..self.end_byte()
331 }
332 fn start_pos(&self) -> Position {
333 let pos = self.start_position();
334 let byte = self.start_byte();
335 Position::new(pos.row, pos.column, byte)
336 }
337 fn end_pos(&self) -> Position {
338 let pos = self.end_position();
339 let byte = self.end_byte();
340 Position::new(pos.row, pos.column, byte)
341 }
342 fn is_missing(&self) -> bool {
344 Node::is_missing(self)
345 }
346 fn is_error(&self) -> bool {
347 Node::is_error(self)
348 }
349
350 fn field(&self, name: &str) -> Option<Self> {
351 self.child_by_field_name(name)
352 }
353 fn field_children(&self, field_id: Option<u16>) -> impl Iterator<Item = Self> {
354 let field_id = field_id.and_then(NonZero::new);
355 let mut cursor = self.walk();
356 cursor.goto_first_child();
357 let mut done = field_id.is_none();
359
360 std::iter::from_fn(move || {
361 if done {
362 return None;
363 }
364 while cursor.field_id() != field_id {
365 if !cursor.goto_next_sibling() {
366 return None;
367 }
368 }
369 let ret = cursor.node();
370 if !cursor.goto_next_sibling() {
371 done = true;
372 }
373 Some(ret)
374 })
375 }
376}
377
378pub fn perform_edit<S: ContentExt>(tree: &mut Tree, input: &mut S, edit: &Edit<S>) -> InputEdit {
379 let edit = input.accept_edit(edit);
380 tree.edit(&edit);
381 edit
382}
383
384pub trait LanguageExt: Language {
418 fn ast_grep<S: AsRef<str>>(&self, source: S) -> AstGrep<StrDoc<Self>> {
434 AstGrep::new(source, self.clone())
435 }
436
437 fn get_ts_language(&self) -> TSLanguage;
447
448 fn injectable_languages(&self) -> Option<&'static [&'static str]> {
458 None
459 }
460
461 #[cfg(feature = "matching")]
481 fn extract_injections<L: LanguageExt>(
482 &self,
483 _root: crate::Node<StrDoc<L>>,
484 ) -> RapidMap<String, Vec<TSRange>> {
485 RapidMap::default()
486 }
487}
488
489fn position_for_offset(input: &[u8], offset: usize) -> Point {
490 debug_assert!(offset <= input.len());
491 let (mut row, mut col) = (0, 0);
492 for c in &input[0..offset] {
493 if *c as char == '\n' {
494 row += 1;
495 col = 0;
496 } else {
497 col += 1;
498 }
499 }
500 Point::new(row, col)
501}
502
503impl<L: LanguageExt> AstGrep<StrDoc<L>> {
504 pub fn new<S: AsRef<str>>(src: S, lang: L) -> Self {
505 Self::str(src.as_ref(), lang)
506 }
507
508 pub fn source(&self) -> &str {
509 self.doc.get_source().as_str()
510 }
511
512 pub fn generate(self) -> String {
513 self.doc.src
514 }
515}
516
517pub trait ContentExt: Content {
518 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit;
519}
520impl ContentExt for String {
521 fn accept_edit(&mut self, edit: &Edit<Self>) -> InputEdit {
522 let start_byte = edit.position;
523 let old_end_byte = edit.position + edit.deleted_length;
524 let new_end_byte = edit.position + edit.inserted_text.len();
525 let input = unsafe { self.as_mut_vec() };
526 let start_position = position_for_offset(input, start_byte);
527 let old_end_position = position_for_offset(input, old_end_byte);
528 input.splice(start_byte..old_end_byte, edit.inserted_text.clone());
529 let new_end_position = position_for_offset(input, new_end_byte);
530 InputEdit {
531 start_byte,
532 old_end_byte,
533 new_end_byte,
534 start_position,
535 old_end_position,
536 new_end_position,
537 }
538 }
539}
540
541impl<L: LanguageExt> Root<StrDoc<L>> {
542 pub fn str(src: &str, lang: L) -> Self {
543 Self::try_new(src, lang).expect("should parse")
544 }
545 pub fn try_new(src: &str, lang: L) -> Result<Self, String> {
546 let doc = StrDoc::try_new(src, lang)?;
547 Ok(Self { doc })
548 }
549 pub fn get_text(&self) -> &str {
550 &self.doc.src
551 }
552 #[cfg(feature = "matching")]
553 pub fn get_injections<F: Fn(&str) -> Option<L>>(&self, get_lang: F) -> Vec<Self> {
554 let root = self.root();
555 let range = self.lang().extract_injections(root);
556 range
557 .into_iter()
558 .filter_map(|(lang, ranges)| {
559 let lang = get_lang(&lang)?;
560 let source = self.doc.get_source();
561 let mut parser = Parser::new();
562 parser.set_included_ranges(&ranges).ok()?;
563 parser.set_language(&lang.get_ts_language()).ok()?;
564 let tree = parser.parse(source, None)?;
565 Some(Self {
566 doc: StrDoc {
567 src: self.doc.src.clone(),
568 lang,
569 tree,
570 },
571 })
572 })
573 .collect()
574 }
575}
576
577pub struct DisplayContext<'r> {
578 pub matched: Cow<'r, str>,
580 pub leading: &'r str,
582 pub trailing: &'r str,
584 pub start_line: usize,
586}
587
588impl<'r, L: LanguageExt> crate::Node<'r, StrDoc<L>> {
590 #[doc(hidden)]
591 #[must_use]
592 pub fn display_context(&self, before: usize, after: usize) -> DisplayContext<'r> {
593 let source = self.root.doc.get_source().as_str();
594 let bytes = source.as_bytes();
595 let start = self.inner.start_byte();
596 let end = self.inner.end_byte();
597 let (mut leading, mut trailing) = (start, end);
598 let mut lines_before = before + 1;
599 while leading > 0 {
600 if bytes[leading - 1] == b'\n' {
601 lines_before -= 1;
602 if lines_before == 0 {
603 break;
604 }
605 }
606 leading -= 1;
607 }
608 let mut lines_after = after + 1;
609 trailing = trailing.min(bytes.len());
611 while trailing < bytes.len() {
612 if bytes[trailing] == b'\n' {
613 lines_after -= 1;
614 if lines_after == 0 {
615 break;
616 }
617 }
618 trailing += 1;
619 }
620 let offset = if lines_before == 0 {
622 before
623 } else {
624 before + 1 - lines_before
626 };
627 DisplayContext {
628 matched: self.text(),
629 leading: &source[leading..start],
630 trailing: &source[end..trailing],
631 start_line: self.start_pos().line() - offset,
632 }
633 }
634
635 #[cfg(feature = "matching")]
636 pub fn replace_all<M: Matcher, R: Replacer<StrDoc<L>>>(
637 &self,
638 matcher: M,
639 replacer: R,
640 ) -> Vec<Edit<String>> {
641 Visitor::new(&matcher)
643 .reentrant(false)
644 .visit(self.clone())
645 .map(|matched| matched.make_edit(&matcher, &replacer))
646 .collect()
647 }
648}
649
650#[cfg(test)]
651mod test {
652 use super::*;
653 use crate::language::Tsx;
654 use tree_sitter::Point;
655
656 fn parse(src: &str) -> Result<Tree, TSParseError> {
657 parse_lang(|p| p.parse(src, None), &Tsx.get_ts_language())
658 }
659
660 #[test]
661 fn test_tree_sitter() -> Result<(), TSParseError> {
662 let tree = parse("var a = 1234")?;
663 let root_node = tree.root_node();
664 assert_eq!(root_node.kind(), "program");
665 assert_eq!(root_node.start_position().column, 0);
666 assert_eq!(root_node.end_position().column, 12);
667 assert_eq!(
668 root_node.to_sexp(),
669 "(program (variable_declaration (variable_declarator name: (identifier) value: (number))))"
670 );
671 Ok(())
672 }
673
674 #[test]
675 fn test_object_literal() -> Result<(), TSParseError> {
676 let tree = parse("{a: $X}")?;
677 let root_node = tree.root_node();
678 assert_eq!(
680 root_node.to_sexp(),
681 "(program (expression_statement (object (pair key: (property_identifier) value: (identifier)))))"
682 );
683 Ok(())
684 }
685
686 #[test]
687 fn test_string() -> Result<(), TSParseError> {
688 let tree = parse("'$A'")?;
689 let root_node = tree.root_node();
690 assert_eq!(
691 root_node.to_sexp(),
692 "(program (expression_statement (string (string_fragment))))"
693 );
694 Ok(())
695 }
696
697 #[test]
698 fn test_row_col() -> Result<(), TSParseError> {
699 let tree = parse("😄")?;
700 let root = tree.root_node();
701 assert_eq!(root.start_position(), Point::new(0, 0));
702 assert_eq!(root.end_position(), Point::new(0, 4));
704 Ok(())
705 }
706
707 #[test]
708 fn test_edit() -> Result<(), TSParseError> {
709 let mut src = "a + b".to_string();
710 let mut tree = parse(&src)?;
711 let _ = perform_edit(
712 &mut tree,
713 &mut src,
714 &Edit {
715 position: 1,
716 deleted_length: 0,
717 inserted_text: " * b".into(),
718 },
719 );
720 let tree2 = parse_lang(|p| p.parse(&src, Some(&tree)), &Tsx.get_ts_language())?;
721 assert_eq!(
722 tree.root_node().to_sexp(),
723 "(program (expression_statement (binary_expression left: (identifier) right: (identifier))))"
724 );
725 assert_eq!(
726 tree2.root_node().to_sexp(),
727 "(program (expression_statement (binary_expression left: (binary_expression left: (identifier) right: (identifier)) right: (identifier))))"
728 );
729 Ok(())
730 }
731}