slotted_egraphs_derive/
lib.rs1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{quote, ToTokens};
4use syn::*;
5
6#[proc_macro]
10pub fn define_language(input: TokenStream1) -> TokenStream1 {
11 let mut ie: ItemEnum = parse(input).unwrap();
12
13 let name = ie.ident.clone();
14 let str_names: Vec<Option<Expr>> = ie
15 .variants
16 .iter_mut()
17 .map(|x| x.discriminant.take().map(|(_, e)| e))
18 .collect();
19
20 let all_slot_occurrences_mut_arms: Vec<TokenStream2> = ie
21 .variants
22 .iter()
23 .map(|x| produce_all_slot_occurrences_mut(&name, x))
24 .collect();
25 let public_slot_occurrences_mut_arms: Vec<TokenStream2> = ie
26 .variants
27 .iter()
28 .map(|x| produce_public_slot_occurrences_mut(&name, x))
29 .collect();
30 let applied_id_occurrences_mut_arms: Vec<TokenStream2> = ie
31 .variants
32 .iter()
33 .map(|x| produce_applied_id_occurrences_mut(&name, x))
34 .collect();
35
36 let all_slot_occurrences_arms: Vec<TokenStream2> = ie
37 .variants
38 .iter()
39 .map(|x| produce_all_slot_occurrences(&name, x))
40 .collect();
41 let public_slot_occurrences_arms: Vec<TokenStream2> = ie
42 .variants
43 .iter()
44 .map(|x| produce_public_slot_occurrences(&name, x))
45 .collect();
46 let applied_id_occurrences_arms: Vec<TokenStream2> = ie
47 .variants
48 .iter()
49 .map(|x| produce_applied_id_occurrences(&name, x))
50 .collect();
51
52 let to_syntax_arms: Vec<TokenStream2> = ie
53 .variants
54 .iter()
55 .zip(&str_names)
56 .map(|(x, n)| produce_to_syntax(&name, &n, x))
57 .collect();
58 let from_syntax_arms1: Vec<TokenStream2> = ie
59 .variants
60 .iter()
61 .zip(&str_names)
62 .filter_map(|(x, n)| produce_from_syntax1(&name, &n, x))
63 .collect();
64 let from_syntax_arms2: Vec<TokenStream2> = ie
65 .variants
66 .iter()
67 .zip(&str_names)
68 .filter_map(|(x, n)| produce_from_syntax2(&name, &n, x))
69 .collect();
70
71 let slots_arms: Vec<TokenStream2> = ie
72 .variants
73 .iter()
74 .map(|x| produce_slots(&name, x))
75 .collect();
76 let weak_shape_inplace_arms: Vec<TokenStream2> = ie
77 .variants
78 .iter()
79 .map(|x| produce_weak_shape_inplace(&name, x))
80 .collect();
81
82 quote! {
83 #[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
84 #ie
85
86 impl Language for #name {
87 fn all_slot_occurrences_mut(&mut self) -> Vec<&mut Slot> {
89 match self {
90 #(#all_slot_occurrences_mut_arms),*
91 }
92 }
93
94 fn public_slot_occurrences_mut(&mut self) -> Vec<&mut Slot> {
95 match self {
96 #(#public_slot_occurrences_mut_arms),*
97 }
98 }
99
100 fn applied_id_occurrences_mut(&mut self) -> Vec<&mut AppliedId> {
101 match self {
102 #(#applied_id_occurrences_mut_arms),*
103 }
104 }
105
106
107 fn all_slot_occurrences(&self) -> Vec<Slot> {
109 match self {
110 #(#all_slot_occurrences_arms),*
111 }
112 }
113
114 fn public_slot_occurrences(&self) -> Vec<Slot> {
115 match self {
116 #(#public_slot_occurrences_arms),*
117 }
118 }
119
120 fn applied_id_occurrences(&self) -> Vec<&AppliedId> {
121 match self {
122 #(#applied_id_occurrences_arms),*
123 }
124 }
125
126 fn to_syntax(&self) -> Vec<SyntaxElem> {
128 match self {
129 #(#to_syntax_arms),*
130 }
131 }
132
133 fn from_syntax(elems: &[SyntaxElem]) -> Option<Self> {
134 let SyntaxElem::String(op) = elems.get(0)? else { return None };
135 match &**op {
136 #(#from_syntax_arms1),*
137 _ => {
138 #(#from_syntax_arms2)*
139
140 None
141 },
142 }
143 }
144
145 fn slots(&self) -> slotted_egraphs::SmallHashSet<Slot> {
146 match self {
147 #(#slots_arms),*
148 }
149 }
150
151 fn weak_shape_inplace(&mut self) -> slotted_egraphs::SlotMap {
152 let m = &mut (slotted_egraphs::SlotMap::new(), 0);
153 match self {
154 #(#weak_shape_inplace_arms),*
155 }
156
157 m.0.inverse()
158 }
159 }
160 }
161 .to_token_stream()
162 .into()
163}
164
165fn produce_all_slot_occurrences_mut(name: &Ident, v: &Variant) -> TokenStream2 {
166 let variant_name = &v.ident;
167 let n = v.fields.len();
168 let fields: Vec<Ident> = (0..n)
169 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
170 .collect();
171 quote! {
172 #name::#variant_name(#(#fields),*) => {
173 let out = std::iter::empty();
174 #(
175 let out = out.chain(#fields .all_slot_occurrences_iter_mut());
176 )*
177 out.collect()
178 }
179 }
180}
181
182fn produce_public_slot_occurrences_mut(name: &Ident, v: &Variant) -> TokenStream2 {
183 let variant_name = &v.ident;
184 let n = v.fields.len();
185 let fields: Vec<Ident> = (0..n)
186 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
187 .collect();
188 quote! {
189 #name::#variant_name(#(#fields),*) => {
190 let out = std::iter::empty();
191 #(
192 let out = out.chain(#fields .public_slot_occurrences_iter_mut());
193 )*
194 out.collect()
195 }
196 }
197}
198
199fn produce_applied_id_occurrences_mut(name: &Ident, v: &Variant) -> TokenStream2 {
200 let variant_name = &v.ident;
201 let n = v.fields.len();
202 let fields: Vec<Ident> = (0..n)
203 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
204 .collect();
205 quote! {
206 #name::#variant_name(#(#fields),*) => {
207 let out = std::iter::empty();
208 #(
209 let out = out.chain(#fields .applied_id_occurrences_iter_mut());
210 )*
211 out.collect()
212 }
213 }
214}
215
216fn produce_all_slot_occurrences(name: &Ident, v: &Variant) -> TokenStream2 {
218 let variant_name = &v.ident;
219 let n = v.fields.len();
220 let fields: Vec<Ident> = (0..n)
221 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
222 .collect();
223 quote! {
224 #name::#variant_name(#(#fields),*) => {
225 let out = std::iter::empty();
226 #(
227 let out = out.chain(#fields .all_slot_occurrences_iter().copied());
228 )*
229 out.collect()
230 }
231 }
232}
233
234fn produce_public_slot_occurrences(name: &Ident, v: &Variant) -> TokenStream2 {
235 let variant_name = &v.ident;
236 let n = v.fields.len();
237 let fields: Vec<Ident> = (0..n)
238 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
239 .collect();
240 quote! {
241 #name::#variant_name(#(#fields),*) => {
242 let out = std::iter::empty();
243 #(
244 let out = out.chain(#fields .public_slot_occurrences_iter().copied());
245 )*
246 out.collect()
247 }
248 }
249}
250
251fn produce_applied_id_occurrences(name: &Ident, v: &Variant) -> TokenStream2 {
252 let variant_name = &v.ident;
253 let n = v.fields.len();
254 let fields: Vec<Ident> = (0..n)
255 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
256 .collect();
257 quote! {
258 #name::#variant_name(#(#fields),*) => {
259 let out = std::iter::empty();
260 #(
261 let out = out.chain(#fields .applied_id_occurrences_iter());
262 )*
263 out.collect()
264 }
265 }
266}
267
268fn produce_to_syntax(name: &Ident, e: &Option<Expr>, v: &Variant) -> TokenStream2 {
270 let variant_name = &v.ident;
271
272 if e.is_none() {
273 return quote! {
274 #name::#variant_name(a0) => {
275 a0.to_syntax()
276 }
277 };
278 }
279
280 let e = e.as_ref().unwrap();
281 let n = v.fields.len();
282 let fields: Vec<Ident> = (0..n)
283 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
284 .collect();
285 quote! {
286 #name::#variant_name(#(#fields),*) => {
287 let mut out: Vec<SyntaxElem> = vec![SyntaxElem::String(String::from(#e))];
288 #(
289 out.extend(#fields.to_syntax());
290 )*
291 out
292 }
293 }
294}
295
296fn produce_from_syntax1(name: &Ident, e: &Option<Expr>, v: &Variant) -> Option<TokenStream2> {
297 let variant_name = &v.ident;
298
299 let e = e.as_ref()?;
300 let n = v.fields.len();
301 let fields: Vec<Ident> = (0..n)
302 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
303 .collect();
304
305 let types: Vec<Type> = v.fields.iter().map(|x| x.ty.clone()).collect();
306
307 Some(quote! {
308 #e => {
309 let mut children = &elems[1..];
310 let mut rest = children;
311 #(
312 let #fields = (0..=children.len()).filter_map(|n| {
313 let a = &children[..n];
314 rest = &children[n..];
315
316 <#types>::from_syntax(a)
317 }).next()?;
318 children = rest;
319 )*
320 Some(#name::#variant_name(#(#fields),*))
321 }
322 })
323}
324
325fn produce_from_syntax2(name: &Ident, e: &Option<Expr>, v: &Variant) -> Option<TokenStream2> {
326 if e.is_some() {
327 return None;
328 }
329 let variant_name = &v.ident;
330
331 let ty = v.fields.iter().map(|x| x.ty.clone()).next().unwrap();
332 Some(quote! {
333 if let Some(a) = <#ty>::from_syntax(elems) {
334 return Some(#name::#variant_name(a));
335 }
336 })
337}
338
339fn produce_slots(name: &Ident, v: &Variant) -> TokenStream2 {
340 let variant_name = &v.ident;
341 let n = v.fields.len();
342 let fields: Vec<Ident> = (0..n)
343 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
344 .collect();
345 quote! {
346 #name::#variant_name(#(#fields),*) => {
347 let out = std::iter::empty();
348 #(
349 let out = out.chain(#fields .public_slot_occurrences_iter().copied());
350 )*
351 out.collect()
352 }
353 }
354}
355
356fn produce_weak_shape_inplace(name: &Ident, v: &Variant) -> TokenStream2 {
357 let variant_name = &v.ident;
358 let n = v.fields.len();
359 let fields: Vec<Ident> = (0..n)
360 .map(|x| Ident::new(&format!("a{x}"), proc_macro2::Span::call_site()))
361 .collect();
362 quote! {
363 #name::#variant_name(#(#fields),*) => {
364 #(
365 #fields .weak_shape_impl(m);
366 )*
367 }
368 }
369}