1use std::collections::{hash_map::Entry, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::{Delimiter, Literal, TokenStream, TokenTree};
5use proc_macro2_diagnostics::SpanDiagnosticExt;
6use quote::{quote, quote_spanned, ToTokens};
7use syn::{parse::Parse, punctuated::Punctuated, *};
8
9#[proc_macro]
38pub fn query(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
39 let def = syn::parse_macro_input!(tokens as QueryDefinition);
40 def.into_tokens(TsLang::Python).into()
41}
42
43#[proc_macro]
44#[deprecated = "Use `query` with #[lang = \"..\"] instead"]
45pub fn query_js(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
46 let def = syn::parse_macro_input!(tokens as QueryDefinition);
47 def.into_tokens(TsLang::Javascript).into()
48}
49
50struct QueryDefinition {
51 meta: Vec<TokenStream>,
52 lang: Option<syn::LitStr>,
53 name: syn::Ident,
54 captures: Punctuated<Ident, Token![,]>,
55 query: TokenStream,
56}
57
58mod kw {
59 syn::custom_keyword!(lang);
60}
61
62impl Parse for QueryDefinition {
63 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
64 let mut meta = vec![];
65 let mut lang = None;
66 while input.peek(Token![#]) {
67 input.parse::<Token![#]>()?;
68 let content;
69 bracketed!(content in input);
70 if content.peek(kw::lang) {
71 content.parse::<kw::lang>()?;
72 content.parse::<Token![=]>()?;
73 lang = Some(content.parse()?);
74 continue;
75 }
76 meta.push(content.parse()?);
77 }
78 let name = input.parse()?;
79 let contents;
80 parenthesized!(contents in input);
81 let captures = Punctuated::parse_terminated(&contents)?;
82 input.parse::<Token![;]>()?;
83 let template = input.parse()?;
84 Ok(Self {
85 name,
86 meta,
87 captures,
88 query: template,
89 lang,
90 })
91 }
92}
93
94enum TsLang {
95 Python,
96 Javascript,
97 Custom(syn::LitStr),
98}
99
100impl ToTokens for TsLang {
101 fn to_tokens(&self, tokens: &mut TokenStream) {
102 match self {
103 Self::Python => tokens.extend(quote!(::tree_sitter_python)),
104 Self::Javascript => tokens.extend(quote!(::tree_sitter_javascript)),
105 Self::Custom(lang) => match syn::parse_str::<syn::Path>(&lang.value()) {
106 Ok(lang) => tokens.extend(quote!(#lang)),
107 Err(err) => {
108 let report = err.to_compile_error();
109 tokens.extend(quote_spanned!(lang.span() => #report))
110 }
111 },
112 }
113 }
114}
115
116fn tokens_to_string(tokens: TokenStream, output: &mut String) {
117 let mut tokens = tokens.into_iter().peekable();
118 while let Some(token) = tokens.next() {
119 match token {
120 TokenTree::Group(group) => {
121 let (lhs, rhs) = match group.delimiter() {
122 Delimiter::Parenthesis => ("(", ")"),
123 Delimiter::Brace => ("{", "}"),
124 Delimiter::Bracket => ("[", "]"),
125 Delimiter::None => (" ", " "),
126 };
127 output.push_str(lhs);
128 tokens_to_string(group.stream(), output);
129 output.push_str(rhs);
130 }
131 TokenTree::Punct(punct) if matches!(punct.as_char(), '@' | '#') => {
133 output.push(' ');
134 output.push(punct.as_char());
135 let Some(TokenTree::Ident(ident)) = tokens.peek() else {
136 continue;
137 };
138 output.push_str(&ident.to_string());
139 tokens.next();
140 let mut ident_allowed = false;
141 loop {
142 match tokens.peek() {
145 Some(TokenTree::Punct(punct)) => {
146 let punct = punct.as_char();
147 output.push(punct);
148 tokens.next();
149 if punct != '-' {
150 break;
151 }
152 ident_allowed = true;
153 }
154 Some(TokenTree::Ident(ident)) if ident_allowed => {
155 output.push_str(&ident.to_string());
156 tokens.next();
157 ident_allowed = false;
158 }
159 _ => break,
160 }
161 }
162 }
163 _ => {
164 output.push(' ');
165 output.push_str(&token.to_string());
166 }
167 }
168 }
169}
170
171impl QueryDefinition {
172 fn into_tokens(self, language: TsLang) -> TokenStream {
173 let language = self.lang.map(TsLang::Custom).unwrap_or(language);
174 let mut captures = HashMap::new();
175 let mut diagnostics = Vec::new();
176 let mut index = 0u32;
177 let mut tokens = self.query.clone().into_iter();
178 let mut expect_capture = false;
179 while let Some(token) = tokens.next() {
180 match token {
181 TokenTree::Punct(punct) if punct.as_char() == '@' => {
182 expect_capture = true;
183 }
184 TokenTree::Ident(capture) if expect_capture => {
185 expect_capture = false;
186 let capture = quote!(#capture).to_string();
187 let key = if capture.starts_with('_') {
188 capture
189 } else {
190 capture.to_upper_camel_case()
191 };
192 if let Entry::Vacant(entry) = captures.entry(key) {
193 entry.insert(index);
194 index += 1;
195 }
196 }
197 TokenTree::Group(group) => {
198 tokens = group
199 .stream()
200 .into_iter()
201 .chain(tokens)
202 .collect::<TokenStream>()
203 .into_iter();
204 }
205 _ => {}
206 }
207 }
208 let mut cases = vec![];
209 let mut variants = vec![];
210 for capture in self.captures.iter() {
211 if let Some(index) = captures.get(capture.to_string().as_str()) {
212 let index = Literal::usize_unsuffixed(*index as _);
213 cases.push(quote_spanned!(capture.span() => #index => Some(Self::#capture),));
214 variants.push(quote_spanned!(capture.span()=> #capture = #index,));
215 } else {
216 diagnostics.push(capture.span().error(format!("No capture '{capture}' found in query")));
217 }
218 }
219 let name = self.name;
220 let mut query = String::new();
221 tokens_to_string(self.query, &mut query);
222 let meta = self.meta;
223 let diagnostics = diagnostics.into_iter().map(|diag| diag.emit_as_item_tokens());
224 quote_spanned!(name.span()=>
225 #(#[#meta])*
226 pub enum #name {
227 #(#variants)*
228 }
229
230 #[allow(dead_code)]
231 impl #name {
232 #[inline]
233 pub fn from(raw: u32) -> Option<Self> {
234 match raw {
235 #(#cases)*
236 _ => None,
237 }
238 }
239 pub fn query() -> &'static ::tree_sitter::Query {
240 use ::std::sync::OnceLock as _OnceLock;
241 static QUERY: _OnceLock<::tree_sitter::Query> = _OnceLock::new();
242 QUERY.get_or_init(|| {
243 ::tree_sitter::Query::new(&#language::LANGUAGE.into(), #query).unwrap()
244 })
245 }
246 }
247 #(#diagnostics)*
248 )
249 }
250}