Skip to main content

typhoon_program_id_macro/
lib.rs

1use {
2    cargo_manifest::Manifest,
3    heck::ToUpperCamelCase,
4    proc_macro::TokenStream,
5    proc_macro2::Span,
6    quote::{format_ident, quote, ToTokens},
7    std::env::var,
8    syn::{parse::Parse, parse_macro_input, Ident, LitStr},
9};
10
11#[proc_macro]
12pub fn program_id(item: TokenStream) -> TokenStream {
13    parse_macro_input!(item as ProgramId)
14        .to_token_stream()
15        .into()
16}
17
18struct ProgramId {
19    pub name: Ident,
20    pub id: String,
21}
22
23impl Parse for ProgramId {
24    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
25        let id: LitStr = input.parse()?;
26        let name = generate_name()?;
27
28        Ok(ProgramId {
29            id: id.value(),
30            name,
31        })
32    }
33}
34
35impl ToTokens for ProgramId {
36    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
37        let id = &self.id;
38        let name = &self.name;
39
40        quote! {
41            declare_id!(#id);
42
43            pub struct #name;
44
45            impl CheckProgramId for #name {
46                #[inline(always)]
47                fn address_eq(program_id: &Address) -> bool {
48                    address_eq(program_id, &crate::ID)
49                }
50            }
51        }
52        .to_tokens(tokens);
53    }
54}
55
56fn get_cargo_toml() -> syn::Result<String> {
57    let crate_dir = var("CARGO_MANIFEST_DIR")
58        .map_err(|_| syn::Error::new(Span::call_site(), "Not in valid rust project."))?;
59
60    Ok(format!("{crate_dir}/Cargo.toml"))
61}
62
63fn generate_name() -> syn::Result<Ident> {
64    let cargo_toml = get_cargo_toml()?;
65    let manifest = Manifest::from_path(cargo_toml)
66        .map_err(|_| syn::Error::new(Span::call_site(), "Invalid Cargo.toml"))?;
67    let package_section = manifest.package.ok_or(syn::Error::new(
68        Span::call_site(),
69        "Invalid package section",
70    ))?;
71
72    Ok(format_ident!(
73        "{}Program",
74        package_section.name.to_upper_camel_case()
75    ))
76}