solverforge_macros/
lib.rs

1//! Macros for SolverForge domain models.
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::Parser;
6use syn::{parse_macro_input, Attribute, DeriveInput, Expr, ItemStruct, Lit, Meta};
7
8mod planning_entity;
9mod planning_solution;
10mod problem_fact;
11
12/// Checks if attribute stream contains the `serde` flag.
13fn has_serde_flag(attr: TokenStream) -> bool {
14    if attr.is_empty() {
15        return false;
16    }
17    let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
18    if let Ok(nested) = parser.parse(attr) {
19        for meta in nested {
20            if let Meta::Path(path) = meta {
21                if path.is_ident("serde") {
22                    return true;
23                }
24            }
25        }
26    }
27    false
28}
29
30#[proc_macro_attribute]
31pub fn planning_entity(attr: TokenStream, item: TokenStream) -> TokenStream {
32    let has_serde = has_serde_flag(attr);
33    let input = parse_macro_input!(item as ItemStruct);
34    let name = &input.ident;
35    let vis = &input.vis;
36    let generics = &input.generics;
37    let attrs: Vec<_> = input.attrs.iter().collect();
38    let fields = &input.fields;
39
40    let serde_derives = if has_serde {
41        quote! { ::serde::Serialize, ::serde::Deserialize, }
42    } else {
43        quote! {}
44    };
45
46    let expanded = quote! {
47        #[derive(Clone, Debug, PartialEq, Eq, Hash, #serde_derives ::solverforge::PlanningEntityImpl)]
48        #(#attrs)*
49        #vis struct #name #generics #fields
50    };
51    expanded.into()
52}
53
54/// Parses planning_solution attribute flags: serde, constraints = "path".
55fn parse_solution_flags(attr: TokenStream) -> (bool, Option<String>) {
56    let mut has_serde = false;
57    let mut constraints_path = None;
58
59    if attr.is_empty() {
60        return (has_serde, constraints_path);
61    }
62
63    let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
64    if let Ok(nested) = parser.parse(attr) {
65        for meta in nested {
66            match meta {
67                Meta::Path(path) if path.is_ident("serde") => {
68                    has_serde = true;
69                }
70                Meta::NameValue(nv) if nv.path.is_ident("constraints") => {
71                    if let Expr::Lit(expr_lit) = &nv.value {
72                        if let Lit::Str(lit_str) = &expr_lit.lit {
73                            constraints_path = Some(lit_str.value());
74                        }
75                    }
76                }
77                _ => {}
78            }
79        }
80    }
81
82    (has_serde, constraints_path)
83}
84
85#[proc_macro_attribute]
86pub fn planning_solution(attr: TokenStream, item: TokenStream) -> TokenStream {
87    let (has_serde, constraints_path) = parse_solution_flags(attr);
88    let input = parse_macro_input!(item as ItemStruct);
89    let name = &input.ident;
90    let vis = &input.vis;
91    let generics = &input.generics;
92    let attrs: Vec<_> = input.attrs.iter().collect();
93    let fields = &input.fields;
94
95    let serde_derives = if has_serde {
96        quote! { ::serde::Serialize, ::serde::Deserialize, }
97    } else {
98        quote! {}
99    };
100
101    let constraints_attr =
102        constraints_path.map(|p| quote! { #[solverforge_constraints_path = #p] });
103
104    let expanded = quote! {
105        #[derive(Clone, Debug, #serde_derives ::solverforge::PlanningSolutionImpl)]
106        #constraints_attr
107        #(#attrs)*
108        #vis struct #name #generics #fields
109    };
110    expanded.into()
111}
112
113#[proc_macro_attribute]
114pub fn problem_fact(attr: TokenStream, item: TokenStream) -> TokenStream {
115    let has_serde = has_serde_flag(attr);
116    let input = parse_macro_input!(item as ItemStruct);
117    let name = &input.ident;
118    let vis = &input.vis;
119    let generics = &input.generics;
120    let attrs: Vec<_> = input.attrs.iter().collect();
121    let fields = &input.fields;
122
123    let serde_derives = if has_serde {
124        quote! { ::serde::Serialize, ::serde::Deserialize, }
125    } else {
126        quote! {}
127    };
128
129    let expanded = quote! {
130        #[derive(Clone, Debug, PartialEq, Eq, #serde_derives ::solverforge::ProblemFactImpl)]
131        #(#attrs)*
132        #vis struct #name #generics #fields
133    };
134    expanded.into()
135}
136
137#[proc_macro_derive(
138    PlanningEntityImpl,
139    attributes(
140        planning_id,
141        planning_variable,
142        planning_list_variable,
143        planning_pin,
144        inverse_relation_shadow_variable,
145        previous_element_shadow_variable,
146        next_element_shadow_variable,
147        cascading_update_shadow_variable
148    )
149)]
150pub fn derive_planning_entity(input: TokenStream) -> TokenStream {
151    let input = parse_macro_input!(input as DeriveInput);
152    planning_entity::expand_derive(input)
153        .unwrap_or_else(|e| e.to_compile_error())
154        .into()
155}
156
157#[proc_macro_derive(
158    PlanningSolutionImpl,
159    attributes(
160        planning_entity_collection,
161        problem_fact_collection,
162        planning_score,
163        value_range_provider,
164        shadow_variable_updates,
165        basic_variable_config,
166        solverforge_constraints_path
167    )
168)]
169pub fn derive_planning_solution(input: TokenStream) -> TokenStream {
170    let input = parse_macro_input!(input as DeriveInput);
171    planning_solution::expand_derive(input)
172        .unwrap_or_else(|e| e.to_compile_error())
173        .into()
174}
175
176#[proc_macro_derive(ProblemFactImpl, attributes(planning_id))]
177pub fn derive_problem_fact(input: TokenStream) -> TokenStream {
178    let input = parse_macro_input!(input as DeriveInput);
179    problem_fact::expand_derive(input)
180        .unwrap_or_else(|e| e.to_compile_error())
181        .into()
182}
183
184fn has_attribute(attrs: &[Attribute], name: &str) -> bool {
185    attrs.iter().any(|attr| attr.path().is_ident(name))
186}
187
188fn get_attribute<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attribute> {
189    attrs.iter().find(|attr| attr.path().is_ident(name))
190}
191
192fn parse_attribute_bool(attr: &Attribute, key: &str) -> Option<bool> {
193    if let Meta::List(meta_list) = &attr.meta {
194        let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
195        if let Ok(nested) = parser.parse2(meta_list.tokens.clone()) {
196            for meta in nested {
197                if let Meta::NameValue(nv) = meta {
198                    if nv.path.is_ident(key) {
199                        if let Expr::Lit(expr_lit) = &nv.value {
200                            if let Lit::Bool(lit_bool) = &expr_lit.lit {
201                                return Some(lit_bool.value());
202                            }
203                        }
204                    }
205                }
206            }
207        }
208    }
209    None
210}
211
212fn parse_attribute_string(attr: &Attribute, key: &str) -> Option<String> {
213    if let Meta::List(meta_list) = &attr.meta {
214        let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
215        if let Ok(nested) = parser.parse2(meta_list.tokens.clone()) {
216            for meta in nested {
217                if let Meta::NameValue(nv) = meta {
218                    if nv.path.is_ident(key) {
219                        if let Expr::Lit(expr_lit) = &nv.value {
220                            if let Lit::Str(lit_str) = &expr_lit.lit {
221                                return Some(lit_str.value());
222                            }
223                        }
224                    }
225                }
226            }
227        }
228    }
229    None
230}
231
232fn parse_attribute_list(attr: &Attribute, key: &str) -> Vec<String> {
233    let mut result = Vec::new();
234    if let Meta::List(meta_list) = &attr.meta {
235        let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
236        if let Ok(nested) = parser.parse2(meta_list.tokens.clone()) {
237            for meta in nested {
238                if let Meta::NameValue(nv) = meta {
239                    if nv.path.is_ident(key) {
240                        if let Expr::Lit(expr_lit) = &nv.value {
241                            if let Lit::Str(lit_str) = &expr_lit.lit {
242                                result.push(lit_str.value());
243                            }
244                        }
245                    }
246                }
247            }
248        }
249    }
250    result
251}