varisat_internal_macros/
lib.rs

1//! Internal macros for the Varisat SAT solver.
2#![recursion_limit = "128"]
3use std::fmt::Write;
4
5use proc_macro2::TokenStream;
6use quote::quote;
7use syn::{
8    parse_quote, punctuated::Punctuated, Attribute, Fields, Ident, Lit, LitStr, Meta,
9    MetaNameValue, Token,
10};
11use synstructure::decl_derive;
12
13/// Get the doc comment as LitStr from the attributes
14fn doc_from_attrs(attrs: &[Attribute]) -> Vec<LitStr> {
15    let mut lines = vec![];
16
17    for attr in attrs.iter() {
18        if let Ok(Meta::NameValue(MetaNameValue {
19            path,
20            lit: Lit::Str(doc_str),
21            ..
22        })) = attr.parse_meta()
23        {
24            if let Some(ident) = path.get_ident() {
25                if ident == "doc" {
26                    lines.push(doc_str);
27                }
28            }
29        }
30    }
31
32    lines
33}
34
35/// Find a field inside the doc comment
36fn get_doc_field(name: &str, attrs: &[Attribute]) -> Option<LitStr> {
37    let re = regex::Regex::new(&format!(r"\[{}: (.+?)\](  |$)", regex::escape(name))).unwrap();
38
39    for doc_str in doc_from_attrs(attrs) {
40        if let Some(expr_str) = re.captures(&doc_str.value()) {
41            let expr_str = expr_str.get(1).unwrap().as_str();
42            let expr_str = LitStr::new(expr_str, doc_str.span());
43            return Some(expr_str);
44        }
45    }
46
47    None
48}
49
50/// Derives a default instance from the documentation.
51fn derive_doc_default(s: synstructure::Structure) -> TokenStream {
52    let variant = match s.variants() {
53        [variant] => variant,
54        _ => panic!("DocDefault requires a struct"),
55    };
56
57    let body = variant.construct(|field, _| {
58        get_doc_field("default", &field.attrs)
59            .map(|expr_str| {
60                expr_str
61                    .parse::<TokenStream>()
62                    .expect("error parsing default expression")
63            })
64            .unwrap_or_else(|| parse_quote!(Default::default()))
65    });
66
67    s.gen_impl(quote! {
68        gen impl Default for @Self {
69            fn default() -> Self {
70                #body
71            }
72        }
73    })
74}
75
76decl_derive!([DocDefault] => derive_doc_default);
77
78/// Derives an update struct and method for a config struct.
79fn derive_config_update(s: synstructure::Structure) -> TokenStream {
80    let variant = match s.variants() {
81        [variant] => variant,
82        _ => panic!("ConfigUpdate requires a struct"),
83    };
84
85    let fields = match variant.ast().fields {
86        Fields::Named(fields_named) => &fields_named.named,
87        _ => panic!("ConfigUpdate requires named fields"),
88    };
89
90    assert!(
91        s.referenced_ty_params().is_empty(),
92        "ConfigUpdate doesn't support type parameters"
93    );
94
95    let ident = &s.ast().ident;
96    let update_struct_ident = Ident::new(&format!("{}Update", ident), ident.span());
97
98    let vis = &s.ast().vis;
99
100    let update_struct_body = fields
101        .iter()
102        .map(|field| {
103            let ty = &field.ty;
104            let mut field = field.clone();
105            field.ty = parse_quote!(Option<#ty>);
106            field
107        })
108        .collect::<Punctuated<_, Token![,]>>();
109
110    let check_ranges = fields
111        .iter()
112        .map(|field| {
113            if let Some(range) = get_doc_field("range", &field.attrs) {
114                // TODO use toml instead of fmt::Debug for errors?
115                let ident = &field.ident;
116                let error_msg = format!(
117                    "{} must be in range {} but was set to {{:?}}",
118                    quote!(#ident),
119                    range.value()
120                );
121                let range = range
122                    .parse::<TokenStream>()
123                    .expect("error parsing range expression");
124                quote! {
125                    if let Some(value) = &self.#ident {
126                        anyhow::ensure!((#range).contains(value), #error_msg, value);
127                    }
128                }
129            } else {
130                quote!()
131            }
132        })
133        .collect::<TokenStream>();
134
135    let apply_updates = fields
136        .iter()
137        .map(|field| {
138            let ident = &field.ident;
139            quote! {
140                if let Some(value) = &self.#ident {
141                    config.#ident = value.clone();
142                }
143            }
144        })
145        .collect::<TokenStream>();
146
147    let merge_updates = fields
148        .iter()
149        .map(|field| {
150            let ident = &field.ident;
151            quote! {
152                if let Some(value) = config_update.#ident {
153                    self.#ident = Some(value);
154                }
155            }
156        })
157        .collect::<TokenStream>();
158
159    let mut help_str = String::new();
160
161    for field in fields.iter() {
162        let ident = &field.ident;
163        writeln!(&mut help_str, "{}:", quote!(#ident)).unwrap();
164        for line in doc_from_attrs(&field.attrs).iter() {
165            if line.value().is_empty() {
166                writeln!(&mut help_str).unwrap();
167            } else {
168                writeln!(&mut help_str, "   {}", line.value()).unwrap();
169            }
170        }
171        writeln!(&mut help_str).unwrap();
172    }
173
174    let doc = format!("Updates configuration values of [`{}`].", ident);
175
176    quote! {
177        #[doc = #doc]
178        #[derive(Default, serde::Serialize, serde::Deserialize)]
179        #[serde(deny_unknown_fields)]
180        #vis struct #update_struct_ident {
181            #update_struct_body
182        }
183
184        impl #ident {
185            /// Return a string describing all supported configuration options.
186            pub fn help() -> &'static str {
187                #help_str
188            }
189        }
190
191        impl #update_struct_ident {
192            /// Create an empty config update.
193            pub fn new() -> #update_struct_ident {
194                #update_struct_ident::default()
195            }
196
197            /// Apply the configuration update.
198            ///
199            /// If an error occurs, the configuration is not changed.
200            pub fn apply(&self, config: &mut #ident) -> Result<(), anyhow::Error> {
201                #check_ranges
202                #apply_updates
203                Ok(())
204            }
205
206            /// Merge two configuration updates.
207            ///
208            /// Add the given update, overwriting values of the receiving update.
209            pub fn merge(&mut self, config_update: #update_struct_ident) {
210                #merge_updates
211            }
212        }
213    }
214}
215
216decl_derive!([ConfigUpdate] => derive_config_update);