1use darling::FromDeriveInput;
2use proc_macro::{self, TokenStream};
3use quote::quote;
4use syn::{self, parse_macro_input, TypeParamBound, WherePredicate};
5
6#[derive(FromDeriveInput, Default)]
7#[darling(default, attributes(kernel))]
8struct KernelOpts {
9 kernel: Option<syn::Expr>,
10}
11
12#[proc_macro_derive(KernelFunctions, attributes(kernel))]
13pub fn kernel_functions_derive(input: TokenStream) -> TokenStream {
14 let ast = parse_macro_input!(input);
16 let opts = KernelOpts::from_derive_input(&ast)
17 .expect("Wrong options for KernelFunctions derive macro");
18
19 if let Some(kernel) = opts.kernel.as_ref() {
20 match kernel {
21 syn::Expr::Field(_) => (),
22 _ => panic!("kernel attribute must be a struct field access"),
23 }
24 }
25
26 impl_kernel_functions_macro(ast, opts)
28}
29
30fn impl_kernel_functions_macro(mut ast: syn::DeriveInput, opts: KernelOpts) -> TokenStream {
31 let name = &ast.ident;
32
33 #[cfg(feature = "interrupt-oracle")]
35 {
36 let mut found_oracle = false;
37 for gen in ast.generics.type_params() {
38 if gen.ident == "O" {
39 found_oracle = true;
40 break;
41 }
42 }
43 if !found_oracle {
44 panic!("KernelFunctions derive needs a generic for the oracle type called 'O'")
45 }
46 }
47
48 let kernel = if let Some(kernel) = opts.kernel {
49 kernel
50 } else {
51 let ts: TokenStream = "self.kernel".parse().unwrap();
52 parse_macro_input!(ts)
53 };
54
55 ast.generics.make_where_clause();
56 let obounds = "where";
57 #[cfg(feature = "interrupt-oracle")]
58 let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
59 let obounds: TokenStream = obounds.parse().unwrap();
60 let obounds: syn::WhereClause = parse_macro_input!(obounds);
61 ast.generics
62 .where_clause
63 .as_mut()
64 .unwrap()
65 .predicates
66 .extend(obounds.predicates);
67
68 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
69
70 quote! {
71 impl #impl_generics KernelFunctions for #name #ty_generics #where_clause {
72 fn pareto_front(&self) -> crate::types::ParetoFront {
73 self.pareto_front.clone()
74 }
75
76 fn stats(&self) -> crate::Stats {
77 #kernel.stats
78 }
79
80 fn attach_logger<L: crate::WriteSolverLog + 'static>(&mut self, logger: L) {
81 #kernel.attach_logger(logger)
82 }
83
84 fn detach_logger(&mut self) -> Option<Box<dyn crate::WriteSolverLog>> {
85 #kernel.detach_logger()
86 }
87
88 fn interrupter(&mut self) -> crate::solver::Interrupter {
89 #kernel.interrupter()
90 }
91 }
92 }
93 .into()
94}
95
96#[derive(FromDeriveInput, Default)]
97#[darling(default, attributes(solve))]
98struct SolveOpts {
99 bounds: Option<syn::WhereClause>,
100 extended_stats: bool,
101 oracle_stats: bool,
102}
103
104#[proc_macro_derive(Solve, attributes(solve))]
105pub fn solve_derive(input: TokenStream) -> TokenStream {
106 let ast = parse_macro_input!(input);
107 let kopts = KernelOpts::from_derive_input(&ast)
108 .expect("Wrong options for KernelFunctions derive macro");
109 let sopts = SolveOpts::from_derive_input(&ast).expect("Wrong options for Solve derive macro");
110
111 if let Some(kernel) = kopts.kernel.as_ref() {
112 match kernel {
113 syn::Expr::Field(_) => (),
114 _ => panic!("kernel attribute must be a struct field access"),
115 }
116 }
117
118 impl_solve_macro(ast, kopts, sopts)
120}
121
122fn impl_solve_macro(mut ast: syn::DeriveInput, kopts: KernelOpts, sopts: SolveOpts) -> TokenStream {
123 let name = &ast.ident;
124
125 let mut found_oracle = false;
127 for gen in ast.generics.type_params() {
128 if gen.ident == "O" {
129 found_oracle = true;
130 break;
131 }
132 }
133 if !found_oracle {
134 panic!("Solve derive needs a generic for the oracle type called 'O'")
135 }
136
137 let kernel = if let Some(kernel) = kopts.kernel {
138 kernel
139 } else {
140 let ts: TokenStream = "self.kernel".parse().unwrap();
141 parse_macro_input!(ts)
142 };
143
144 ast.generics.make_where_clause();
145 let obounds = "where";
146 #[cfg(feature = "interrupt-oracle")]
147 let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
148 #[cfg(feature = "phasing")]
149 let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
150 #[cfg(feature = "sol-tightening")]
151 let obounds = format!(
152 "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
153 obounds
154 );
155 #[cfg(feature = "limit-conflicts")]
156 let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
157 let obounds: TokenStream = obounds.parse().unwrap();
158 let obounds: syn::WhereClause = parse_macro_input!(obounds);
159 ast.generics
160 .where_clause
161 .as_mut()
162 .unwrap()
163 .predicates
164 .extend(obounds.predicates);
165 if let Some(add_bounds) = sopts.bounds {
166 ast.generics
167 .where_clause
168 .as_mut()
169 .unwrap()
170 .predicates
171 .extend(add_bounds.predicates)
172 }
173
174 let mut oracle_stats = false;
177 if !sopts.extended_stats {
178 if let Some(ref where_clause) = ast.generics.where_clause {
179 for pred in where_clause.predicates.iter() {
180 if let WherePredicate::Type(typ) = pred {
181 for bound in typ.bounds.iter() {
182 if let TypeParamBound::Trait(tb) = bound {
183 if tb.path.segments.last().unwrap().ident == "SolveStats" {
184 oracle_stats = true;
185 }
186 }
187 }
188 }
189 }
190 }
191 }
192
193 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
194
195 if sopts.extended_stats {
196 quote! {
197 impl #impl_generics Solve for #name #ty_generics #where_clause {
198 fn solve(&mut self, limits: Limits) -> Result<bool, Termination> {
199 #kernel.start_solving(limits);
200 self.alg_main()?;
201 Ok(true)
202 }
203
204 fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
205 use crate::ExtendedSolveStats;
206 (#kernel.stats, Some(self.oracle_stats()),
207 Some(self.encoding_stats()))
208 }
209 }
210 }
211 } else if oracle_stats {
212 quote!{
213 impl #impl_generics Solve for #name #ty_generics #where_clause {
214 fn solve(&mut self, limits: Limits) -> Result<bool, Termination> {
215 #kernel.start_solving(limits);
216 self.alg_main()?;
217 Ok(true)
218 }
219
220 fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
221 use rustsat::solvers::SolveStats;
222 (#kernel.stats, Some(#kernel.oracle.stats()), None)
223 }
224 }
225 }
226 } else {
227 quote!{
228 impl #impl_generics Solve for #name #ty_generics #where_clause {
229 fn solve(&mut self, limits: Limits) -> Result<bool, Termination> {
230 #kernel.start_solving(limits);
231 self.alg_main()?;
232 Ok(true)
233 }
234
235 fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
236 (#kernel.stats, None, None)
237 }
238 }
239 }
240 }
241 .into()
242}
243
244#[proc_macro_attribute]
245pub fn oracle_bounds(_attr: TokenStream, item: TokenStream) -> TokenStream {
246 let ast: syn::Item = parse_macro_input!(item);
247 let impl_block = match ast {
248 syn::Item::Impl(impl_block) => impl_block,
249 _ => panic!("oracle_bounds attribute can only be used on impl blocks"),
250 };
251
252 insert_oracle_bounds(impl_block)
253}
254
255fn insert_oracle_bounds(mut impl_block: syn::ItemImpl) -> TokenStream {
256 let mut found_oracle = false;
258 for gen in impl_block.generics.type_params() {
259 if gen.ident == "O" {
260 found_oracle = true;
261 break;
262 }
263 }
264 if !found_oracle {
265 panic!("oracle_bounds attribute needs a generic for the oracle type called 'O'")
266 }
267
268 let obounds = "where";
269 #[cfg(feature = "interrupt-oracle")]
270 let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
271 #[cfg(feature = "phasing")]
272 let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
273 #[cfg(feature = "sol-tightening")]
274 let obounds = format!(
275 "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
276 obounds
277 );
278 #[cfg(feature = "limit-conflicts")]
279 let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
280 let obounds: TokenStream = obounds.parse().unwrap();
281 let obounds: syn::WhereClause = parse_macro_input!(obounds);
282
283 impl_block.generics.make_where_clause();
284 impl_block
285 .generics
286 .where_clause
287 .as_mut()
288 .unwrap()
289 .predicates
290 .extend(obounds.predicates);
291
292 quote! { #impl_block }.into()
293 }