Skip to main content

struct_record/
lib.rs

1use std::str::FromStr;
2
3use convert_case::Casing;
4use proc_macro2::{TokenStream, TokenTree};
5use quote::quote;
6use syn::spanned::Spanned;
7
8struct RecordParams(syn::Type, syn::Ident, Option<syn::LitStr>);
9impl syn::parse::Parse for RecordParams {
10  fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
11    let value_type = input.parse()?;
12    let _ = input.parse::<syn::Token![,]>();
13    let ident = input.parse()?;
14    let _ = input.parse::<syn::Token![,]>();
15    let header = input.parse().ok();
16    Ok(RecordParams(value_type, ident, header))
17  }
18}
19
20const GENERIC_ERROR_MSG: &str = "Expected to be used with an enum";
21
22#[proc_macro_attribute]
23pub fn record(
24  params: proc_macro::TokenStream,
25  ast: proc_macro::TokenStream,
26) -> proc_macro::TokenStream {
27  let RecordParams(value_type, ident, header) = syn::parse2(params.into()).expect("Invalid params");
28
29  let ast_tokens: TokenStream = ast.clone().into();
30
31  let ast_iter = ast_tokens.clone().into_iter();
32
33  let mut iter_after_enum = ast_iter.skip_while(|next| {
34    if let TokenTree::Ident(ident) = next
35      && ident.to_string() != "enum"
36    {
37      true
38    } else {
39      false
40    }
41  });
42
43  let _ = iter_after_enum.next().expect(GENERIC_ERROR_MSG);
44  let _ = iter_after_enum.next().expect(GENERIC_ERROR_MSG);
45  let TokenTree::Group(enum_props_group) = iter_after_enum.next().expect(GENERIC_ERROR_MSG) else {
46    panic!("{GENERIC_ERROR_MSG}")
47  };
48
49  let struct_props: Vec<TokenStream> = enum_props_group
50    .stream()
51    .into_iter()
52    .filter_map(|token| {
53      if let TokenTree::Ident(ident) = token {
54        let ident_str = ident.to_string();
55        let snake = ident_str.to_case(convert_case::Case::Snake);
56        let snake_ident = syn::Ident::new(&snake, ident_str.span());
57
58        Some(quote! {
59          pub #snake_ident: #value_type ,
60        })
61      } else {
62        None
63      }
64    })
65    .collect();
66
67  let header_token = header
68    .map(|token| TokenStream::from_str(&token.value()).expect(GENERIC_ERROR_MSG))
69    .unwrap_or(TokenStream::new());
70
71  let output = quote! {
72    #ast_tokens
73    #header_token
74    struct #ident {
75      #(#struct_props)*
76    }
77  };
78
79  // panic!("{output}");
80
81  output.into()
82}