valistr_proc_macro/
lib.rs

1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::Span;
3use quote::{quote, quote_spanned, TokenStreamExt};
4use syn::{
5    parse::{Parse, Parser},
6    parse_macro_input,
7    spanned::Spanned,
8    Field, Fields, FieldsNamed, Ident, ItemStruct,
9};
10
11mod utils;
12
13struct ValistrArgs {
14    regex: String,
15    // for future features
16    // container_field: Option<String>,
17    // hook_fn: Option<String>,
18}
19
20impl Parse for ValistrArgs {
21    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
22        let regex = input.parse::<syn::LitStr>()?.value();
23        Ok(ValistrArgs { regex })
24    }
25}
26
27/// This procedural macro creates immutable string wrapper types with values validated with regexes.
28#[proc_macro_attribute]
29pub fn valistr(attr: TokenStream1, item: TokenStream1) -> TokenStream1 {
30    let mut input = parse_macro_input!(item as ItemStruct);
31    let args = parse_macro_input!(attr as ValistrArgs);
32
33    if !input.fields.is_empty() {
34        return quote_spanned!(input.fields.span() => compile_error!("Only unit structs are supported");).into();
35    }
36
37    // get the path to the regex re-exported by valistr
38    let regex_path = match utils::get_regex_reexport_path() {
39        Ok(path) => path,
40        Err(err) => return err.into(),
41    };
42
43    // collect regex
44    let regex_str = utils::ensure_regex_anchors(&args.regex);
45    let regex_lit = syn::LitStr::new(&regex_str, Span::call_site());
46    let regex = regex::Regex::new(&regex_str).unwrap();
47
48    // collect named groups with simple identifiers
49    let named_groups = regex
50        .capture_names()
51        .enumerate()
52        .filter_map(|(index, name)| {
53            name.filter(|name| utils::is_simple_ident(name))
54                .map(|name| (index, name.to_string()))
55        })
56        .collect::<Vec<_>>();
57
58    // create the field `value: String` to store the value
59    let mut fields = FieldsNamed::parse.parse2(quote!({value: String})).unwrap();
60
61    // create the fields to store the capture groups, [`regex::Captures`] cannot be used directly here
62    for (_, group_name) in &named_groups {
63        let group_name_ident = Ident::new(group_name, Span::call_site());
64        fields.named.push(
65            Field::parse_named
66                .parse2(quote!(#group_name_ident: Option<(usize, usize)>))
67                .unwrap(),
68        );
69    }
70
71    // set the new fields
72    input.fields = Fields::Named(fields);
73
74    // create the `validator` method
75    let validator = quote!(
76        #[doc = "The regex to validate the value."]
77        pub fn validator() -> &'static #regex_path::Regex {
78            static VALIDATOR: std::sync::OnceLock<#regex_path::Regex> = std::sync::OnceLock::new();
79            VALIDATOR.get_or_init(|| #regex_path::Regex::new(#regex_lit).unwrap())
80        }
81    );
82
83    // create the getter methods
84    let mut getter_methods = quote!();
85    for (_, group_name) in &named_groups {
86        let group_name_ident = Ident::new(group_name, Span::call_site());
87        let get_method_ident = Ident::new(&format!("get_{group_name}"), Span::call_site());
88        let get_method = quote!(
89            #[doc = concat!("Get the value of the capture group `", stringify!(#group_name_ident), "`.")]
90            pub fn #get_method_ident(&self) -> Option<&str> {
91                self.#group_name_ident.map(|(start, end)| &self.value[start..end])
92            }
93        );
94        getter_methods.append_all(get_method);
95    }
96
97    // create the `new` method
98    let mut new_fn_capture_group_mappers = quote!();
99    for (index, group_name) in &named_groups {
100        let group_name_ident = Ident::new(group_name, Span::call_site());
101        let capture_group_mapper = quote!(
102            let #group_name_ident = captures.get(#index).map(|m| (m.start(), m.end()));
103        );
104        new_fn_capture_group_mappers.append_all(capture_group_mapper);
105    }
106
107    let named_group_names: Vec<_> = named_groups
108        .iter()
109        .map(|(_, name)| Ident::new(name, Span::call_site()))
110        .collect();
111    let new = quote!(
112        #[doc = "Try to create a new instance of the struct from a value. Returns `None` if the value fails to validate with the regex."]
113        pub fn new(value: impl Into<String>) -> Option<Self> {
114            let validator = Self::validator();
115            let value = value.into();
116
117            if let Some(captures) = validator.captures(&value) {
118                #new_fn_capture_group_mappers
119
120                Some(Self {
121                    value,
122                    #(#named_group_names,)*
123                })
124            } else {
125                None
126            }
127        }
128    );
129
130    let struct_name = &input.ident;
131
132    quote!(
133        #input
134
135        impl #struct_name {
136            #validator
137            #getter_methods
138            #new
139        }
140
141        #[automatically_derived]
142        impl std::ops::Deref for #struct_name {
143            type Target = String;
144
145            fn deref(&self) -> &Self::Target {
146                &self.value
147            }
148        }
149
150        #[automatically_derived]
151        impl std::ops::DerefMut for #struct_name {
152            fn deref_mut(&mut self) -> &mut Self::Target {
153                &mut self.value
154            }
155        }
156
157        #[automatically_derived]
158        impl std::convert::TryFrom<&str> for #struct_name {
159            type Error = ();
160
161            fn try_from(value: &str) -> Result<Self, Self::Error> {
162                Self::new(value).ok_or(())
163            }
164        }
165
166        #[automatically_derived]
167        impl std::convert::TryFrom<String> for #struct_name {
168            type Error = ();
169
170            fn try_from(value: String) -> Result<Self, Self::Error> {
171                Self::new(value).ok_or(())
172            }
173        }
174
175        #[automatically_derived]
176        impl std::fmt::Display for #struct_name {
177            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
178                std::fmt::Display::fmt(&self.value, f)
179            }
180        }
181
182        #[automatically_derived]
183        impl std::fmt::Debug for #struct_name {
184            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
185                std::fmt::Debug::fmt(&self.value, f)
186            }
187        }
188    )
189    .into()
190}