1#![feature(proc_macro_diagnostic, proc_macro_span, proc_macro_quote)]
2extern crate proc_macro;
3
4use proc_macro::{quote, Delimiter, Diagnostic, Level, TokenStream, TokenTree};
5
6#[proc_macro_attribute]
7pub fn main(args: TokenStream, item: TokenStream) -> TokenStream {
8 if !args.is_empty() {
9 let start = args.clone().into_iter().next().unwrap().span();
10 let end = args.clone().into_iter().last().unwrap().span();
11 let span = start.join(end).unwrap();
12 Diagnostic::spanned(
13 vec![span],
14 Level::Error,
15 "Attribute macro veneer_macros::main does not accept any arguments",
16 )
17 .emit()
18 }
19
20 let signature = item
21 .clone()
22 .into_iter()
23 .take_while(|t| {
24 if let TokenTree::Group(group) = t {
25 group.delimiter() != Delimiter::Brace
26 } else {
27 true
28 }
29 })
30 .collect::<Vec<_>>();
31
32 let start = item.clone().into_iter().next().unwrap().span();
33 let end = item.clone().into_iter().last().unwrap().span();
34 let span = start.join(end).unwrap();
35 let not_a_fn = Diagnostic::spanned(
36 vec![span],
37 Level::Error,
38 "Attribute macro veneer_macros::main may only be applied to functions which take no arguments",
39 );
40
41 let name = match (signature.get(0), signature.get(1), signature.get(2)) {
42 (Some(TokenTree::Ident(f)), Some(TokenTree::Ident(name)), Some(TokenTree::Group(args))) => {
43 if f.to_string() == "fn" && args.delimiter() == Delimiter::Parenthesis {
44 name
45 } else {
46 not_a_fn.emit();
47 return item;
48 }
49 }
50 _ => {
51 not_a_fn.emit();
52 return item;
53 }
54 };
55
56 let name = TokenTree::from(name.clone());
57
58 let header = if signature.len() == 3 {
59 quote! {
60 #[no_mangle]
61 unsafe extern "C" fn __veneer_main() {
62 $name();
63 veneer::syscalls::exit(0);
64 }
65 }
66 } else {
67 quote! {
68 #[no_mangle]
69 unsafe extern "C" fn __veneer_main() {
70 let exit_code = match $name() {
71 Ok(()) => 0,
72 Err(e) => {
73 veneer::eprintln!("Error: {}", e);
74 1
75 },
76 };
77 veneer::syscalls::exit(exit_code);
78 }
79 }
80 };
81 header.into_iter().chain(item.into_iter()).collect()
82}