scuttle_proc/
lib.rs

1use darling::FromDeriveInput;
2use proc_macro::{self, TokenStream};
3use quote::quote;
4use syn::{self, parse_macro_input, TypeParamBound, WherePredicate};
5
6#[derive(FromDeriveInput, Default)]
7#[darling(default, attributes(kernel))]
8struct KernelOpts {
9    kernel: Option<syn::Expr>,
10}
11
12#[proc_macro_derive(KernelFunctions, attributes(kernel))]
13pub fn kernel_functions_derive(input: TokenStream) -> TokenStream {
14    // Construct a representation of the code as a syntax tree to manipulate
15    let ast = parse_macro_input!(input);
16    let opts = KernelOpts::from_derive_input(&ast)
17        .expect("Wrong options for KernelFunctions derive macro");
18
19    if let Some(kernel) = opts.kernel.as_ref() {
20        match kernel {
21            syn::Expr::Field(_) => (),
22            _ => panic!("kernel attribute must be a struct field access"),
23        }
24    }
25
26    // Build the trait implementation
27    impl_kernel_functions_macro(ast, opts)
28}
29
30fn impl_kernel_functions_macro(mut ast: syn::DeriveInput, opts: KernelOpts) -> TokenStream {
31    let name = &ast.ident;
32
33    // Check whether type has generic named O that is assumed to be the oracle
34    #[cfg(feature = "interrupt-oracle")]
35    {
36        let mut found_oracle = false;
37        for gen in ast.generics.type_params() {
38            if gen.ident == "O" {
39                found_oracle = true;
40                break;
41            }
42        }
43        if !found_oracle {
44            panic!("KernelFunctions derive needs a generic for the oracle type called 'O'")
45        }
46    }
47
48    let kernel = if let Some(kernel) = opts.kernel {
49        kernel
50    } else {
51        let ts: TokenStream = "self.kernel".parse().unwrap();
52        parse_macro_input!(ts)
53    };
54
55    ast.generics.make_where_clause();
56    let obounds = "where";
57    #[cfg(feature = "interrupt-oracle")]
58    let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
59    let obounds: TokenStream = obounds.parse().unwrap();
60    let obounds: syn::WhereClause = parse_macro_input!(obounds);
61    ast.generics
62        .where_clause
63        .as_mut()
64        .unwrap()
65        .predicates
66        .extend(obounds.predicates);
67
68    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
69
70    quote! {
71        impl #impl_generics KernelFunctions for #name #ty_generics #where_clause {
72            fn pareto_front(&self) -> crate::types::ParetoFront {
73                self.pareto_front.clone()
74            }
75
76            fn stats(&self) -> crate::Stats {
77                #kernel.stats
78            }
79
80            fn attach_logger<L: crate::WriteSolverLog + 'static>(&mut self, logger: L) {
81                #kernel.attach_logger(logger)
82            }
83
84            fn detach_logger(&mut self) -> Option<Box<dyn crate::WriteSolverLog>> {
85                #kernel.detach_logger()
86            }
87
88            fn interrupter(&mut self) -> crate::solver::Interrupter {
89                #kernel.interrupter()
90            }
91        }
92    }
93    .into()
94}
95
96#[derive(FromDeriveInput, Default)]
97#[darling(default, attributes(solve))]
98struct SolveOpts {
99    bounds: Option<syn::WhereClause>,
100    extended_stats: bool,
101    oracle_stats: bool,
102}
103
104#[proc_macro_derive(Solve, attributes(solve))]
105pub fn solve_derive(input: TokenStream) -> TokenStream {
106    let ast = parse_macro_input!(input);
107    let kopts = KernelOpts::from_derive_input(&ast)
108        .expect("Wrong options for KernelFunctions derive macro");
109    let sopts = SolveOpts::from_derive_input(&ast).expect("Wrong options for Solve derive macro");
110
111    if let Some(kernel) = kopts.kernel.as_ref() {
112        match kernel {
113            syn::Expr::Field(_) => (),
114            _ => panic!("kernel attribute must be a struct field access"),
115        }
116    }
117
118    // Build the trait implementation
119    impl_solve_macro(ast, kopts, sopts)
120}
121
122fn impl_solve_macro(mut ast: syn::DeriveInput, kopts: KernelOpts, sopts: SolveOpts) -> TokenStream {
123    let name = &ast.ident;
124
125    // Check whether type has generic named O that is assumed to be the oracle
126    let mut found_oracle = false;
127    for gen in ast.generics.type_params() {
128        if gen.ident == "O" {
129            found_oracle = true;
130            break;
131        }
132    }
133    if !found_oracle {
134        panic!("Solve derive needs a generic for the oracle type called 'O'")
135    }
136
137    let kernel = if let Some(kernel) = kopts.kernel {
138        kernel
139    } else {
140        let ts: TokenStream = "self.kernel".parse().unwrap();
141        parse_macro_input!(ts)
142    };
143
144    ast.generics.make_where_clause();
145    let obounds = "where";
146    #[cfg(feature = "interrupt-oracle")]
147    let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
148    #[cfg(feature = "phasing")]
149    let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
150    #[cfg(feature = "sol-tightening")]
151    let obounds = format!(
152        "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
153        obounds
154    );
155    #[cfg(feature = "limit-conflicts")]
156    let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
157    let obounds: TokenStream = obounds.parse().unwrap();
158    let obounds: syn::WhereClause = parse_macro_input!(obounds);
159    ast.generics
160        .where_clause
161        .as_mut()
162        .unwrap()
163        .predicates
164        .extend(obounds.predicates);
165    if let Some(add_bounds) = sopts.bounds {
166        ast.generics
167            .where_clause
168            .as_mut()
169            .unwrap()
170            .predicates
171            .extend(add_bounds.predicates)
172    }
173
174    // If O: SolveStats is satisfied, add oracle stats
175    // (this doesn't actually check that this bound is on O)
176    let mut oracle_stats = false;
177    if !sopts.extended_stats {
178        if let Some(ref where_clause) = ast.generics.where_clause {
179            for pred in where_clause.predicates.iter() {
180                if let WherePredicate::Type(typ) = pred {
181                    for bound in typ.bounds.iter() {
182                        if let TypeParamBound::Trait(tb) = bound {
183                            if tb.path.segments.last().unwrap().ident == "SolveStats" {
184                                oracle_stats = true;
185                            }
186                        }
187                    }
188                }
189            }
190        }
191    }
192
193    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
194
195    if sopts.extended_stats {
196        quote! {
197            impl #impl_generics Solve for #name #ty_generics #where_clause {
198                fn solve(&mut self, limits: Limits) -> Result<bool, Termination> {
199                    #kernel.start_solving(limits);
200                    self.alg_main()?;
201                    Ok(true)
202                }
203
204                fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
205                    use crate::ExtendedSolveStats;
206                    (#kernel.stats, Some(self.oracle_stats()),
207                    Some(self.encoding_stats()))
208                }
209            }
210        }
211    } else if oracle_stats {
212        quote!{
213            impl #impl_generics Solve for #name #ty_generics #where_clause {
214                fn solve(&mut self, limits: Limits) -> Result<bool, Termination> {
215                    #kernel.start_solving(limits);
216                    self.alg_main()?;
217                    Ok(true)
218                }
219
220                fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
221                    use rustsat::solvers::SolveStats;
222                    (#kernel.stats, Some(#kernel.oracle.stats()), None)
223                }
224            }
225        }
226    } else {
227        quote!{
228            impl #impl_generics Solve for #name #ty_generics #where_clause {
229                fn solve(&mut self, limits: Limits) -> Result<bool, Termination> {
230                    #kernel.start_solving(limits);
231                    self.alg_main()?;
232                    Ok(true)
233                }
234
235                fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
236                    (#kernel.stats, None, None)
237                }
238            }
239        }
240    }
241    .into()
242}
243
244#[proc_macro_attribute]
245pub fn oracle_bounds(_attr: TokenStream, item: TokenStream) -> TokenStream {
246    let ast: syn::Item = parse_macro_input!(item);
247    let impl_block = match ast {
248        syn::Item::Impl(impl_block) => impl_block,
249        _ => panic!("oracle_bounds attribute can only be used on impl blocks"),
250    };
251
252    insert_oracle_bounds(impl_block)
253}
254
255fn insert_oracle_bounds(mut impl_block: syn::ItemImpl) -> TokenStream {
256    // Check whether type has generic named O that is assumed to be the oracle
257    let mut found_oracle = false;
258    for gen in impl_block.generics.type_params() {
259        if gen.ident == "O" {
260            found_oracle = true;
261            break;
262        }
263    }
264    if !found_oracle {
265        panic!("oracle_bounds attribute needs a generic for the oracle type called 'O'")
266    }
267
268    let obounds = "where";
269    #[cfg(feature = "interrupt-oracle")]
270    let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
271    #[cfg(feature = "phasing")]
272    let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
273    #[cfg(feature = "sol-tightening")]
274    let obounds = format!(
275        "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
276        obounds
277    );
278    #[cfg(feature = "limit-conflicts")]
279    let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
280    let obounds: TokenStream = obounds.parse().unwrap();
281    let obounds: syn::WhereClause = parse_macro_input!(obounds);
282
283    impl_block.generics.make_where_clause();
284    impl_block
285        .generics
286        .where_clause
287        .as_mut()
288        .unwrap()
289        .predicates
290        .extend(obounds.predicates);
291
292    quote! { #impl_block }.into()
293    //}
294}