1use darling::FromDeriveInput;
2use proc_macro::{self, TokenStream};
3use quote::quote;
4use syn::{self, parse_macro_input};
5
6#[derive(FromDeriveInput, Default)]
7#[darling(default, attributes(kernel))]
8struct KernelOpts {
9 kernel: Option<syn::Expr>,
10}
11
12#[proc_macro_derive(KernelFunctions, attributes(kernel))]
13pub fn kernel_functions_derive(input: TokenStream) -> TokenStream {
14 let ast = parse_macro_input!(input);
16 let opts = KernelOpts::from_derive_input(&ast)
17 .expect("Wrong options for KernelFunctions derive macro");
18
19 if let Some(kernel) = opts.kernel.as_ref() {
20 match kernel {
21 syn::Expr::Field(_) => (),
22 _ => panic!("kernel attribute must be a struct field access"),
23 }
24 }
25
26 impl_kernel_functions_macro(ast, opts)
28}
29
30fn impl_kernel_functions_macro(mut ast: syn::DeriveInput, opts: KernelOpts) -> TokenStream {
31 let name = &ast.ident;
32
33 #[cfg(feature = "interrupt-oracle")]
35 {
36 let mut found_oracle = false;
37 for gen in ast.generics.type_params() {
38 if gen.ident == "O" {
39 found_oracle = true;
40 break;
41 }
42 }
43 if !found_oracle {
44 panic!("KernelFunctions derive needs a generic for the oracle type called 'O'")
45 }
46 }
47
48 let kernel = if let Some(kernel) = opts.kernel {
49 kernel
50 } else {
51 let ts: TokenStream = "self.kernel".parse().unwrap();
52 parse_macro_input!(ts)
53 };
54
55 ast.generics.make_where_clause();
56 let obounds = "where";
57 #[cfg(feature = "interrupt-oracle")]
58 let obounds = format!(
59 "{} O: rustsat::solvers::Interrupt, ProofW: std::io::Write,",
60 obounds
61 );
62 let obounds: TokenStream = obounds.parse().unwrap();
63 let obounds: syn::WhereClause = parse_macro_input!(obounds);
64 ast.generics
65 .where_clause
66 .as_mut()
67 .unwrap()
68 .predicates
69 .extend(obounds.predicates);
70
71 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
72
73 quote! {
74 impl #impl_generics KernelFunctions for #name #ty_generics #where_clause {
75 fn pareto_front(&self) -> crate::types::ParetoFront {
76 self.pareto_front.clone()
77 }
78
79 fn stats(&self) -> crate::Stats {
80 #kernel.stats
81 }
82
83 fn attach_logger<L: crate::WriteSolverLog + 'static>(&mut self, logger: L) {
84 #kernel.attach_logger(logger)
85 }
86
87 fn detach_logger(&mut self) -> Option<Box<dyn crate::WriteSolverLog>> {
88 #kernel.detach_logger()
89 }
90
91 fn interrupter(&mut self) -> crate::algs::Interrupter {
92 #kernel.interrupter()
93 }
94 }
95 }
96 .into()
97}
98
99#[derive(FromDeriveInput, Default)]
100#[darling(default, attributes(solve))]
101struct SolveOpts {
102 bounds: Option<syn::WhereClause>,
103 oracle_stats: bool,
104}
105
106#[proc_macro_derive(Solve, attributes(solve))]
107pub fn solve_derive(input: TokenStream) -> TokenStream {
108 let ast = parse_macro_input!(input);
109 let kopts = KernelOpts::from_derive_input(&ast)
110 .expect("Wrong options for KernelFunctions derive macro");
111 let sopts = SolveOpts::from_derive_input(&ast).expect("Wrong options for Solve derive macro");
112
113 if let Some(kernel) = kopts.kernel.as_ref() {
114 match kernel {
115 syn::Expr::Field(_) => (),
116 _ => panic!("kernel attribute must be a struct field access"),
117 }
118 }
119
120 impl_solve_macro(ast, kopts, sopts)
122}
123
124fn impl_solve_macro(mut ast: syn::DeriveInput, kopts: KernelOpts, sopts: SolveOpts) -> TokenStream {
125 let name = &ast.ident;
126
127 let mut found_oracle = false;
129 for gen in ast.generics.type_params() {
130 if gen.ident == "O" {
131 found_oracle = true;
132 break;
133 }
134 }
135 if !found_oracle {
136 panic!("Solve derive needs a generic for the oracle type called 'O'")
137 }
138
139 let kernel = if let Some(kernel) = kopts.kernel {
140 kernel
141 } else {
142 let ts: TokenStream = "self.kernel".parse().unwrap();
143 parse_macro_input!(ts)
144 };
145
146 ast.generics.make_where_clause();
147 let obounds = "where";
148 #[cfg(feature = "interrupt-oracle")]
149 let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
150 #[cfg(feature = "phasing")]
151 let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
152 #[cfg(feature = "sol-tightening")]
153 let obounds = format!(
154 "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
155 obounds
156 );
157 #[cfg(feature = "limit-conflicts")]
158 let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
159 let obounds: TokenStream = obounds.parse().unwrap();
160 let obounds: syn::WhereClause = parse_macro_input!(obounds);
161 ast.generics
162 .where_clause
163 .as_mut()
164 .unwrap()
165 .predicates
166 .extend(obounds.predicates);
167 if let Some(add_bounds) = sopts.bounds {
168 ast.generics
169 .where_clause
170 .as_mut()
171 .unwrap()
172 .predicates
173 .extend(add_bounds.predicates)
174 }
175
176 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
177
178 quote! {
179 impl #impl_generics Solve for #name #ty_generics #where_clause {
180 fn solve(&mut self, limits: Limits) -> crate::MaybeTerminatedError {
181 #kernel.start_solving(limits);
182 self.alg_main()
183 }
184
185 fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
186 use crate::ExtendedSolveStats;
187 (#kernel.stats, Some(self.oracle_stats()),
188 Some(self.encoding_stats()))
189 }
190 }
191 }.into()
192}
193
194#[proc_macro_attribute]
195pub fn oracle_bounds(_attr: TokenStream, item: TokenStream) -> TokenStream {
196 let ast: syn::Item = parse_macro_input!(item);
197 let impl_block = match ast {
198 syn::Item::Impl(impl_block) => impl_block,
199 _ => panic!("oracle_bounds attribute can only be used on impl blocks"),
200 };
201
202 insert_oracle_bounds(impl_block)
203}
204
205fn insert_oracle_bounds(mut impl_block: syn::ItemImpl) -> TokenStream {
206 let mut found_oracle = false;
208 for gen in impl_block.generics.type_params() {
209 if gen.ident == "O" {
210 found_oracle = true;
211 break;
212 }
213 }
214 if !found_oracle {
215 panic!("oracle_bounds attribute needs a generic for the oracle type called 'O'")
216 }
217
218 let obounds = "where";
219 #[cfg(feature = "interrupt-oracle")]
220 let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
221 #[cfg(feature = "phasing")]
222 let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
223 #[cfg(feature = "sol-tightening")]
224 let obounds = format!(
225 "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
226 obounds
227 );
228 #[cfg(feature = "limit-conflicts")]
229 let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
230 let obounds: TokenStream = obounds.parse().unwrap();
231 let obounds: syn::WhereClause = parse_macro_input!(obounds);
232
233 impl_block.generics.make_where_clause();
234 impl_block
235 .generics
236 .where_clause
237 .as_mut()
238 .unwrap()
239 .predicates
240 .extend(obounds.predicates);
241
242 quote! { #impl_block }.into()
243 }