valistr_proc_macro/
lib.rs

1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::Span;
3use quote::{quote, quote_spanned, TokenStreamExt};
4use syn::{
5    parse::{Parse, Parser},
6    parse_macro_input,
7    spanned::Spanned,
8    Field, Fields, FieldsNamed, Ident, ItemStruct,
9};
10
11mod utils;
12
13struct ValistrArgs {
14    regex: String,
15    // for future features
16    // container_field: Option<String>,
17    // hook_fn: Option<String>,
18}
19
20impl Parse for ValistrArgs {
21    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
22        let regex = input.parse::<syn::LitStr>()?.value();
23        Ok(ValistrArgs { regex })
24    }
25}
26
27/// This procedural macro creates immutable string wrapper types with values validated with regexes.
28#[proc_macro_attribute]
29pub fn valistr(attr: TokenStream1, item: TokenStream1) -> TokenStream1 {
30    let mut input = parse_macro_input!(item as ItemStruct);
31    let args = parse_macro_input!(attr as ValistrArgs);
32
33    if input.fields.len() > 0 {
34        return quote_spanned!(input.fields.span() => compile_error!("Only unit structs are supported");).into();
35    }
36
37    // collect regex
38    let regex_str = utils::ensure_regex_anchors(&args.regex);
39    let regex_lit = syn::LitStr::new(&regex_str, Span::call_site());
40    let regex = regex::Regex::new(&regex_str).unwrap();
41
42    // collect named groups with simple identifiers
43    let named_groups = regex
44        .capture_names()
45        .enumerate()
46        .filter_map(|(index, name)| {
47            name.filter(|name| utils::is_simple_ident(*name))
48                .map(|name| (index, name.to_string()))
49        })
50        .collect::<Vec<_>>();
51
52    // create the field `value: String` to store the value
53    let mut fields = FieldsNamed::parse.parse2(quote!({value: String})).unwrap();
54
55    // create the fields to store the capture groups, [`regex::Captures`] cannot be used directly here
56    for (_, group_name) in &named_groups {
57        let group_name_ident = Ident::new(group_name, Span::call_site());
58        fields.named.push(
59            Field::parse_named
60                .parse2(quote!(#group_name_ident: Option<(usize, usize)>))
61                .unwrap(),
62        );
63    }
64
65    // set the new fields
66    input.fields = Fields::Named(fields);
67
68    // create the `validator` method
69    let validator = quote!(
70        #[doc = "The regex to validate the value."]
71        pub fn validator() -> &'static regex::Regex {
72            static VALIDATOR: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
73            VALIDATOR.get_or_init(|| regex::Regex::new(#regex_lit).unwrap())
74        }
75    );
76
77    // create the getter methods
78    let mut getter_methods = quote!();
79    for (_, group_name) in &named_groups {
80        let group_name_ident = Ident::new(group_name, Span::call_site());
81        let get_method_ident = Ident::new(&format!("get_{}", group_name), Span::call_site());
82        let get_method = quote!(
83            #[doc = concat!("Get the value of the capture group `", stringify!(#group_name_ident), "`.")]
84            pub fn #get_method_ident(&self) -> Option<&str> {
85                self.#group_name_ident.map(|(start, end)| &self.value[start..end])
86            }
87        );
88        getter_methods.append_all(get_method);
89    }
90
91    // create the `new` method
92    let mut new_fn_capture_group_mappers = quote!();
93    for (index, group_name) in &named_groups {
94        let group_name_ident = Ident::new(group_name, Span::call_site());
95        let capture_group_mapper = quote!(
96            let #group_name_ident = captures.get(#index).map(|m| (m.start(), m.end()));
97        );
98        new_fn_capture_group_mappers.append_all(capture_group_mapper);
99    }
100
101    let named_group_names: Vec<_> = named_groups
102        .iter()
103        .map(|(_, name)| Ident::new(name, Span::call_site()))
104        .collect();
105    let new = quote!(
106        #[doc = "Try to create a new instance of the struct from a value. Returns `None` if the value fails to validate with the regex."]
107        pub fn new(value: impl Into<String>) -> Option<Self> {
108            let validator = Self::validator();
109            let value = value.into();
110
111            if let Some(captures) = validator.captures(&value) {
112                #new_fn_capture_group_mappers
113
114                Some(Self {
115                    value,
116                    #(#named_group_names,)*
117                })
118            } else {
119                None
120            }
121        }
122    );
123
124    let struct_name = &input.ident;
125
126    quote!(
127        #input
128
129        impl #struct_name {
130            #validator
131            #getter_methods
132            #new
133        }
134
135        #[automatically_derived]
136        impl std::ops::Deref for #struct_name {
137            type Target = String;
138
139            fn deref(&self) -> &Self::Target {
140                &self.value
141            }
142        }
143
144        #[automatically_derived]
145        impl std::ops::DerefMut for #struct_name {
146            fn deref_mut(&mut self) -> &mut Self::Target {
147                &mut self.value
148            }
149        }
150
151        #[automatically_derived]
152        impl std::convert::TryFrom<&str> for #struct_name {
153            type Error = ();
154
155            fn try_from(value: &str) -> Result<Self, Self::Error> {
156                Self::new(value).ok_or(())
157            }
158        }
159
160        #[automatically_derived]
161        impl std::convert::TryFrom<String> for #struct_name {
162            type Error = ();
163
164            fn try_from(value: String) -> Result<Self, Self::Error> {
165                Self::new(value).ok_or(())
166            }
167        }
168
169        #[automatically_derived]
170        impl std::fmt::Display for #struct_name {
171            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
172                std::fmt::Display::fmt(&self.value, f)
173            }
174        }
175
176        #[automatically_derived]
177        impl std::fmt::Debug for #struct_name {
178            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
179                std::fmt::Debug::fmt(&self.value, f)
180            }
181        }
182    )
183    .into()
184}