1#![doc = include_str!("../README.md")]
2
3extern crate alloc;
6use alloc::vec::Vec;
7use alloc::string::String;
8use std::collections::HashMap;
9
10use proc_macro::TokenStream;
11use proc_macro2::{TokenTree, Group};
12use quote::{ToTokens, quote, quote_spanned};
13use heck::{AsUpperCamelCase, AsSnakeCase};
14use syn::parse::{Parse, ParseStream, Result};
15use syn::{parse, parse2, parse_quote, parse_macro_input, parse_str, Attribute, Block, Error, Fields, GenericParam, Generics, Ident, ImplItem, ItemEnum, ItemImpl, FnArg, punctuated::Punctuated, Signature, Token, Type, TypeParam, PathArguments, Variant, Visibility};
16use syn::spanned::Spanned;
17
18struct SummumType {
19 attrs: Vec<Attribute>,
20 vis: Visibility,
21 name: Ident,
22 generics: Generics,
23 cases: Vec<Variant>,
24}
25
26impl SummumType {
27 fn parse_haskell_style(input: ParseStream, attrs: Vec<Attribute>, vis: Visibility) -> Result<Self> {
28 let _ = input.parse::<Token![type]>()?;
29 let name = input.parse()?;
30 let generics: Generics = input.parse()?;
31 let _ = input.parse::<Token![=]>()?;
32 let mut cases = vec![];
33
34 loop {
35 let item_type = input.parse()?;
36
37 let item_ident = if input.peek(Token![as]) {
38 let _ = input.parse::<Token![as]>()?;
39 input.parse::<Ident>()?
40 } else {
41 ident_from_type_full(&item_type)
42 };
43
44 let mut variant: Variant = parse(quote!{ #item_ident(#item_type) }.into())?;
45 let sub_type = type_from_fields_mut(&mut variant.fields);
46 cannonicalize_type_path(sub_type);
47
48 cases.push(variant);
49
50 if input.peek(Token![;]) {
51 let _ = input.parse::<Token![;]>()?;
52 break;
53 }
54
55 let _ = input.parse::<Token![|]>()?;
56 }
57
58 Ok(Self {
59 attrs,
60 vis,
61 name,
62 generics,
63 cases,
64 })
65 }
66
67 fn parse_enum_style(input: ParseStream, attrs: Vec<Attribute>, vis: Visibility) -> Result<Self> {
68 let enum_block: ItemEnum = input.parse()?;
69 let name = enum_block.ident;
70 let generics = enum_block.generics;
71 let cases = enum_block.variants.into_iter()
72 .map(|mut variant| {
73 let sub_type = type_from_fields_mut(&mut variant.fields);
74 cannonicalize_type_path(sub_type);
75 variant
76 })
77 .collect();
78
79 Ok(Self {
80 attrs,
81 vis,
82 name,
83 generics,
84 cases,
85 })
86 }
87
88 fn parse(input: ParseStream, attrs: Vec<Attribute>) -> Result<Self> {
89 let vis = input.parse()?;
90
91 if input.peek(Token![type]) {
92 SummumType::parse_haskell_style(input, attrs, vis)
93 } else if input.peek(Token![enum]) {
94 SummumType::parse_enum_style(input, attrs, vis)
95 } else {
96 input.step(|cursor| {
97 Err(cursor.error(format!("expected `enum`, `type`, or `impl`")))
98 })
99 }
100 }
101
102 fn render(&self) -> TokenStream {
103 let Self {
104 attrs,
105 vis,
106 name,
107 generics,
108 cases,
109 } = self;
110
111 let cases_tokens = cases.iter().map(|variant| quote! {
112 #variant
113 }).collect::<Vec<_>>();
114
115 let from_impls = cases.iter().map(|variant| {
116 let ident = &variant.ident;
117 let sub_type = type_from_fields(&variant.fields);
118
119 quote_spanned! {variant.span() =>
120 impl #generics From<#sub_type> for #name #generics {
121 fn from(val: #sub_type) -> Self {
122 #name::#ident(val)
123 }
124 }
125 }
126 }).collect::<Vec<_>>();
127
128 let generic_params = type_params_from_generics(&generics);
129 let try_from_impls = cases.iter().map(|variant| {
130 let ident = &variant.ident;
131 let sub_type = type_from_fields(&variant.fields);
132 if !detect_uncovered_type(&generic_params[..], &sub_type) {
133 quote! {
134 impl #generics core::convert::TryFrom<#name #generics> for #sub_type {
135 type Error = ();
136 fn try_from(val: #name #generics) -> Result<Self, Self::Error> {
137 match val{#name::#ident(val)=>Ok(val), _=>Err(())}
138 }
139 }
140 }
141 } else {
142 quote!{}
143 }
144 }).collect::<Vec<_>>();
145
146 let variants_strs = cases.iter().map(|variant| {
147 let ident_string = &variant.ident.to_string();
148 quote_spanned! {variant.span() =>
149 #ident_string
150 }
151 }).collect::<Vec<_>>();
152 let variant_name_branches = cases.iter().map(|variant| {
153 let ident = &variant.ident;
154 let ident_string = ident.to_string();
155 quote_spanned! {variant.span() =>
156 Self::#ident(_) => #ident_string
157 }
158 }).collect::<Vec<_>>();
159 let variants_impl = quote!{
160 impl #generics #name #generics {
161 pub const fn variants() -> &'static[&'static str] {
162 &[#(#variants_strs),* ]
163 }
164 pub fn variant_name(&self) -> &'static str {
165 match self{
166 #(#variant_name_branches),*
167 }
168 }
169 }
170 };
171
172 let accessor_impls = cases.iter().map(|variant| {
187 let ident = &variant.ident;
188 let sub_type = type_from_fields(&variant.fields);
189
190 let ident_string = ident.to_string();
191 let is_fn_name = Ident::new(&snake_name("is", &ident_string), variant.ident.span());
192 let try_as_fn_name = Ident::new(&snake_name("try_as", &ident_string), variant.ident.span());
193 let as_fn_name_str = snake_name("as", &ident_string);
194 let as_fn_name = Ident::new(&as_fn_name_str, variant.ident.span());
195 let try_as_mut_fn_name = Ident::new(&snake_name("try_as_mut", &ident_string), variant.ident.span());
196 let as_mut_fn_name_str = snake_name("as_mut", &ident_string);
197 let as_mut_fn_name = Ident::new(&as_mut_fn_name_str, variant.ident.span());
198 let try_into_fn_name = Ident::new(&snake_name("try_into", &ident_string), variant.ident.span());
199 let into_fn_name_str = snake_name("into", &ident_string);
200 let into_fn_name = Ident::new(&into_fn_name_str, variant.ident.span());
201
202 let error_msg = format!("invalid downcast: {name}::{{}} expecting {ident_string} found {{}}");
203 quote_spanned! {variant.span() =>
204 pub fn #is_fn_name(&self) -> bool {
205 match self{Self::#ident(_)=>true, _=>false}
206 }
207 pub fn #try_as_fn_name(&self) -> Option<&#sub_type> {
208 match self{Self::#ident(val)=>Some(val), _=>None}
209 }
210 pub fn #as_fn_name(&self) -> &#sub_type {
211 self.#try_as_fn_name().unwrap_or_else(|| panic!(#error_msg, #as_fn_name_str, self.variant_name()))
212 }
213 pub fn #try_as_mut_fn_name(&mut self) -> Option<&mut #sub_type> {
214 match self{Self::#ident(val)=>Some(val), _=>None}
215 }
216 pub fn #as_mut_fn_name(&mut self) -> &mut #sub_type {
217 let variant_name = self.variant_name();
218 self.#try_as_mut_fn_name().unwrap_or_else(|| panic!(#error_msg, #as_mut_fn_name_str, variant_name))
219 }
220 pub fn #try_into_fn_name(self) -> core::result::Result<#sub_type, Self> {
221 match self{Self::#ident(val)=>Ok(val), _=>Err(self)}
222 }
223 pub fn #into_fn_name(self) -> #sub_type {
224 self.#try_into_fn_name().unwrap_or_else(|t| panic!(#error_msg, #into_fn_name_str, t.variant_name()))
225 }
226 }
227 }).collect::<Vec<_>>();
228 let accessors_impl = quote!{
229 #[allow(dead_code)]
230 impl #generics #name #generics {
231 #(#accessor_impls)*
232 }
233 };
234
235 quote! {
253 #[allow(dead_code)]
254 #(#attrs)*
255 #vis enum #name #generics {
256 #(#cases_tokens),*
257 }
258
259 #(#from_impls)*
260
261 #(#try_from_impls)*
262
263 #variants_impl
264
265 #accessors_impl
266
267 }.into()
272 }
273}
274
275struct SummumImpl {
276 item_impl: ItemImpl,
277 item_type_name: Ident,
278}
279
280impl SummumImpl {
281 fn parse(input: ParseStream, attrs: Vec<Attribute>) -> Result<Self> {
282 let mut item_impl: ItemImpl = input.parse()?;
283
284 if item_impl.trait_.is_some() {
285 return Err(Error::new(item_impl.span(), format!("impl for traits doesn't belong in summum block")));
286 }
287
288 item_impl.attrs = attrs;
289 let item_type_name = ident_from_type_short(&*item_impl.self_ty)?;
290
291 Ok(Self {
292 item_impl,
293 item_type_name
294 })
295 }
296
297 fn render(&mut self, types: &HashMap<String, SummumType>) -> TokenStream {
298 let item_impl = &mut self.item_impl;
299
300 let item_type = if let Some(item_type) = types.get(&self.item_type_name.to_string()) {
301 item_type
302 } else {
303 return quote_spanned! {
304 self.item_type_name.span() => compile_error!("can't find definition for type in summum block");
305 }.into();
306 };
307
308 let impl_span = item_impl.span();
309 let items = core::mem::take(&mut item_impl.items);
310 let mut new_items = vec![];
311 for item in items.into_iter() {
312 if let ImplItem::Fn(mut item) = item {
313
314 let mut variant_blocks = vec![];
316 for variant in item_type.cases.iter() {
317 let ident = &variant.ident;
318 let ident_string = ident.to_string();
319 let variant_t_name = format!("{}T", ident_string);
320
321 let sub_type = type_from_fields(&variant.fields);
322 let sub_type_string = quote!{ #sub_type }.to_string();
323
324 let block_tokenstream = replace_idents(item.block.to_token_stream(), &[
326 ("self", "_summum_self"),
327 ("VariantT", &variant_t_name),
328 ("InnerT", &sub_type_string),
329 ], &[
330 ("_inner_var", &|base| snake_name(base, &ident_string))
331 ]);
332
333 let block_tokenstream = match handle_excludes(block_tokenstream.into(), &ident_string) {
335 Ok(block_tokenstream) => block_tokenstream,
336 Err(err) => {return err.into();}
337 };
338
339 let block: Block = parse(quote_spanned!{item.block.span() => { #block_tokenstream } }.into()).expect("Error composing sub-block");
340 variant_blocks.push(block);
341 }
342
343 let item_fn_name = item.sig.ident.to_string();
345 if item_fn_name.ends_with("_inner_var") {
346 let base_fn_name = &item_fn_name[0..(item_fn_name.len() - "_inner_var".len())];
347 for (variant, block) in item_type.cases.iter().zip(variant_blocks) {
348 let mut new_item = item.clone();
349
350 let item_type_name = self.item_type_name.to_string();
351 let ident = &variant.ident;
352 let ident_string = ident.to_string();
353 let new_method_name = snake_name(base_fn_name, &ident_string);
354 new_item.sig.ident = Ident::new(&new_method_name, item.sig.ident.span());
355
356 let variant_t_name = format!("{}T", ident_string);
358 let sub_type = type_from_fields(&variant.fields);
359 let sub_type_string = quote!{ #sub_type }.to_string();
360 let sig_tokenstream = replace_idents(new_item.sig.to_token_stream(), &[
361 ("VariantT", &variant_t_name),
362 ("InnerT", &sub_type_string),
363 ], &[]);
364 new_item.sig = parse(quote_spanned!{item.sig.span() => #sig_tokenstream }.into()).expect("Error replacing signature types");
365
366 new_item.block = if sig_contains_self_arg(&new_item.sig) {
368 parse(quote_spanned!{item.span() =>
369 {
370 match self{
371 Self::#ident(_summum_self) => #block ,
372 _ => panic!("`{}::{}` method must be called with corresponding inner type", #item_type_name, #new_method_name)
373 }
374 }
375 }.into()).unwrap()
376 } else {
377 block
378 };
379
380 new_items.push(ImplItem::Fn(new_item));
381 }
382
383 } else {
384
385 let match_arms = item_type.cases.iter().zip(variant_blocks).map(|(variant, block)| {
387 let ident = &variant.ident;
388 quote_spanned! {item.span() =>
389 Self::#ident(_summum_self) => #block
390 }
391 }).collect::<Vec<_>>();
392
393 item.block = parse(quote_spanned!{item.span() =>
394 {
395 match self{
396 #(#match_arms),*
397 }
398 }
399 }.into()).unwrap();
400 new_items.push(ImplItem::Fn(item));
401 }
402 } else {
403 new_items.push(item);
404 }
405 }
406 item_impl.items = new_items;
407
408 quote_spanned!{impl_span =>
409 #item_impl
410 }.into()
411 }
412}
413
414#[derive(Default)]
415struct SummumItems {
416 types: HashMap<String, SummumType>,
417 impls: Vec<SummumImpl>
418}
419
420impl Parse for SummumItems {
421 fn parse(input: ParseStream) -> Result<Self> {
422 let mut items = Self::default();
423
424 while !input.is_empty() {
425 let attrs = input.call(Attribute::parse_outer)?;
426
427 if input.peek(Token![impl]) {
428 let next_impl = SummumImpl::parse(input, attrs)?;
429 items.impls.push(next_impl);
430 } else {
431 let next_type = SummumType::parse(input, attrs)?;
432 items.types.insert(next_type.name.to_string(), next_type);
433 }
434 }
435
436 Ok(items)
437 }
438}
439
440#[proc_macro]
442pub fn summum(input: TokenStream) -> TokenStream {
443 let mut new_stream = TokenStream::new();
444 let mut items: SummumItems = parse_macro_input!(input as SummumItems);
445
446 for item in items.types.values() {
447 new_stream.extend(item.render());
448 }
449
450 for item_impl in items.impls.iter_mut() {
451 new_stream.extend(item_impl.render(&items.types));
452 }
453
454 new_stream
455}
456
457fn ident_from_type_full(item_type: &Type) -> Ident {
459 let item_ident = quote!{ #item_type }.to_string();
460 let item_ident = AsUpperCamelCase(item_ident).to_string();
461 Ident::new(&item_ident, item_type.span())
462}
463
464fn snake_name(base: &str, ident: &str) -> String {
465 format!("{base}_{}", AsSnakeCase(ident))
466}
467
468fn type_from_fields(fields: &Fields) -> &Type {
469 if let Fields::Unnamed(field) = fields {
470 &field.unnamed.first().unwrap().ty
471 } else {panic!()}
472}
473
474fn type_from_fields_mut(fields: &mut Fields) -> &mut Type {
475 if let Fields::Unnamed(field) = fields {
476 &mut field.unnamed.first_mut().unwrap().ty
477 } else {panic!()}
478}
479
480fn cannonicalize_type_path(item_type: &mut Type) {
482 if let Type::Path(type_path) = item_type {
483 if let Some(segment) = type_path.path.segments.first_mut() {
484 let ident_span = segment.ident.span();
485 if let PathArguments::AngleBracketed(args) = &mut segment.arguments {
486 if args.colon2_token.is_none() {
487 args.colon2_token = Some(Token);
488 }
489 }
490 }
491 }
492}
493
494fn detect_uncovered_type(generic_type_params: &[&TypeParam], item_type: &Type) -> bool {
497 match item_type {
498 Type::Path(type_path) => {
499 if let Some(type_ident) = type_path.path.get_ident() {
500 for generic_type_params in generic_type_params {
501 if generic_type_params.ident.to_string() == type_ident.to_string() {
502 return true;
503 }
504 }
505 }
506 false
507 }
508 Type::Reference(type_ref) => detect_uncovered_type(generic_type_params, &type_ref.elem),
509 _ => false
510 }
511}
512
513fn type_params_from_generics(generics: &Generics) -> Vec<&TypeParam> {
514 let mut results = vec![];
515 for generic_param in generics.params.iter() {
516 if let GenericParam::Type(type_param) = generic_param {
517 results.push(type_param);
518 }
519 }
520 results
521}
522
523struct TypeIdentParseHelper(Ident);
525
526impl Parse for TypeIdentParseHelper {
527 fn parse(input: ParseStream) -> Result<Self> {
528
529 let mut result = Err(Error::new(input.span(), format!("invalid type")));
530 while !input.is_empty() {
531 if input.peek(Ident) {
532 let ident = input.parse::<Ident>()?;
533 if result.is_err() {
534 result = Ok(Self(ident));
535 }
536 } else {
537 _ = input.parse::<TokenTree>()?;
538 }
539 }
540
541 result
542 }
543}
544
545fn ident_from_type_short(item_type: &Type) -> Result<Ident> {
547 let type_stream = quote!{ #item_type };
548 let ident: TypeIdentParseHelper = parse2(type_stream)?;
549 Ok(ident.0)
550}
551
552fn replace_idents(input: proc_macro2::TokenStream, map: &[(&str, &str)], ends_with_map: &[(&str, &dyn Fn(&str) -> String)]) -> proc_macro2::TokenStream {
554 let mut new_stream = proc_macro2::TokenStream::new();
555
556 for item in input.into_iter() {
557 match item {
558 TokenTree::Ident(ident) => {
559 let ident_string = ident.to_string();
560 if let Some(replacement_string) = map.iter().find_map(|(key, val)| {
561 if key == &ident_string {
562 Some(val.to_string())
563 } else {
564 None
565 }
566 }).or_else(|| {
567 ends_with_map.iter().find_map(|(ends_with_key, func)| {
568 if ident_string.ends_with(ends_with_key) {
569 Some(func(&ident_string[0..(ident_string.len() - ends_with_key.len())]))
570 } else {
571 None
572 }
573 })
574 }) {
575 let replacement_stream = parse_str::<proc_macro2::TokenStream>(&replacement_string).expect("Error rendering type back to tokens");
576 let replacement_stream: proc_macro2::TokenStream = replacement_stream.into_iter()
577 .map(|mut item| {item.set_span(ident.span()); item} ).collect();
578 new_stream.extend([replacement_stream]);
579 } else {
580 new_stream.extend([TokenTree::Ident(ident)]);
581 }
582 },
583 TokenTree::Group(group) => {
584 let new_group_stream = replace_idents(group.stream(), map, ends_with_map);
585 let mut new_group = Group::new(group.delimiter(), new_group_stream);
586 new_group.set_span(group.span());
587 new_stream.extend([TokenTree::Group(new_group)]);
588 },
589 _ => {new_stream.extend([item]);}
590 }
591 }
592
593 new_stream
594}
595
596fn handle_excludes(input: proc_macro2::TokenStream, branch_ident: &str) -> core::result::Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
598 let mut new_stream = proc_macro2::TokenStream::new();
599
600 let mut input_iter = input.into_iter().peekable();
605 while let Some(item) = input_iter.next() {
606 match item {
607 TokenTree::Ident(ident) => {
608 let ident_string = ident.to_string();
609 if ident_string == "summum_exclude" || ident_string == "summum_restrict" {
610 let is_exclude = ident_string == "summum_exclude";
611
612 parse_punct(input_iter.next(), '!')?;
613 let next_item = input_iter.next();
614 let next_span = next_item.span();
615 let macro_args = if let Some(TokenTree::Group(macro_args_group)) = next_item {
616 let args_group_stream = macro_args_group.stream();
617 let macro_args_punct: Punctuated::<Ident, Token![,]> = parse_quote!( #args_group_stream );
618 let macro_args: Vec<String> = macro_args_punct.into_iter().map(|ident| ident.to_string()).collect();
619 macro_args
620 } else {
621 return Err(quote_spanned! {next_span => compile_error!("Expecting tuple of variants"); }.into());
622 };
623 if parse_punct(input_iter.peek(), ';').is_ok() {
624 let _ = input_iter.next();
625 }
626
627 let branch_in_list = macro_args.iter().find(|arg| arg.as_str() == branch_ident).is_some();
628
629 if (is_exclude && branch_in_list) || (!is_exclude && !branch_in_list) {
630 let unreachable_message = &format!("internal error: encountered {ident_string} for {branch_ident}");
631 let panic_tokens = quote_spanned!{ident.span() =>
632 {
633 #new_stream
634 panic!(#unreachable_message);
635 }
637 };
638 return Ok(panic_tokens);
639 }
640 } else {
641 new_stream.extend([TokenTree::Ident(ident)]);
642 }
643 },
644 TokenTree::Group(group) => {
645 let new_group_stream = handle_excludes(group.stream(), branch_ident)?;
646 let mut new_group = Group::new(group.delimiter(), new_group_stream);
647 new_group.set_span(group.span());
648 new_stream.extend([TokenTree::Group(new_group)]);
649 },
650 _ => {new_stream.extend([item]);}
651 }
652 }
653 Ok(new_stream)
654}
655
656fn parse_punct<T: core::borrow::Borrow<TokenTree>>(item: Option<T>, the_char: char) -> core::result::Result<(), proc_macro2::TokenStream> {
657 let item_ref = item.as_ref().map(|i| i.borrow());
658 let span = item_ref.span();
659 if let Some(TokenTree::Punct(p)) = item_ref {
660 if p.as_char() == the_char {
661 return Ok(());
662 }
663 }
664 let err_string = format!("expecting {the_char}");
665 return Err(quote_spanned! {span => compile_error!(#err_string); }.into());
666}
667
668fn sig_contains_self_arg(sig: &Signature) -> bool {
669 if let Some(first_arg) = sig.inputs.first() {
670 if let FnArg::Receiver(_rcvr) = first_arg {
671 return true;
672 }
673 }
674 false
675}
676
677#[cfg(feature = "generated_example")]
679#[allow(missing_docs)]
680pub mod generated_example {
681 use crate::summum;
682 summum! {
683 #[derive(Debug, Copy, Clone, PartialEq)]
689 enum GeneratedExample<'a, T> {
690 Slice(&'a [T]),
691 Vec(Vec<T>),
692 }
693
694 impl<'a, T> SliceOrPie<'a, T> {
695 fn get(&self, idx: usize) -> Option<&T> {
697 self.get(idx)
698 }
699
700 }
703 }
704}
705
706