tp_npos_elections_compact/
lib.rs

1// This file is part of Tetcore.
2
3// Copyright (C) 2020-2021 Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Proc macro for a npos compact assignment.
19
20use proc_macro::TokenStream;
21use proc_macro2::{TokenStream as TokenStream2, Span, Ident};
22use proc_macro_crate::crate_name;
23use quote::quote;
24use syn::{parse::{Parse, ParseStream, Result}};
25
26mod assignment;
27mod codec;
28
29// prefix used for struct fields in compact.
30const PREFIX: &'static str = "votes";
31
32pub(crate) fn syn_err(message: &'static str) -> syn::Error {
33	syn::Error::new(Span::call_site(), message)
34}
35
36/// Generates a struct to store the election result in a small way. This can encode a structure
37/// which is the equivalent of a `tp_npos_elections::Assignment<_>`.
38///
39/// The following data types can be configured by the macro.
40///
41/// - The identifier of the voter. This can be any type that supports `tetsy-scale-codec`'s compact
42///   encoding.
43/// - The identifier of the target. This can be any type that supports `tetsy-scale-codec`'s
44///   compact encoding.
45/// - The accuracy of the ratios. This must be one of the `PerThing` types defined in
46///   `arithmetic`.
47///
48/// Moreover, the maximum number of edges per voter (distribution per assignment) also need to be
49/// specified. Attempting to convert from/to an assignment with more distributions will fail.
50///
51///
52/// For example, the following generates a public struct with name `TestSolution` with `u16` voter
53/// type, `u8` target type and `Perbill` accuracy with maximum of 8 edges per voter.
54///
55/// ```ignore
56/// generate_solution_type!(pub struct TestSolution<u16, u8, Perbill>::(8))
57/// ```
58///
59/// The given struct provides function to convert from/to Assignment:
60///
61/// - `fn from_assignment<..>(..)`
62/// - `fn into_assignment<..>(..)`
63///
64/// The generated struct is by default deriving both `Encode` and `Decode`. This is okay but could
65/// lead to many 0s in the solution. If prefixed with `#[compact]`, then a custom compact encoding
66/// for numbers will be used, similar to how `tetsy-scale-codec`'s `Compact` works.
67///
68/// ```ignore
69/// generate_solution_type!(
70///     #[compact]
71///     pub struct TestSolutionCompact<u16, u8, Perbill>::(8)
72/// )
73/// ```
74#[proc_macro]
75pub fn generate_solution_type(item: TokenStream) -> TokenStream {
76	let SolutionDef {
77		vis,
78		ident,
79		count,
80		voter_type,
81		target_type,
82		weight_type,
83		compact_encoding,
84	} = syn::parse_macro_input!(item as SolutionDef);
85
86	let imports = imports().unwrap_or_else(|e| e.to_compile_error());
87
88	let solution_struct = struct_def(
89		vis,
90		ident.clone(),
91		count,
92		voter_type.clone(),
93		target_type.clone(),
94		weight_type.clone(),
95		compact_encoding,
96	).unwrap_or_else(|e| e.to_compile_error());
97
98	quote!(
99		#imports
100		#solution_struct
101	)
102	.into()
103}
104
105fn struct_def(
106	vis: syn::Visibility,
107	ident: syn::Ident,
108	count: usize,
109	voter_type: syn::Type,
110	target_type: syn::Type,
111	weight_type: syn::Type,
112	compact_encoding: bool,
113) -> Result<TokenStream2> {
114	if count <= 2 {
115		Err(syn_err("cannot build compact solution struct with capacity less than 3."))?
116	}
117
118	let singles = {
119		let name = field_name_for(1);
120		// NOTE: we use the visibility of the struct for the fields as well.. could be made better.
121		quote!(
122			#vis #name: Vec<(#voter_type, #target_type)>,
123		)
124	};
125
126	let doubles = {
127		let name = field_name_for(2);
128		quote!(
129			#vis #name: Vec<(#voter_type, (#target_type, #weight_type), #target_type)>,
130		)
131	};
132
133	let rest = (3..=count)
134		.map(|c| {
135			let field_name = field_name_for(c);
136			let array_len = c - 1;
137			quote!(
138				#vis #field_name: Vec<(
139					#voter_type,
140					[(#target_type, #weight_type); #array_len],
141					#target_type
142				)>,
143			)
144		})
145		.collect::<TokenStream2>();
146
147	let len_impl = len_impl(count);
148	let edge_count_impl = edge_count_impl(count);
149	let unique_targets_impl = unique_targets_impl(count);
150	let remove_voter_impl = remove_voter_impl(count);
151
152	let derives_and_maybe_compact_encoding = if compact_encoding {
153		// custom compact encoding.
154		let compact_impl = codec::codec_impl(
155			ident.clone(),
156			voter_type.clone(),
157			target_type.clone(),
158			weight_type.clone(),
159			count,
160		);
161		quote!{
162			#compact_impl
163			#[derive(Default, PartialEq, Eq, Clone, Debug)]
164		}
165	} else {
166		// automatically derived.
167		quote!(#[derive(Default, PartialEq, Eq, Clone, Debug, _npos::codec::Encode, _npos::codec::Decode)])
168	};
169
170	let from_impl = assignment::from_impl(count);
171	let into_impl = assignment::into_impl(count, weight_type.clone());
172
173	Ok(quote! (
174		/// A struct to encode a election assignment in a compact way.
175		#derives_and_maybe_compact_encoding
176		#vis struct #ident { #singles #doubles #rest }
177
178		use _npos::__OrInvalidIndex;
179		impl _npos::CompactSolution for #ident {
180			const LIMIT: usize = #count;
181			type Voter = #voter_type;
182			type Target = #target_type;
183			type Accuracy = #weight_type;
184
185			fn voter_count(&self) -> usize {
186				let mut all_len = 0usize;
187				#len_impl
188				all_len
189			}
190
191			fn edge_count(&self) -> usize {
192				let mut all_edges = 0usize;
193				#edge_count_impl
194				all_edges
195			}
196
197			fn unique_targets(&self) -> Vec<Self::Target> {
198				// NOTE: this implementation returns the targets sorted, but we don't use it yet per
199				// se, nor is the API enforcing it.
200				let mut all_targets: Vec<Self::Target> = Vec::with_capacity(self.average_edge_count());
201				let mut maybe_insert_target = |t: Self::Target| {
202					match all_targets.binary_search(&t) {
203						Ok(_) => (),
204						Err(pos) => all_targets.insert(pos, t)
205					}
206				};
207
208				#unique_targets_impl
209
210				all_targets
211			}
212
213			fn remove_voter(&mut self, to_remove: Self::Voter) -> bool {
214				#remove_voter_impl
215				return false
216			}
217
218			fn from_assignment<FV, FT, A>(
219				assignments: Vec<_npos::Assignment<A, #weight_type>>,
220				index_of_voter: FV,
221				index_of_target: FT,
222			) -> Result<Self, _npos::Error>
223				where
224					A: _npos::IdentifierT,
225					for<'r> FV: Fn(&'r A) -> Option<Self::Voter>,
226					for<'r> FT: Fn(&'r A) -> Option<Self::Target>,
227			{
228				let mut compact: #ident = Default::default();
229
230				for _npos::Assignment { who, distribution } in assignments {
231					match distribution.len() {
232						0 => continue,
233						#from_impl
234						_ => {
235							return Err(_npos::Error::CompactTargetOverflow);
236						}
237					}
238				};
239				Ok(compact)
240			}
241
242			fn into_assignment<A: _npos::IdentifierT>(
243				self,
244				voter_at: impl Fn(Self::Voter) -> Option<A>,
245				target_at: impl Fn(Self::Target) -> Option<A>,
246			) -> Result<Vec<_npos::Assignment<A, #weight_type>>, _npos::Error> {
247				let mut assignments: Vec<_npos::Assignment<A, #weight_type>> = Default::default();
248				#into_impl
249				Ok(assignments)
250			}
251		}
252	))
253}
254
255fn remove_voter_impl(count: usize) -> TokenStream2 {
256	let field_name = field_name_for(1);
257	let single = quote! {
258		if let Some(idx) = self.#field_name.iter().position(|(x, _)| *x == to_remove) {
259			self.#field_name.remove(idx);
260			return true
261		}
262	};
263
264	let field_name = field_name_for(2);
265	let double = quote! {
266		if let Some(idx) = self.#field_name.iter().position(|(x, _, _)| *x == to_remove) {
267			self.#field_name.remove(idx);
268			return true
269		}
270	};
271
272	let rest = (3..=count)
273		.map(|c| {
274			let field_name = field_name_for(c);
275			quote! {
276				if let Some(idx) = self.#field_name.iter().position(|(x, _, _)| *x == to_remove) {
277					self.#field_name.remove(idx);
278					return true
279				}
280			}
281		})
282		.collect::<TokenStream2>();
283
284	quote! {
285		#single
286		#double
287		#rest
288	}
289}
290
291fn len_impl(count: usize) -> TokenStream2 {
292	(1..=count).map(|c| {
293		let field_name = field_name_for(c);
294		quote!(
295			all_len = all_len.saturating_add(self.#field_name.len());
296		)
297	}).collect::<TokenStream2>()
298}
299
300fn edge_count_impl(count: usize) -> TokenStream2 {
301	(1..=count).map(|c| {
302		let field_name = field_name_for(c);
303		quote!(
304			all_edges = all_edges.saturating_add(
305				self.#field_name.len().saturating_mul(#c as usize)
306			);
307		)
308	}).collect::<TokenStream2>()
309}
310
311fn unique_targets_impl(count: usize) -> TokenStream2 {
312	let unique_targets_impl_single = {
313		let field_name = field_name_for(1);
314		quote! {
315			self.#field_name.iter().for_each(|(_, t)| {
316				maybe_insert_target(*t);
317			});
318		}
319	};
320
321	let unique_targets_impl_double = {
322		let field_name = field_name_for(2);
323		quote! {
324			self.#field_name.iter().for_each(|(_, (t1, _), t2)| {
325				maybe_insert_target(*t1);
326				maybe_insert_target(*t2);
327			});
328		}
329	};
330
331	let unique_targets_impl_rest = (3..=count).map(|c| {
332		let field_name = field_name_for(c);
333		quote! {
334			self.#field_name.iter().for_each(|(_, inners, t_last)| {
335				inners.iter().for_each(|(t, _)| {
336					maybe_insert_target(*t);
337				});
338				maybe_insert_target(*t_last);
339			});
340		}
341	}).collect::<TokenStream2>();
342
343	quote! {
344		#unique_targets_impl_single
345		#unique_targets_impl_double
346		#unique_targets_impl_rest
347	}
348}
349
350fn imports() -> Result<TokenStream2> {
351	if std::env::var("CARGO_PKG_NAME").unwrap() == "tp-npos-elections" {
352		Ok(quote! {
353			use crate as _npos;
354		})
355	} else {
356		match crate_name("tp-npos-elections") {
357			Ok(tp_npos_elections) => {
358				let ident = syn::Ident::new(&tp_npos_elections, Span::call_site());
359				Ok(quote!( extern crate #ident as _npos; ))
360			},
361			Err(e) => Err(syn::Error::new(Span::call_site(), &e)),
362		}
363	}
364}
365
366struct SolutionDef {
367	vis: syn::Visibility,
368	ident: syn::Ident,
369	voter_type: syn::Type,
370	target_type: syn::Type,
371	weight_type: syn::Type,
372	count: usize,
373	compact_encoding: bool,
374}
375
376fn check_compact_attr(input: ParseStream) -> Result<bool> {
377	let mut attrs = input.call(syn::Attribute::parse_outer).unwrap_or_default();
378	if attrs.len() == 1 {
379		let attr = attrs.pop().expect("Vec with len 1 can be popped.");
380		if attr.path.segments.len() == 1 {
381			let segment = attr.path.segments.first().expect("Vec with len 1 can be popped.");
382			if segment.ident == Ident::new("compact", Span::call_site()) {
383				Ok(true)
384			} else {
385				Err(syn_err("generate_solution_type macro can only accept #[compact] attribute."))
386			}
387		} else {
388			Err(syn_err("generate_solution_type macro can only accept #[compact] attribute."))
389		}
390	} else {
391		Ok(false)
392	}
393}
394
395/// #[compact] pub struct CompactName::<u32, u32, u32>()
396impl Parse for SolutionDef {
397	fn parse(input: ParseStream) -> syn::Result<Self> {
398		// optional #[compact]
399		let compact_encoding = check_compact_attr(input)?;
400
401		// <vis> struct <name>
402		let vis: syn::Visibility = input.parse()?;
403		let _ = <syn::Token![struct]>::parse(input)?;
404		let ident: syn::Ident = input.parse()?;
405
406		// ::<V, T, W>
407		let _ = <syn::Token![::]>::parse(input)?;
408		let generics: syn::AngleBracketedGenericArguments = input.parse()?;
409
410		if generics.args.len() != 3 {
411			return Err(syn_err("Must provide 3 generic args."))
412		}
413
414		let mut types: Vec<syn::Type> = generics.args.iter().map(|t|
415			match t {
416				syn::GenericArgument::Type(ty) => Ok(ty.clone()),
417				_ => Err(syn_err("Wrong type of generic provided. Must be a `type`.")),
418			}
419		).collect::<Result<_>>()?;
420
421		let weight_type = types.pop().expect("Vector of length 3 can be popped; qed");
422		let target_type = types.pop().expect("Vector of length 2 can be popped; qed");
423		let voter_type = types.pop().expect("Vector of length 1 can be popped; qed");
424
425		// (<count>)
426		let count_expr: syn::ExprParen = input.parse()?;
427		let expr = count_expr.expr;
428		let expr_lit = match *expr {
429			syn::Expr::Lit(count_lit) => count_lit.lit,
430			_ => return Err(syn_err("Count must be literal."))
431		};
432		let int_lit = match expr_lit {
433			syn::Lit::Int(int_lit) => int_lit,
434			_ => return Err(syn_err("Count must be int literal."))
435		};
436		let count = int_lit.base10_parse::<usize>()?;
437
438		Ok(Self { vis, ident, voter_type, target_type, weight_type, count, compact_encoding } )
439	}
440}
441
442fn field_name_for(n: usize) -> Ident {
443	Ident::new(&format!("{}{}", PREFIX, n), Span::call_site())
444}