tp_npos_elections_compact/
lib.rs1use proc_macro::TokenStream;
21use proc_macro2::{TokenStream as TokenStream2, Span, Ident};
22use proc_macro_crate::crate_name;
23use quote::quote;
24use syn::{parse::{Parse, ParseStream, Result}};
25
26mod assignment;
27mod codec;
28
29const PREFIX: &'static str = "votes";
31
32pub(crate) fn syn_err(message: &'static str) -> syn::Error {
33 syn::Error::new(Span::call_site(), message)
34}
35
36#[proc_macro]
75pub fn generate_solution_type(item: TokenStream) -> TokenStream {
76 let SolutionDef {
77 vis,
78 ident,
79 count,
80 voter_type,
81 target_type,
82 weight_type,
83 compact_encoding,
84 } = syn::parse_macro_input!(item as SolutionDef);
85
86 let imports = imports().unwrap_or_else(|e| e.to_compile_error());
87
88 let solution_struct = struct_def(
89 vis,
90 ident.clone(),
91 count,
92 voter_type.clone(),
93 target_type.clone(),
94 weight_type.clone(),
95 compact_encoding,
96 ).unwrap_or_else(|e| e.to_compile_error());
97
98 quote!(
99 #imports
100 #solution_struct
101 )
102 .into()
103}
104
105fn struct_def(
106 vis: syn::Visibility,
107 ident: syn::Ident,
108 count: usize,
109 voter_type: syn::Type,
110 target_type: syn::Type,
111 weight_type: syn::Type,
112 compact_encoding: bool,
113) -> Result<TokenStream2> {
114 if count <= 2 {
115 Err(syn_err("cannot build compact solution struct with capacity less than 3."))?
116 }
117
118 let singles = {
119 let name = field_name_for(1);
120 quote!(
122 #vis #name: Vec<(#voter_type, #target_type)>,
123 )
124 };
125
126 let doubles = {
127 let name = field_name_for(2);
128 quote!(
129 #vis #name: Vec<(#voter_type, (#target_type, #weight_type), #target_type)>,
130 )
131 };
132
133 let rest = (3..=count)
134 .map(|c| {
135 let field_name = field_name_for(c);
136 let array_len = c - 1;
137 quote!(
138 #vis #field_name: Vec<(
139 #voter_type,
140 [(#target_type, #weight_type); #array_len],
141 #target_type
142 )>,
143 )
144 })
145 .collect::<TokenStream2>();
146
147 let len_impl = len_impl(count);
148 let edge_count_impl = edge_count_impl(count);
149 let unique_targets_impl = unique_targets_impl(count);
150 let remove_voter_impl = remove_voter_impl(count);
151
152 let derives_and_maybe_compact_encoding = if compact_encoding {
153 let compact_impl = codec::codec_impl(
155 ident.clone(),
156 voter_type.clone(),
157 target_type.clone(),
158 weight_type.clone(),
159 count,
160 );
161 quote!{
162 #compact_impl
163 #[derive(Default, PartialEq, Eq, Clone, Debug)]
164 }
165 } else {
166 quote!(#[derive(Default, PartialEq, Eq, Clone, Debug, _npos::codec::Encode, _npos::codec::Decode)])
168 };
169
170 let from_impl = assignment::from_impl(count);
171 let into_impl = assignment::into_impl(count, weight_type.clone());
172
173 Ok(quote! (
174 #derives_and_maybe_compact_encoding
176 #vis struct #ident { #singles #doubles #rest }
177
178 use _npos::__OrInvalidIndex;
179 impl _npos::CompactSolution for #ident {
180 const LIMIT: usize = #count;
181 type Voter = #voter_type;
182 type Target = #target_type;
183 type Accuracy = #weight_type;
184
185 fn voter_count(&self) -> usize {
186 let mut all_len = 0usize;
187 #len_impl
188 all_len
189 }
190
191 fn edge_count(&self) -> usize {
192 let mut all_edges = 0usize;
193 #edge_count_impl
194 all_edges
195 }
196
197 fn unique_targets(&self) -> Vec<Self::Target> {
198 let mut all_targets: Vec<Self::Target> = Vec::with_capacity(self.average_edge_count());
201 let mut maybe_insert_target = |t: Self::Target| {
202 match all_targets.binary_search(&t) {
203 Ok(_) => (),
204 Err(pos) => all_targets.insert(pos, t)
205 }
206 };
207
208 #unique_targets_impl
209
210 all_targets
211 }
212
213 fn remove_voter(&mut self, to_remove: Self::Voter) -> bool {
214 #remove_voter_impl
215 return false
216 }
217
218 fn from_assignment<FV, FT, A>(
219 assignments: Vec<_npos::Assignment<A, #weight_type>>,
220 index_of_voter: FV,
221 index_of_target: FT,
222 ) -> Result<Self, _npos::Error>
223 where
224 A: _npos::IdentifierT,
225 for<'r> FV: Fn(&'r A) -> Option<Self::Voter>,
226 for<'r> FT: Fn(&'r A) -> Option<Self::Target>,
227 {
228 let mut compact: #ident = Default::default();
229
230 for _npos::Assignment { who, distribution } in assignments {
231 match distribution.len() {
232 0 => continue,
233 #from_impl
234 _ => {
235 return Err(_npos::Error::CompactTargetOverflow);
236 }
237 }
238 };
239 Ok(compact)
240 }
241
242 fn into_assignment<A: _npos::IdentifierT>(
243 self,
244 voter_at: impl Fn(Self::Voter) -> Option<A>,
245 target_at: impl Fn(Self::Target) -> Option<A>,
246 ) -> Result<Vec<_npos::Assignment<A, #weight_type>>, _npos::Error> {
247 let mut assignments: Vec<_npos::Assignment<A, #weight_type>> = Default::default();
248 #into_impl
249 Ok(assignments)
250 }
251 }
252 ))
253}
254
255fn remove_voter_impl(count: usize) -> TokenStream2 {
256 let field_name = field_name_for(1);
257 let single = quote! {
258 if let Some(idx) = self.#field_name.iter().position(|(x, _)| *x == to_remove) {
259 self.#field_name.remove(idx);
260 return true
261 }
262 };
263
264 let field_name = field_name_for(2);
265 let double = quote! {
266 if let Some(idx) = self.#field_name.iter().position(|(x, _, _)| *x == to_remove) {
267 self.#field_name.remove(idx);
268 return true
269 }
270 };
271
272 let rest = (3..=count)
273 .map(|c| {
274 let field_name = field_name_for(c);
275 quote! {
276 if let Some(idx) = self.#field_name.iter().position(|(x, _, _)| *x == to_remove) {
277 self.#field_name.remove(idx);
278 return true
279 }
280 }
281 })
282 .collect::<TokenStream2>();
283
284 quote! {
285 #single
286 #double
287 #rest
288 }
289}
290
291fn len_impl(count: usize) -> TokenStream2 {
292 (1..=count).map(|c| {
293 let field_name = field_name_for(c);
294 quote!(
295 all_len = all_len.saturating_add(self.#field_name.len());
296 )
297 }).collect::<TokenStream2>()
298}
299
300fn edge_count_impl(count: usize) -> TokenStream2 {
301 (1..=count).map(|c| {
302 let field_name = field_name_for(c);
303 quote!(
304 all_edges = all_edges.saturating_add(
305 self.#field_name.len().saturating_mul(#c as usize)
306 );
307 )
308 }).collect::<TokenStream2>()
309}
310
311fn unique_targets_impl(count: usize) -> TokenStream2 {
312 let unique_targets_impl_single = {
313 let field_name = field_name_for(1);
314 quote! {
315 self.#field_name.iter().for_each(|(_, t)| {
316 maybe_insert_target(*t);
317 });
318 }
319 };
320
321 let unique_targets_impl_double = {
322 let field_name = field_name_for(2);
323 quote! {
324 self.#field_name.iter().for_each(|(_, (t1, _), t2)| {
325 maybe_insert_target(*t1);
326 maybe_insert_target(*t2);
327 });
328 }
329 };
330
331 let unique_targets_impl_rest = (3..=count).map(|c| {
332 let field_name = field_name_for(c);
333 quote! {
334 self.#field_name.iter().for_each(|(_, inners, t_last)| {
335 inners.iter().for_each(|(t, _)| {
336 maybe_insert_target(*t);
337 });
338 maybe_insert_target(*t_last);
339 });
340 }
341 }).collect::<TokenStream2>();
342
343 quote! {
344 #unique_targets_impl_single
345 #unique_targets_impl_double
346 #unique_targets_impl_rest
347 }
348}
349
350fn imports() -> Result<TokenStream2> {
351 if std::env::var("CARGO_PKG_NAME").unwrap() == "tp-npos-elections" {
352 Ok(quote! {
353 use crate as _npos;
354 })
355 } else {
356 match crate_name("tp-npos-elections") {
357 Ok(tp_npos_elections) => {
358 let ident = syn::Ident::new(&tp_npos_elections, Span::call_site());
359 Ok(quote!( extern crate #ident as _npos; ))
360 },
361 Err(e) => Err(syn::Error::new(Span::call_site(), &e)),
362 }
363 }
364}
365
366struct SolutionDef {
367 vis: syn::Visibility,
368 ident: syn::Ident,
369 voter_type: syn::Type,
370 target_type: syn::Type,
371 weight_type: syn::Type,
372 count: usize,
373 compact_encoding: bool,
374}
375
376fn check_compact_attr(input: ParseStream) -> Result<bool> {
377 let mut attrs = input.call(syn::Attribute::parse_outer).unwrap_or_default();
378 if attrs.len() == 1 {
379 let attr = attrs.pop().expect("Vec with len 1 can be popped.");
380 if attr.path.segments.len() == 1 {
381 let segment = attr.path.segments.first().expect("Vec with len 1 can be popped.");
382 if segment.ident == Ident::new("compact", Span::call_site()) {
383 Ok(true)
384 } else {
385 Err(syn_err("generate_solution_type macro can only accept #[compact] attribute."))
386 }
387 } else {
388 Err(syn_err("generate_solution_type macro can only accept #[compact] attribute."))
389 }
390 } else {
391 Ok(false)
392 }
393}
394
395impl Parse for SolutionDef {
397 fn parse(input: ParseStream) -> syn::Result<Self> {
398 let compact_encoding = check_compact_attr(input)?;
400
401 let vis: syn::Visibility = input.parse()?;
403 let _ = <syn::Token![struct]>::parse(input)?;
404 let ident: syn::Ident = input.parse()?;
405
406 let _ = <syn::Token![::]>::parse(input)?;
408 let generics: syn::AngleBracketedGenericArguments = input.parse()?;
409
410 if generics.args.len() != 3 {
411 return Err(syn_err("Must provide 3 generic args."))
412 }
413
414 let mut types: Vec<syn::Type> = generics.args.iter().map(|t|
415 match t {
416 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
417 _ => Err(syn_err("Wrong type of generic provided. Must be a `type`.")),
418 }
419 ).collect::<Result<_>>()?;
420
421 let weight_type = types.pop().expect("Vector of length 3 can be popped; qed");
422 let target_type = types.pop().expect("Vector of length 2 can be popped; qed");
423 let voter_type = types.pop().expect("Vector of length 1 can be popped; qed");
424
425 let count_expr: syn::ExprParen = input.parse()?;
427 let expr = count_expr.expr;
428 let expr_lit = match *expr {
429 syn::Expr::Lit(count_lit) => count_lit.lit,
430 _ => return Err(syn_err("Count must be literal."))
431 };
432 let int_lit = match expr_lit {
433 syn::Lit::Int(int_lit) => int_lit,
434 _ => return Err(syn_err("Count must be int literal."))
435 };
436 let count = int_lit.base10_parse::<usize>()?;
437
438 Ok(Self { vis, ident, voter_type, target_type, weight_type, count, compact_encoding } )
439 }
440}
441
442fn field_name_for(n: usize) -> Ident {
443 Ident::new(&format!("{}{}", PREFIX, n), Span::call_site())
444}