1use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident};
6
7#[proc_macro_derive(Alpha)]
9pub fn derive_alpha(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 let name = &input.ident;
12 let generics = &input.generics;
13 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
14
15 let aeq_impl = generate_aeq_impl(&input.data, name);
16 let aeq_in_impl = generate_aeq_in_impl(&input.data);
17 let fv_in_impl = generate_fv_in_impl(&input.data);
18
19 let expanded = quote! {
20 impl #impl_generics unbound::Alpha for #name #ty_generics #where_clause {
21 fn aeq(&self, other: &Self) -> bool {
22 #aeq_impl
23 }
24
25 fn aeq_in(&self, ctx: &mut unbound::alpha::AlphaCtx, other: &Self) -> bool {
26 #aeq_in_impl
27 }
28
29 fn fv_in(&self, vars: &mut Vec<String>) {
30 #fv_in_impl
31 }
32 }
33 };
34
35 TokenStream::from(expanded)
36}
37
38fn generate_aeq_impl(data: &Data, name: &Ident) -> proc_macro2::TokenStream {
39 match data {
40 Data::Struct(data_struct) => match &data_struct.fields {
41 Fields::Named(fields) => {
42 let field_checks = fields.named.iter().map(|f| {
43 let field_name = &f.ident;
44 quote! {
45 self.#field_name.aeq(&other.#field_name)
46 }
47 });
48 quote! {
49 #(#field_checks)&&*
50 }
51 }
52 Fields::Unnamed(fields) => {
53 let field_checks = fields.unnamed.iter().enumerate().map(|(i, _)| {
54 let index = syn::Index::from(i);
55 quote! {
56 self.#index.aeq(&other.#index)
57 }
58 });
59 quote! {
60 #(#field_checks)&&*
61 }
62 }
63 Fields::Unit => quote! { true },
64 },
65 Data::Enum(data_enum) => {
66 let variant_matches = data_enum.variants.iter().map(|variant| {
67 let variant_name = &variant.ident;
68 match &variant.fields {
69 Fields::Named(fields) => {
70 let field_names: Vec<_> = fields
71 .named
72 .iter()
73 .filter_map(|f| f.ident.as_ref())
74 .collect();
75 let other_field_names: Vec<_> = field_names
76 .iter()
77 .map(|f| quote::format_ident!("other_{}", f))
78 .collect();
79 let field_checks = field_names.iter().zip(other_field_names.iter()).map(
80 |(field_name, other_field_name)| {
81 quote! {
82 #field_name.aeq(#other_field_name)
83 }
84 },
85 );
86 let other_bindings = field_names.iter().zip(other_field_names.iter()).map(
87 |(field_name, other_field_name)| {
88 quote! { #field_name: #other_field_name }
89 },
90 );
91 quote! {
92 (#name::#variant_name { #(#field_names),* },
93 #name::#variant_name { #(#other_bindings),* }) => {
94 #(#field_checks)&&*
95 }
96 }
97 }
98 Fields::Unnamed(fields) => {
99 let field_names: Vec<_> = (0..fields.unnamed.len())
100 .map(|i| quote::format_ident!("f{}", i))
101 .collect();
102 let other_names: Vec<_> = (0..fields.unnamed.len())
103 .map(|i| quote::format_ident!("other_f{}", i))
104 .collect();
105 let field_checks =
106 field_names
107 .iter()
108 .zip(other_names.iter())
109 .map(|(f, other_f)| {
110 quote! {
111 #f.aeq(#other_f)
112 }
113 });
114 quote! {
115 (#name::#variant_name(#(#field_names),*),
116 #name::#variant_name(#(#other_names),*)) => {
117 #(#field_checks)&&*
118 }
119 }
120 }
121 Fields::Unit => {
122 quote! {
123 (#name::#variant_name, #name::#variant_name) => true
124 }
125 }
126 }
127 });
128 quote! {
129 match (self, other) {
130 #(#variant_matches,)*
131 _ => false,
132 }
133 }
134 }
135 Data::Union(_) => panic!("Unions are not supported"),
136 }
137}
138
139fn generate_aeq_in_impl(data: &Data) -> proc_macro2::TokenStream {
140 match data {
141 Data::Struct(data_struct) => match &data_struct.fields {
142 Fields::Named(fields) => {
143 let field_checks = fields.named.iter().map(|f| {
144 let field_name = &f.ident;
145 quote! {
146 self.#field_name.aeq_in(ctx, &other.#field_name)
147 }
148 });
149 quote! {
150 #(#field_checks)&&*
151 }
152 }
153 Fields::Unnamed(fields) => {
154 let field_checks = fields.unnamed.iter().enumerate().map(|(i, _)| {
155 let index = syn::Index::from(i);
156 quote! {
157 self.#index.aeq_in(ctx, &other.#index)
158 }
159 });
160 quote! {
161 #(#field_checks)&&*
162 }
163 }
164 Fields::Unit => quote! { true },
165 },
166 Data::Enum(data_enum) => {
167 let variant_matches = data_enum.variants.iter().map(|variant| {
169 let variant_name = &variant.ident;
170 match &variant.fields {
171 Fields::Named(fields) => {
172 let field_names: Vec<_> = fields
173 .named
174 .iter()
175 .filter_map(|f| f.ident.as_ref())
176 .collect();
177 let other_field_names: Vec<_> = field_names
178 .iter()
179 .map(|f| quote::format_ident!("other_{}", f))
180 .collect();
181 let field_checks = field_names.iter().zip(other_field_names.iter()).map(
182 |(field_name, other_field_name)| {
183 quote! {
184 #field_name.aeq_in(ctx, #other_field_name)
185 }
186 },
187 );
188 let other_bindings = field_names.iter().zip(other_field_names.iter()).map(
189 |(field_name, other_field_name)| {
190 quote! { #field_name: #other_field_name }
191 },
192 );
193 quote! {
194 (Self::#variant_name { #(#field_names),* },
195 Self::#variant_name { #(#other_bindings),* }) => {
196 #(#field_checks)&&*
197 }
198 }
199 }
200 Fields::Unnamed(fields) => {
201 let field_names: Vec<_> = (0..fields.unnamed.len())
202 .map(|i| quote::format_ident!("f{}", i))
203 .collect();
204 let other_names: Vec<_> = (0..fields.unnamed.len())
205 .map(|i| quote::format_ident!("other_f{}", i))
206 .collect();
207 let field_checks =
208 field_names
209 .iter()
210 .zip(other_names.iter())
211 .map(|(f, other_f)| {
212 quote! {
213 #f.aeq_in(ctx, #other_f)
214 }
215 });
216 quote! {
217 (Self::#variant_name(#(#field_names),*),
218 Self::#variant_name(#(#other_names),*)) => {
219 #(#field_checks)&&*
220 }
221 }
222 }
223 Fields::Unit => {
224 quote! {
225 (Self::#variant_name, Self::#variant_name) => true
226 }
227 }
228 }
229 });
230 quote! {
231 match (self, other) {
232 #(#variant_matches,)*
233 _ => false,
234 }
235 }
236 }
237 Data::Union(_) => panic!("Unions are not supported"),
238 }
239}
240
241fn generate_fv_in_impl(data: &Data) -> proc_macro2::TokenStream {
242 match data {
243 Data::Struct(data_struct) => match &data_struct.fields {
244 Fields::Named(fields) => {
245 let field_calls = fields.named.iter().map(|f| {
246 let field_name = &f.ident;
247 quote! {
248 self.#field_name.fv_in(vars);
249 }
250 });
251 quote! {
252 #(#field_calls)*
253 }
254 }
255 Fields::Unnamed(fields) => {
256 let field_calls = fields.unnamed.iter().enumerate().map(|(i, _)| {
257 let index = syn::Index::from(i);
258 quote! {
259 self.#index.fv_in(vars);
260 }
261 });
262 quote! {
263 #(#field_calls)*
264 }
265 }
266 Fields::Unit => quote! {},
267 },
268 Data::Enum(data_enum) => {
269 let variant_matches = data_enum.variants.iter().map(|variant| {
270 let variant_name = &variant.ident;
271 match &variant.fields {
272 Fields::Named(fields) => {
273 let field_names: Vec<_> = fields
274 .named
275 .iter()
276 .filter_map(|f| f.ident.as_ref())
277 .collect();
278 let field_calls = field_names.iter().map(|field_name| {
279 quote! {
280 #field_name.fv_in(vars);
281 }
282 });
283 quote! {
284 Self::#variant_name { #(#field_names),* } => {
285 #(#field_calls)*
286 }
287 }
288 }
289 Fields::Unnamed(fields) => {
290 let field_names: Vec<_> = (0..fields.unnamed.len())
291 .map(|i| quote::format_ident!("f{}", i))
292 .collect();
293 let field_calls = field_names.iter().map(|f| {
294 quote! {
295 #f.fv_in(vars);
296 }
297 });
298 quote! {
299 Self::#variant_name(#(#field_names),*) => {
300 #(#field_calls)*
301 }
302 }
303 }
304 Fields::Unit => {
305 quote! {
306 Self::#variant_name => {}
307 }
308 }
309 }
310 });
311 quote! {
312 match self {
313 #(#variant_matches)*
314 }
315 }
316 }
317 Data::Union(_) => panic!("Unions are not supported"),
318 }
319}
320
321#[proc_macro_derive(Subst, attributes(subst_var))]
323pub fn derive_subst(input: TokenStream) -> TokenStream {
324 let input = parse_macro_input!(input as DeriveInput);
325 let name = &input.ident;
326 let generics = &input.generics;
327 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
328
329 let (is_var_impl, subst_impl) = generate_subst_impl(&input.data, name);
330
331 let expanded = quote! {
332 impl #impl_generics unbound::Subst<#name #ty_generics> for #name #ty_generics #where_clause {
333 fn is_var(&self) -> Option<unbound::SubstName<#name #ty_generics>> {
334 #is_var_impl
335 }
336
337 fn subst(&self, var: &unbound::Name<#name #ty_generics>, value: &#name #ty_generics) -> Self {
338 #subst_impl
339 }
340 }
341 };
342
343 TokenStream::from(expanded)
344}
345
346fn generate_subst_impl(
347 data: &Data,
348 name: &Ident,
349) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
350 match data {
351 Data::Enum(data_enum) => {
352 let var_variant = data_enum
354 .variants
355 .iter()
356 .find(|v| v.ident == "V" || v.ident == "Var" || v.ident == "Variable");
357
358 let is_var_impl = if let Some(var_variant) = var_variant {
359 let variant_name = &var_variant.ident;
360 quote! {
361 match self {
362 #name::#variant_name(x) => Some(unbound::SubstName::Name(x.clone())),
363 _ => None,
364 }
365 }
366 } else {
367 quote! { None }
368 };
369
370 let subst_cases = data_enum.variants.iter().map(|variant| {
371 let variant_name = &variant.ident;
372
373 if Some(&variant.ident) == var_variant.as_ref().map(|v| &v.ident) {
375 quote! {
376 #name::#variant_name(x) => {
377 if x == var {
378 value.clone()
379 } else {
380 self.clone()
381 }
382 }
383 }
384 } else if variant.ident == "Lam" {
385 match &variant.fields {
387 Fields::Unnamed(_) => {
388 quote! {
389 #name::#variant_name(bnd) => {
390 let bound_var = bnd.pattern();
392 if bound_var == var {
393 self.clone()
395 } else {
396 let body_subst = bnd.body().subst(var, value);
398 #name::#variant_name(unbound::bind(bound_var.clone(), body_subst))
399 }
400 }
401 }
402 }
403 _ => {
404 match &variant.fields {
406 Fields::Named(fields) => {
407 let field_names: Vec<_> =
408 fields.named.iter().filter_map(|f| f.ident.as_ref()).collect();
409 let field_substs = field_names.iter().map(|field_name| {
410 quote! {
411 #field_name: #field_name.subst(var, value)
412 }
413 });
414 quote! {
415 #name::#variant_name { #(#field_names),* } => {
416 #name::#variant_name {
417 #(#field_substs),*
418 }
419 }
420 }
421 }
422 Fields::Unnamed(fields) => {
423 let field_names: Vec<_> = (0..fields.unnamed.len())
424 .map(|i| quote::format_ident!("f{}", i))
425 .collect();
426 let field_substs = field_names.iter().map(|f| {
427 quote! {
428 #f.subst(var, value)
429 }
430 });
431 quote! {
432 #name::#variant_name(#(#field_names),*) => {
433 #name::#variant_name(#(#field_substs),*)
434 }
435 }
436 }
437 Fields::Unit => {
438 quote! {
439 #name::#variant_name => #name::#variant_name
440 }
441 }
442 }
443 }
444 }
445 } else {
446 match &variant.fields {
447 Fields::Named(fields) => {
448 let field_names: Vec<_> =
449 fields.named.iter().filter_map(|f| f.ident.as_ref()).collect();
450 let field_substs = field_names.iter().map(|field_name| {
451 quote! {
452 #field_name: #field_name.subst(var, value)
453 }
454 });
455 quote! {
456 #name::#variant_name { #(#field_names),* } => {
457 #name::#variant_name {
458 #(#field_substs),*
459 }
460 }
461 }
462 }
463 Fields::Unnamed(fields) => {
464 let field_names: Vec<_> = (0..fields.unnamed.len())
465 .map(|i| quote::format_ident!("f{}", i))
466 .collect();
467 let field_substs = field_names.iter().map(|f| {
468 quote! {
469 #f.subst(var, value)
470 }
471 });
472 quote! {
473 #name::#variant_name(#(#field_names),*) => {
474 #name::#variant_name(#(#field_substs),*)
475 }
476 }
477 }
478 Fields::Unit => {
479 quote! {
480 #name::#variant_name => #name::#variant_name
481 }
482 }
483 }
484 }
485 });
486
487 let subst_impl = quote! {
488 match self {
489 #(#subst_cases),*
490 }
491 };
492
493 (is_var_impl, subst_impl)
494 }
495 Data::Struct(data_struct) => {
496 let is_var_impl = quote! { None };
497
498 let subst_impl = match &data_struct.fields {
499 Fields::Named(fields) => {
500 let field_names: Vec<_> = fields
501 .named
502 .iter()
503 .filter_map(|f| f.ident.as_ref())
504 .collect();
505 let field_substs = field_names.iter().map(|field_name| {
506 quote! {
507 #field_name: self.#field_name.subst(var, value)
508 }
509 });
510 quote! {
511 #name {
512 #(#field_substs),*
513 }
514 }
515 }
516 Fields::Unnamed(fields) => {
517 let field_substs = (0..fields.unnamed.len()).map(|i| {
518 let index = syn::Index::from(i);
519 quote! {
520 self.#index.subst(var, value)
521 }
522 });
523 quote! {
524 #name(#(#field_substs),*)
525 }
526 }
527 Fields::Unit => quote! { #name },
528 };
529
530 (is_var_impl, subst_impl)
531 }
532 Data::Union(_) => panic!("Unions are not supported"),
533 }
534}