ts_macros/
lib.rs

1use std::collections::{hash_map::Entry, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::{Delimiter, Literal, TokenStream, TokenTree};
5use proc_macro2_diagnostics::SpanDiagnosticExt;
6use quote::{quote, quote_spanned, ToTokens};
7use syn::{parse::Parse, punctuated::Punctuated, *};
8
9/// Define a [tree-sitter query], optionally extracting its named captures into an enum.
10///
11/// *Usage:*
12/// ```rust,noplayground
13/// ts_macros::query! {
14///     MyQuery(Foo, Bar);
15/// (function_definition
16///  (parameters . (string) @FOO)
17///  (block
18///    (expression_statement
19///      (call
20///        (_) @callee
21///        (parameters . (string) @BAR)))))
22/// };
23/// ```
24///
25/// Generates:
26/// ```rust,noplayground
27/// pub enum MyQuery {
28///     Foo = 0,
29///     Bar = 2,
30/// }
31/// impl MyQuery {
32///     pub fn query() -> &'static Query;
33///     pub fn from(raw: u32) -> Option<Self>;
34/// }
35/// ```
36/// [tree-sitter query]: https://tree-sitter.github.io/tree-sitter/using-parsers#query-syntax
37#[proc_macro]
38pub fn query(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
39	let def = syn::parse_macro_input!(tokens as QueryDefinition);
40	def.into_tokens(TsLang::Python).into()
41}
42
43#[proc_macro]
44#[deprecated = "Use `query` with #[lang = \"..\"] instead"]
45pub fn query_js(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
46	let def = syn::parse_macro_input!(tokens as QueryDefinition);
47	def.into_tokens(TsLang::Javascript).into()
48}
49
50struct QueryDefinition {
51	meta: Vec<TokenStream>,
52	lang: Option<syn::LitStr>,
53	name: syn::Ident,
54	captures: Punctuated<Ident, Token![,]>,
55	query: TokenStream,
56}
57
58mod kw {
59	syn::custom_keyword!(lang);
60}
61
62impl Parse for QueryDefinition {
63	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
64		let mut meta = vec![];
65		let mut lang = None;
66		while input.peek(Token![#]) {
67			input.parse::<Token![#]>()?;
68			let content;
69			bracketed!(content in input);
70			if content.peek(kw::lang) {
71				content.parse::<kw::lang>()?;
72				content.parse::<Token![=]>()?;
73				lang = Some(content.parse()?);
74				continue;
75			}
76			meta.push(content.parse()?);
77		}
78		let name = input.parse()?;
79		let contents;
80		parenthesized!(contents in input);
81		let captures = Punctuated::parse_terminated(&contents)?;
82		input.parse::<Token![;]>()?;
83		let template = input.parse()?;
84		Ok(Self {
85			name,
86			meta,
87			captures,
88			query: template,
89			lang,
90		})
91	}
92}
93
94enum TsLang {
95	Python,
96	Javascript,
97	Custom(syn::LitStr),
98}
99
100impl ToTokens for TsLang {
101	fn to_tokens(&self, tokens: &mut TokenStream) {
102		match self {
103			Self::Python => tokens.extend(quote!(::tree_sitter_python)),
104			Self::Javascript => tokens.extend(quote!(::tree_sitter_javascript)),
105			Self::Custom(lang) => match syn::parse_str::<syn::Path>(&lang.value()) {
106				Ok(lang) => tokens.extend(quote!(#lang)),
107				Err(err) => {
108					let report = err.to_compile_error();
109					tokens.extend(quote_spanned!(lang.span() => #report))
110				}
111			},
112		}
113	}
114}
115
116fn tokens_to_string(tokens: TokenStream, output: &mut String) {
117	let mut tokens = tokens.into_iter().peekable();
118	while let Some(token) = tokens.next() {
119		match token {
120			TokenTree::Group(group) => {
121				let (lhs, rhs) = match group.delimiter() {
122					Delimiter::Parenthesis => ("(", ")"),
123					Delimiter::Brace => ("{", "}"),
124					Delimiter::Bracket => ("[", "]"),
125					Delimiter::None => (" ", " "),
126				};
127				output.push_str(lhs);
128				tokens_to_string(group.stream(), output);
129				output.push_str(rhs);
130			}
131			// if an identifier follows any of these punctuations, we must not add a space.
132			TokenTree::Punct(punct) if matches!(punct.as_char(), '@' | '#') => {
133				output.push(' ');
134				output.push(punct.as_char());
135				let Some(TokenTree::Ident(ident)) = tokens.peek() else {
136					continue;
137				};
138				output.push_str(&ident.to_string());
139				tokens.next();
140				let mut ident_allowed = false;
141				loop {
142					// A dash is part of a valid Scheme identifier, so it allows at most one more identifier.
143					// Any other punctuation (usually ! or ?) marks the end of the identifier.
144					match tokens.peek() {
145						Some(TokenTree::Punct(punct)) => {
146							let punct = punct.as_char();
147							output.push(punct);
148							tokens.next();
149							if punct != '-' {
150								break;
151							}
152							ident_allowed = true;
153						}
154						Some(TokenTree::Ident(ident)) if ident_allowed => {
155							output.push_str(&ident.to_string());
156							tokens.next();
157							ident_allowed = false;
158						}
159						_ => break,
160					}
161				}
162			}
163			_ => {
164				output.push(' ');
165				output.push_str(&token.to_string());
166			}
167		}
168	}
169}
170
171impl QueryDefinition {
172	fn into_tokens(self, language: TsLang) -> TokenStream {
173		let language = self.lang.map(TsLang::Custom).unwrap_or(language);
174		let mut captures = HashMap::new();
175		let mut diagnostics = Vec::new();
176		let mut index = 0u32;
177		let mut tokens = self.query.clone().into_iter();
178		let mut expect_capture = false;
179		while let Some(token) = tokens.next() {
180			match token {
181				TokenTree::Punct(punct) if punct.as_char() == '@' => {
182					expect_capture = true;
183				}
184				TokenTree::Ident(capture) if expect_capture => {
185					expect_capture = false;
186					let capture = quote!(#capture).to_string();
187					let key = if capture.starts_with('_') {
188						capture
189					} else {
190						capture.to_upper_camel_case()
191					};
192					if let Entry::Vacant(entry) = captures.entry(key) {
193						entry.insert(index);
194						index += 1;
195					}
196				}
197				TokenTree::Group(group) => {
198					tokens = group
199						.stream()
200						.into_iter()
201						.chain(tokens)
202						.collect::<TokenStream>()
203						.into_iter();
204				}
205				_ => {}
206			}
207		}
208		let mut cases = vec![];
209		let mut variants = vec![];
210		for capture in self.captures.iter() {
211			if let Some(index) = captures.get(capture.to_string().as_str()) {
212				let index = Literal::usize_unsuffixed(*index as _);
213				cases.push(quote_spanned!(capture.span() => #index => Some(Self::#capture),));
214				variants.push(quote_spanned!(capture.span()=> #capture = #index,));
215			} else {
216				diagnostics.push(capture.span().error(format!("No capture '{capture}' found in query")));
217			}
218		}
219		let name = self.name;
220		let mut query = String::new();
221		tokens_to_string(self.query, &mut query);
222		let meta = self.meta;
223		let diagnostics = diagnostics.into_iter().map(|diag| diag.emit_as_item_tokens());
224		quote_spanned!(name.span()=>
225			#(#[#meta])*
226			pub enum #name {
227				#(#variants)*
228			}
229
230			#[allow(dead_code)]
231			impl #name {
232				#[inline]
233				pub fn from(raw: u32) -> Option<Self> {
234					match raw {
235						#(#cases)*
236						_ => None,
237					}
238				}
239				pub fn query() -> &'static ::tree_sitter::Query {
240					use ::std::sync::OnceLock as _OnceLock;
241					static QUERY: _OnceLock<::tree_sitter::Query> = _OnceLock::new();
242					QUERY.get_or_init(|| {
243						::tree_sitter::Query::new(&#language::LANGUAGE.into(), #query).unwrap()
244					})
245				}
246			}
247			#(#diagnostics)*
248		)
249	}
250}