scuttle_proc/
lib.rs

1use darling::FromDeriveInput;
2use proc_macro::{self, TokenStream};
3use quote::quote;
4use syn::{self, parse_macro_input};
5
6#[derive(FromDeriveInput, Default)]
7#[darling(default, attributes(kernel))]
8struct KernelOpts {
9    kernel: Option<syn::Expr>,
10}
11
12#[proc_macro_derive(KernelFunctions, attributes(kernel))]
13pub fn kernel_functions_derive(input: TokenStream) -> TokenStream {
14    // Construct a representation of the code as a syntax tree to manipulate
15    let ast = parse_macro_input!(input);
16    let opts = KernelOpts::from_derive_input(&ast)
17        .expect("Wrong options for KernelFunctions derive macro");
18
19    if let Some(kernel) = opts.kernel.as_ref() {
20        match kernel {
21            syn::Expr::Field(_) => (),
22            _ => panic!("kernel attribute must be a struct field access"),
23        }
24    }
25
26    // Build the trait implementation
27    impl_kernel_functions_macro(ast, opts)
28}
29
30fn impl_kernel_functions_macro(mut ast: syn::DeriveInput, opts: KernelOpts) -> TokenStream {
31    let name = &ast.ident;
32
33    // Check whether type has generic named O that is assumed to be the oracle
34    #[cfg(feature = "interrupt-oracle")]
35    {
36        let mut found_oracle = false;
37        for gen in ast.generics.type_params() {
38            if gen.ident == "O" {
39                found_oracle = true;
40                break;
41            }
42        }
43        if !found_oracle {
44            panic!("KernelFunctions derive needs a generic for the oracle type called 'O'")
45        }
46    }
47
48    let kernel = if let Some(kernel) = opts.kernel {
49        kernel
50    } else {
51        let ts: TokenStream = "self.kernel".parse().unwrap();
52        parse_macro_input!(ts)
53    };
54
55    ast.generics.make_where_clause();
56    let obounds = "where";
57    #[cfg(feature = "interrupt-oracle")]
58    let obounds = format!(
59        "{} O: rustsat::solvers::Interrupt, ProofW: std::io::Write,",
60        obounds
61    );
62    let obounds: TokenStream = obounds.parse().unwrap();
63    let obounds: syn::WhereClause = parse_macro_input!(obounds);
64    ast.generics
65        .where_clause
66        .as_mut()
67        .unwrap()
68        .predicates
69        .extend(obounds.predicates);
70
71    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
72
73    quote! {
74        impl #impl_generics KernelFunctions for #name #ty_generics #where_clause {
75            fn pareto_front(&self) -> crate::types::ParetoFront {
76                self.pareto_front.clone()
77            }
78
79            fn stats(&self) -> crate::Stats {
80                #kernel.stats
81            }
82
83            fn attach_logger<L: crate::WriteSolverLog + 'static>(&mut self, logger: L) {
84                #kernel.attach_logger(logger)
85            }
86
87            fn detach_logger(&mut self) -> Option<Box<dyn crate::WriteSolverLog>> {
88                #kernel.detach_logger()
89            }
90
91            fn interrupter(&mut self) -> crate::algs::Interrupter {
92                #kernel.interrupter()
93            }
94        }
95    }
96    .into()
97}
98
99#[derive(FromDeriveInput, Default)]
100#[darling(default, attributes(solve))]
101struct SolveOpts {
102    bounds: Option<syn::WhereClause>,
103    oracle_stats: bool,
104}
105
106#[proc_macro_derive(Solve, attributes(solve))]
107pub fn solve_derive(input: TokenStream) -> TokenStream {
108    let ast = parse_macro_input!(input);
109    let kopts = KernelOpts::from_derive_input(&ast)
110        .expect("Wrong options for KernelFunctions derive macro");
111    let sopts = SolveOpts::from_derive_input(&ast).expect("Wrong options for Solve derive macro");
112
113    if let Some(kernel) = kopts.kernel.as_ref() {
114        match kernel {
115            syn::Expr::Field(_) => (),
116            _ => panic!("kernel attribute must be a struct field access"),
117        }
118    }
119
120    // Build the trait implementation
121    impl_solve_macro(ast, kopts, sopts)
122}
123
124fn impl_solve_macro(mut ast: syn::DeriveInput, kopts: KernelOpts, sopts: SolveOpts) -> TokenStream {
125    let name = &ast.ident;
126
127    // Check whether type has generic named O that is assumed to be the oracle
128    let mut found_oracle = false;
129    for gen in ast.generics.type_params() {
130        if gen.ident == "O" {
131            found_oracle = true;
132            break;
133        }
134    }
135    if !found_oracle {
136        panic!("Solve derive needs a generic for the oracle type called 'O'")
137    }
138
139    let kernel = if let Some(kernel) = kopts.kernel {
140        kernel
141    } else {
142        let ts: TokenStream = "self.kernel".parse().unwrap();
143        parse_macro_input!(ts)
144    };
145
146    ast.generics.make_where_clause();
147    let obounds = "where";
148    #[cfg(feature = "interrupt-oracle")]
149    let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
150    #[cfg(feature = "phasing")]
151    let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
152    #[cfg(feature = "sol-tightening")]
153    let obounds = format!(
154        "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
155        obounds
156    );
157    #[cfg(feature = "limit-conflicts")]
158    let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
159    let obounds: TokenStream = obounds.parse().unwrap();
160    let obounds: syn::WhereClause = parse_macro_input!(obounds);
161    ast.generics
162        .where_clause
163        .as_mut()
164        .unwrap()
165        .predicates
166        .extend(obounds.predicates);
167    if let Some(add_bounds) = sopts.bounds {
168        ast.generics
169            .where_clause
170            .as_mut()
171            .unwrap()
172            .predicates
173            .extend(add_bounds.predicates)
174    }
175
176    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
177
178    quote! {
179        impl #impl_generics Solve for #name #ty_generics #where_clause {
180            fn solve(&mut self, limits: Limits) -> crate::MaybeTerminatedError {
181                #kernel.start_solving(limits);
182                self.alg_main()
183            }
184
185            fn all_stats(&self) -> (crate::Stats, Option<rustsat::solvers::SolverStats>, Option<Vec<crate::EncodingStats>>) {
186                use crate::ExtendedSolveStats;
187                (#kernel.stats, Some(self.oracle_stats()),
188                Some(self.encoding_stats()))
189            }
190        }
191    }.into()
192}
193
194#[proc_macro_attribute]
195pub fn oracle_bounds(_attr: TokenStream, item: TokenStream) -> TokenStream {
196    let ast: syn::Item = parse_macro_input!(item);
197    let impl_block = match ast {
198        syn::Item::Impl(impl_block) => impl_block,
199        _ => panic!("oracle_bounds attribute can only be used on impl blocks"),
200    };
201
202    insert_oracle_bounds(impl_block)
203}
204
205fn insert_oracle_bounds(mut impl_block: syn::ItemImpl) -> TokenStream {
206    // Check whether type has generic named O that is assumed to be the oracle
207    let mut found_oracle = false;
208    for gen in impl_block.generics.type_params() {
209        if gen.ident == "O" {
210            found_oracle = true;
211            break;
212        }
213    }
214    if !found_oracle {
215        panic!("oracle_bounds attribute needs a generic for the oracle type called 'O'")
216    }
217
218    let obounds = "where";
219    #[cfg(feature = "interrupt-oracle")]
220    let obounds = format!("{} O: rustsat::solvers::Interrupt,", obounds);
221    #[cfg(feature = "phasing")]
222    let obounds = format!("{} O: rustsat::solvers::PhaseLit,", obounds);
223    #[cfg(feature = "sol-tightening")]
224    let obounds = format!(
225        "{} O: rustsat::solvers::FlipLit + rustsat::solvers::FreezeVar,",
226        obounds
227    );
228    #[cfg(feature = "limit-conflicts")]
229    let obounds = format!("{} O: rustsat::solvers::LimitConflicts,", obounds);
230    let obounds: TokenStream = obounds.parse().unwrap();
231    let obounds: syn::WhereClause = parse_macro_input!(obounds);
232
233    impl_block.generics.make_where_clause();
234    impl_block
235        .generics
236        .where_clause
237        .as_mut()
238        .unwrap()
239        .predicates
240        .extend(obounds.predicates);
241
242    quote! { #impl_block }.into()
243    //}
244}