1use parse::{Attrs, MacroOpts};
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, ItemFn, Signature};
7
8mod parse;
9
10const NO_SYNC_ERR: &str = "The vexide entrypoint must be marked `async`.";
11const NO_UNSAFE_ERR: &str = "The vexide entrypoint must be not marked `unsafe`.";
12const WRONG_ARGS_ERR: &str = "The vexide entrypoint must take a single parameter of type `vexide_devices::peripherals::Peripherals`";
13
14fn verify_function_sig(sig: &Signature) -> Result<(), syn::Error> {
15 let mut error = None;
16
17 if sig.asyncness.is_none() {
18 let message = syn::Error::new_spanned(sig, NO_SYNC_ERR);
19 error.replace(message);
20 }
21 if sig.unsafety.is_some() {
22 let message = syn::Error::new_spanned(sig, NO_UNSAFE_ERR);
23 match error {
24 Some(ref mut e) => e.combine(message),
25 None => {
26 error.replace(message);
27 }
28 }
29 }
30 if sig.inputs.len() != 1 {
31 let message = syn::Error::new_spanned(sig, WRONG_ARGS_ERR);
32 match error {
33 Some(ref mut e) => e.combine(message),
34 None => {
35 error.replace(message);
36 }
37 }
38 }
39
40 match error {
41 Some(e) => Err(e),
42 None => Ok(()),
43 }
44}
45
46fn make_code_sig(opts: MacroOpts) -> proc_macro2::TokenStream {
47 let sig = if let Some(code_sig) = opts.code_sig {
48 quote! { #code_sig }
49 } else {
50 quote! { ::vexide::startup::CodeSignature::new(
51 ::vexide::startup::ProgramType::User,
52 ::vexide::startup::ProgramOwner::Partner,
53 ::vexide::startup::ProgramFlags::empty(),
54 ) }
55 };
56
57 quote! {
58 #[link_section = ".code_signature"]
59 #[used] static CODE_SIGNATURE: ::vexide::startup::CodeSignature = #sig;
61 }
62}
63
64fn make_entrypoint(inner: &ItemFn, opts: MacroOpts) -> proc_macro2::TokenStream {
65 match verify_function_sig(&inner.sig) {
66 Ok(()) => {}
67 Err(e) => return e.to_compile_error(),
68 }
69 let inner_ident = inner.sig.ident.clone();
70 let ret_type = match &inner.sig.output {
71 syn::ReturnType::Default => quote! { () },
72 syn::ReturnType::Type(_, ty) => quote! { #ty },
73 };
74
75 let banner_theme = if let Some(theme) = opts.banner_theme {
76 quote! { #theme }
77 } else {
78 quote! { ::vexide::startup::banner::themes::THEME_DEFAULT }
79 };
80
81 let banner_enabled = if opts.banner_enabled {
82 quote! { true }
83 } else {
84 quote! { false }
85 };
86
87 quote! {
88 #[no_mangle]
89 unsafe extern "C" fn _start() -> ! {
90 ::vexide::startup::startup::<#banner_enabled>(#banner_theme);
91
92 #inner
93 let termination: #ret_type = ::vexide::runtime::block_on(
94 #inner_ident(::vexide::devices::peripherals::Peripherals::take().unwrap())
95 );
96 ::vexide::program::Termination::report(termination);
97 ::vexide::program::exit();
98 }
99 }
100}
101
102#[proc_macro_attribute]
177pub fn main(attrs: TokenStream, item: TokenStream) -> TokenStream {
178 let item = parse_macro_input!(item as ItemFn);
179 let opts = MacroOpts::from(parse_macro_input!(attrs as Attrs));
180
181 let entrypoint = make_entrypoint(&item, opts.clone());
182 let code_signature = make_code_sig(opts);
183
184 quote! {
185 const _: () = {
186 #code_signature
187
188 #entrypoint
189 };
190 }
191 .into()
192}
193
194#[cfg(test)]
195mod test {
196 use syn::Ident;
197
198 use super::*;
199
200 #[test]
201 fn wraps_main_fn() {
202 let source = quote! {
203 async fn main(_peripherals: Peripherals) {
204 println!("Hello, world!");
205 }
206 };
207
208 let input = syn::parse2::<ItemFn>(source.clone()).unwrap();
209 let output = make_entrypoint(&input, MacroOpts::default());
210
211 assert_eq!(
212 output.to_string(),
213 quote! {
214 #[no_mangle]
215 unsafe extern "C" fn _start() -> ! {
216 ::vexide::startup::startup::<true>(::vexide::startup::banner::themes::THEME_DEFAULT);
217
218 #source
219
220 let termination: () = ::vexide::runtime::block_on(
221 main(::vexide::devices::peripherals::Peripherals::take().unwrap())
222 );
223
224 ::vexide::program::Termination::report(termination);
225 ::vexide::program::exit();
226 }
227 }
228 .to_string()
229 );
230 }
231
232 #[test]
233 fn toggles_banner_using_parsed_opts() {
234 let source = quote! {
235 async fn main(_peripherals: Peripherals) {
236 println!("Hello, world!");
237 }
238 };
239 let input = syn::parse2::<ItemFn>(source.clone()).unwrap();
240 let entrypoint = make_entrypoint(
241 &input,
242 MacroOpts {
243 banner_enabled: false,
244 banner_theme: None,
245 code_sig: None,
246 },
247 );
248 assert!(entrypoint.to_string().contains("false"));
249 assert!(!entrypoint.to_string().contains("true"));
250
251 let entrypoint = make_entrypoint(
252 &input,
253 MacroOpts {
254 banner_enabled: true,
255 banner_theme: None,
256 code_sig: None,
257 },
258 );
259 assert!(entrypoint.to_string().contains("true"));
260 assert!(!entrypoint.to_string().contains("false"));
261 }
262
263 #[test]
264 fn uses_custom_code_sig_from_parsed_opts() {
265 let code_sig = make_code_sig(MacroOpts {
266 banner_enabled: false,
267 banner_theme: None,
268 code_sig: Some(Ident::new(
269 "__custom_code_sig_ident__",
270 proc_macro2::Span::call_site(),
271 )),
272 });
273
274 println!("{}", code_sig.to_string());
275 assert!(code_sig.to_string().contains(
276 "static CODE_SIGNATURE : :: vexide :: startup :: CodeSignature = __custom_code_sig_ident__ ;"
277 ));
278 }
279
280 #[test]
281 fn requires_async() {
282 let source = quote! {
283 fn main(_peripherals: Peripherals) {
284 println!("Hello, world!");
285 }
286 };
287
288 let input = syn::parse2::<ItemFn>(source.clone()).unwrap();
289 let output = make_entrypoint(&input, MacroOpts::default());
290
291 assert!(output.to_string().contains(NO_SYNC_ERR));
292 }
293
294 #[test]
295 fn requires_safe() {
296 let source = quote! {
297 async unsafe fn main(_peripherals: Peripherals) {
298 println!("Hello, world!");
299 }
300 };
301
302 let input = syn::parse2::<ItemFn>(source.clone()).unwrap();
303 let output = make_entrypoint(&input, MacroOpts::default());
304
305 assert!(output.to_string().contains(NO_UNSAFE_ERR));
306 }
307
308 #[test]
309 fn disallows_0_args() {
310 let source = quote! {
311 async fn main() {
312 println!("Hello, world!");
313 }
314 };
315
316 let input = syn::parse2::<ItemFn>(source.clone()).unwrap();
317 let output = make_entrypoint(&input, MacroOpts::default());
318
319 assert!(output.to_string().contains(WRONG_ARGS_ERR));
320 }
321
322 #[test]
323 fn disallows_2_args() {
324 let source = quote! {
325 async fn main(_peripherals: Peripherals, _other: Peripherals) {
326 println!("Hello, world!");
327 }
328 };
329
330 let input = syn::parse2::<ItemFn>(source.clone()).unwrap();
331 let output = make_entrypoint(&input, MacroOpts::default());
332
333 assert!(output.to_string().contains(WRONG_ARGS_ERR));
334 }
335}