Skip to main content

subsoil/arithmetic/
mod.rs

1// This file is part of Soil.
2
3// Copyright (C) Soil contributors.
4// Copyright (C) Parity Technologies (UK) Ltd.
5// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later WITH Classpath-exception-2.0
6
7//! Minimal fixed point arithmetic primitives and types for runtime.
8
9/// Copied from `sp-runtime` and documented there.
10#[macro_export]
11macro_rules! assert_eq_error_rate {
12	($x:expr, $y:expr, $error:expr $(,)?) => {
13		assert!(
14			($x) >= (($y) - ($error)) && ($x) <= (($y) + ($error)),
15			"{:?} != {:?} (with error rate {:?})",
16			$x,
17			$y,
18			$error,
19		);
20	};
21}
22
23pub mod biguint;
24pub mod fixed_point;
25pub mod helpers_128bit;
26pub mod per_things;
27pub mod rational;
28pub mod traits;
29
30pub use fixed_point::{
31	FixedI128, FixedI64, FixedPointNumber, FixedPointOperand, FixedU128, FixedU64,
32};
33pub use per_things::{
34	InnerOf, MultiplyArg, PerThing, PerU16, Perbill, Percent, Permill, Perquintill, RationalArg,
35	ReciprocalArg, Rounding, SignedRounding, UpperOf,
36};
37pub use rational::{MultiplyRational, Rational128, RationalInfinite};
38
39use alloc::vec::Vec;
40use core::{cmp::Ordering, fmt::Debug};
41use traits::{BaseArithmetic, One, SaturatedConversion, Unsigned, Zero};
42
43use codec::{Decode, DecodeWithMemTracking, Encode, MaxEncodedLen};
44use scale_info::TypeInfo;
45
46#[cfg(feature = "serde")]
47use serde::{Deserialize, Serialize};
48
49/// Arithmetic errors.
50#[derive(
51	Eq,
52	PartialEq,
53	Clone,
54	Copy,
55	Encode,
56	Decode,
57	DecodeWithMemTracking,
58	Debug,
59	TypeInfo,
60	MaxEncodedLen,
61)]
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63pub enum ArithmeticError {
64	/// Underflow.
65	Underflow,
66	/// Overflow.
67	Overflow,
68	/// Division by zero.
69	DivisionByZero,
70}
71
72impl From<ArithmeticError> for &'static str {
73	fn from(e: ArithmeticError) -> &'static str {
74		match e {
75			ArithmeticError::Underflow => "An underflow would occur",
76			ArithmeticError::Overflow => "An overflow would occur",
77			ArithmeticError::DivisionByZero => "Division by zero",
78		}
79	}
80}
81
82/// Trait for comparing two numbers with an threshold.
83///
84/// Returns:
85/// - `Ordering::Greater` if `self` is greater than `other + threshold`.
86/// - `Ordering::Less` if `self` is less than `other - threshold`.
87/// - `Ordering::Equal` otherwise.
88pub trait ThresholdOrd<T> {
89	/// Compare if `self` is `threshold` greater or less than `other`.
90	fn tcmp(&self, other: &T, threshold: T) -> Ordering;
91}
92
93impl<T> ThresholdOrd<T> for T
94where
95	T: Ord + PartialOrd + Copy + Clone + traits::Zero + traits::Saturating,
96{
97	fn tcmp(&self, other: &T, threshold: T) -> Ordering {
98		// early exit.
99		if threshold.is_zero() {
100			return self.cmp(other);
101		}
102
103		let upper_bound = other.saturating_add(threshold);
104		let lower_bound = other.saturating_sub(threshold);
105
106		if upper_bound <= lower_bound {
107			// defensive only. Can never happen.
108			self.cmp(other)
109		} else {
110			// upper_bound is guaranteed now to be bigger than lower.
111			match (self.cmp(&lower_bound), self.cmp(&upper_bound)) {
112				(Ordering::Greater, Ordering::Greater) => Ordering::Greater,
113				(Ordering::Less, Ordering::Less) => Ordering::Less,
114				_ => Ordering::Equal,
115			}
116		}
117	}
118}
119
120/// A collection-like object that is made of values of type `T` and can normalize its individual
121/// values around a centric point.
122///
123/// Note that the order of items in the collection may affect the result.
124pub trait Normalizable<T> {
125	/// Normalize self around `targeted_sum`.
126	///
127	/// Only returns `Ok` if the new sum of results is guaranteed to be equal to `targeted_sum`.
128	/// Else, returns an error explaining why it failed to do so.
129	fn normalize(&self, targeted_sum: T) -> Result<Vec<T>, &'static str>;
130}
131
132macro_rules! impl_normalize_for_numeric {
133	($($numeric:ty),*) => {
134		$(
135			impl Normalizable<$numeric> for Vec<$numeric> {
136				fn normalize(&self, targeted_sum: $numeric) -> Result<Vec<$numeric>, &'static str> {
137					normalize(self.as_ref(), targeted_sum)
138				}
139			}
140		)*
141	};
142}
143
144impl_normalize_for_numeric!(u8, u16, u32, u64, u128);
145
146impl<P: PerThing> Normalizable<P> for Vec<P> {
147	fn normalize(&self, targeted_sum: P) -> Result<Vec<P>, &'static str> {
148		let uppers = self.iter().map(|p| <UpperOf<P>>::from(p.deconstruct())).collect::<Vec<_>>();
149
150		let normalized =
151			normalize(uppers.as_ref(), <UpperOf<P>>::from(targeted_sum.deconstruct()))?;
152
153		Ok(normalized
154			.into_iter()
155			.map(|i: UpperOf<P>| P::from_parts(i.saturated_into::<P::Inner>()))
156			.collect())
157	}
158}
159
160/// Normalize `input` so that the sum of all elements reaches `targeted_sum`.
161///
162/// This implementation is currently in a balanced position between being performant and accurate.
163///
164/// 1. We prefer storing original indices, and sorting the `input` only once. This will save the
165///    cost of sorting per round at the cost of a little bit of memory.
166/// 2. The granularity of increment/decrements is determined by the number of elements in `input`
167///    and their sum difference with `targeted_sum`, namely `diff = diff(sum(input), target_sum)`.
168///    This value is then distributed into `per_round = diff / input.len()` and `leftover = diff %
169///    round`. First, per_round is applied to all elements of input, and then we move to leftover,
170///    in which case we add/subtract 1 by 1 until `leftover` is depleted.
171///
172/// When the sum is less than the target, the above approach always holds. In this case, then each
173/// individual element is also less than target. Thus, by adding `per_round` to each item, neither
174/// of them can overflow the numeric bound of `T`. In fact, neither of the can go beyond
175/// `target_sum`*.
176///
177/// If sum is more than target, there is small twist. The subtraction of `per_round`
178/// form each element might go below zero. In this case, we saturate and add the error to the
179/// `leftover` value. This ensures that the result will always stay accurate, yet it might cause the
180/// execution to become increasingly slow, since leftovers are applied one by one.
181///
182/// All in all, the complicated case above is rare to happen in most use cases within this repo ,
183/// hence we opt for it due to its simplicity.
184///
185/// This function will return an error is if length of `input` cannot fit in `T`, or if `sum(input)`
186/// cannot fit inside `T`.
187///
188/// * This proof is used in the implementation as well.
189pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>
190where
191	T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
192{
193	// compute sum and return error if failed.
194	let mut sum = T::zero();
195	for t in input.iter() {
196		sum = sum.checked_add(t).ok_or("sum of input cannot fit in `T`")?;
197	}
198
199	// convert count and return error if failed.
200	let count = input.len();
201	let count_t: T = count.try_into().map_err(|_| "length of `inputs` cannot fit in `T`")?;
202
203	// Nothing to do here.
204	if count.is_zero() {
205		return Ok(Vec::<T>::new());
206	}
207
208	let diff = targeted_sum.max(sum) - targeted_sum.min(sum);
209	if diff.is_zero() {
210		return Ok(input.to_vec());
211	}
212
213	let needs_bump = targeted_sum > sum;
214	let per_round = diff / count_t;
215	let mut leftover = diff % count_t;
216
217	// sort output once based on diff. This will require more data transfer and saving original
218	// index, but we sort only twice instead: once now and once at the very end.
219	let mut output_with_idx = input.iter().cloned().enumerate().collect::<Vec<(usize, T)>>();
220	output_with_idx.sort_by_key(|x| x.1);
221
222	if needs_bump {
223		// must increase the values a bit. Bump from the min element. Index of minimum is now zero
224		// because we did a sort. If at any point the min goes greater or equal the `max_threshold`,
225		// we move to the next minimum.
226		let mut min_index = 0;
227		// at this threshold we move to next index.
228		let threshold = targeted_sum / count_t;
229
230		if !per_round.is_zero() {
231			for _ in 0..count {
232				output_with_idx[min_index].1 = output_with_idx[min_index]
233					.1
234					.checked_add(&per_round)
235					.expect("Proof provided in the module doc; qed.");
236				if output_with_idx[min_index].1 >= threshold {
237					min_index += 1;
238					min_index %= count;
239				}
240			}
241		}
242
243		// continue with the previous min_index
244		while !leftover.is_zero() {
245			output_with_idx[min_index].1 = output_with_idx[min_index]
246				.1
247				.checked_add(&T::one())
248				.expect("Proof provided in the module doc; qed.");
249			if output_with_idx[min_index].1 >= threshold {
250				min_index += 1;
251				min_index %= count;
252			}
253			leftover -= One::one();
254		}
255	} else {
256		// must decrease the stakes a bit. decrement from the max element. index of maximum is now
257		// last. if at any point the max goes less or equal the `min_threshold`, we move to the next
258		// maximum.
259		let mut max_index = count - 1;
260		// at this threshold we move to next index.
261		let threshold = output_with_idx
262			.first()
263			.expect("length of input is greater than zero; it must have a first; qed")
264			.1;
265
266		if !per_round.is_zero() {
267			for _ in 0..count {
268				output_with_idx[max_index].1 =
269					output_with_idx[max_index].1.checked_sub(&per_round).unwrap_or_else(|| {
270						let remainder = per_round - output_with_idx[max_index].1;
271						leftover += remainder;
272						output_with_idx[max_index].1.saturating_sub(per_round)
273					});
274				if output_with_idx[max_index].1 <= threshold {
275					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
276				}
277			}
278		}
279
280		// continue with the previous max_index
281		while !leftover.is_zero() {
282			if let Some(next) = output_with_idx[max_index].1.checked_sub(&One::one()) {
283				output_with_idx[max_index].1 = next;
284				if output_with_idx[max_index].1 <= threshold {
285					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
286				}
287				leftover -= One::one();
288			} else {
289				max_index = max_index.checked_sub(1).unwrap_or(count - 1);
290			}
291		}
292	}
293
294	debug_assert_eq!(
295		output_with_idx.iter().fold(T::zero(), |acc, (_, x)| acc + *x),
296		targeted_sum,
297		"sum({:?}) != {:?}",
298		output_with_idx,
299		targeted_sum
300	);
301
302	// sort again based on the original index.
303	output_with_idx.sort_by_key(|x| x.0);
304	Ok(output_with_idx.into_iter().map(|(_, t)| t).collect())
305}
306
307#[cfg(test)]
308mod normalize_tests {
309	use super::*;
310
311	#[test]
312	fn work_for_all_types() {
313		macro_rules! test_for {
314			($type:ty) => {
315				assert_eq!(
316					normalize(vec![8 as $type, 9, 7, 10].as_ref(), 40).unwrap(),
317					vec![10, 10, 10, 10],
318				);
319			};
320		}
321		// it should work for all types as long as the length of vector can be converted to T.
322		test_for!(u128);
323		test_for!(u64);
324		test_for!(u32);
325		test_for!(u16);
326		test_for!(u8);
327	}
328
329	#[test]
330	fn fails_on_if_input_sum_large() {
331		assert!(normalize(vec![1u8; 255].as_ref(), 10).is_ok());
332		assert_eq!(normalize(vec![1u8; 256].as_ref(), 10), Err("sum of input cannot fit in `T`"));
333	}
334
335	#[test]
336	fn does_not_fail_on_subtraction_overflow() {
337		assert_eq!(normalize(vec![1u8, 100, 100].as_ref(), 10).unwrap(), vec![1, 9, 0]);
338		assert_eq!(normalize(vec![1u8, 8, 9].as_ref(), 1).unwrap(), vec![0, 1, 0]);
339	}
340
341	#[test]
342	fn works_for_vec() {
343		assert_eq!(vec![8u32, 9, 7, 10].normalize(40).unwrap(), vec![10u32, 10, 10, 10]);
344	}
345
346	#[test]
347	fn works_for_per_thing() {
348		assert_eq!(
349			vec![Perbill::from_percent(33), Perbill::from_percent(33), Perbill::from_percent(33)]
350				.normalize(Perbill::one())
351				.unwrap(),
352			vec![
353				Perbill::from_parts(333333334),
354				Perbill::from_parts(333333333),
355				Perbill::from_parts(333333333)
356			]
357		);
358
359		assert_eq!(
360			vec![Perbill::from_percent(20), Perbill::from_percent(15), Perbill::from_percent(30)]
361				.normalize(Perbill::one())
362				.unwrap(),
363			vec![
364				Perbill::from_parts(316666668),
365				Perbill::from_parts(383333332),
366				Perbill::from_parts(300000000)
367			]
368		);
369	}
370
371	#[test]
372	fn can_work_for_peru16() {
373		// Peru16 is a rather special case; since inner type is exactly the same as capacity, we
374		// could have a situation where the sum cannot be calculated in the inner type. Calculating
375		// using the upper type of the per_thing should assure this to be okay.
376		assert_eq!(
377			vec![PerU16::from_percent(40), PerU16::from_percent(40), PerU16::from_percent(40)]
378				.normalize(PerU16::one())
379				.unwrap(),
380			vec![
381				PerU16::from_parts(21845), // 33%
382				PerU16::from_parts(21845), // 33%
383				PerU16::from_parts(21845)  // 33%
384			]
385		);
386	}
387
388	#[test]
389	fn normalize_works_all_le() {
390		assert_eq!(normalize(vec![8u32, 9, 7, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
391
392		assert_eq!(normalize(vec![7u32, 7, 7, 7].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
393
394		assert_eq!(normalize(vec![7u32, 7, 7, 10].as_ref(), 40).unwrap(), vec![11, 11, 8, 10]);
395
396		assert_eq!(normalize(vec![7u32, 8, 7, 10].as_ref(), 40).unwrap(), vec![11, 8, 11, 10]);
397
398		assert_eq!(normalize(vec![7u32, 7, 8, 10].as_ref(), 40).unwrap(), vec![11, 11, 8, 10]);
399	}
400
401	#[test]
402	fn normalize_works_some_ge() {
403		assert_eq!(normalize(vec![8u32, 11, 9, 10].as_ref(), 40).unwrap(), vec![10, 11, 9, 10]);
404	}
405
406	#[test]
407	fn always_inc_min() {
408		assert_eq!(normalize(vec![10u32, 7, 10, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
409		assert_eq!(normalize(vec![10u32, 10, 7, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
410		assert_eq!(normalize(vec![10u32, 10, 10, 7].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
411	}
412
413	#[test]
414	fn normalize_works_all_ge() {
415		assert_eq!(normalize(vec![12u32, 11, 13, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
416
417		assert_eq!(normalize(vec![13u32, 13, 13, 13].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
418
419		assert_eq!(normalize(vec![13u32, 13, 13, 10].as_ref(), 40).unwrap(), vec![12, 9, 9, 10]);
420
421		assert_eq!(normalize(vec![13u32, 12, 13, 10].as_ref(), 40).unwrap(), vec![9, 12, 9, 10]);
422
423		assert_eq!(normalize(vec![13u32, 13, 12, 10].as_ref(), 40).unwrap(), vec![9, 9, 12, 10]);
424	}
425}
426
427#[cfg(test)]
428mod per_and_fixed_examples {
429	use super::*;
430
431	#[docify::export]
432	#[test]
433	fn percent_mult() {
434		let percent = Percent::from_rational(5u32, 100u32); // aka, 5%
435		let five_percent_of_100 = percent * 100u32; // 5% of 100 is 5.
436		assert_eq!(five_percent_of_100, 5)
437	}
438	#[docify::export]
439	#[test]
440	fn perbill_example() {
441		let p = Perbill::from_percent(80);
442		// 800000000 bil, or a representative of 0.800000000.
443		// Precision is in the billions place.
444		assert_eq!(p.deconstruct(), 800000000);
445	}
446
447	#[docify::export]
448	#[test]
449	fn percent_example() {
450		let percent = Percent::from_rational(190u32, 400u32);
451		assert_eq!(percent.deconstruct(), 47);
452	}
453
454	#[docify::export]
455	#[test]
456	fn fixed_u64_block_computation_example() {
457		// Calculate a very rudimentary on-chain price from supply / demand
458		// Supply: Cores available per block
459		// Demand: Cores being ordered per block
460		let price = FixedU64::from_rational(5u128, 10u128);
461
462		// 0.5 DOT per core
463		assert_eq!(price, FixedU64::from_float(0.5));
464
465		// Now, the story has changed - lots of demand means we buy as many cores as there
466		// available.  This also means that price goes up! For the sake of simplicity, we don't care
467		// about who gets a core - just about our very simple price model
468
469		// Calculate a very rudimentary on-chain price from supply / demand
470		// Supply: Cores available per block
471		// Demand: Cores being ordered per block
472		let price = FixedU64::from_rational(19u128, 10u128);
473
474		// 1.9 DOT per core
475		assert_eq!(price, FixedU64::from_float(1.9));
476	}
477
478	#[docify::export]
479	#[test]
480	fn fixed_u64() {
481		// The difference between this and perthings is perthings operates within the relam of [0,
482		// 1] In cases where we need > 1, we can used fixed types such as FixedU64
483
484		let rational_1 = FixedU64::from_rational(10, 5); //" 200%" aka 2.
485		let rational_2 = FixedU64::from_rational_with_rounding(5, 10, Rounding::Down); // "50%" aka 0.50...
486
487		assert_eq!(rational_1, (2u64).into());
488		assert_eq!(rational_2.into_perbill(), Perbill::from_float(0.5));
489	}
490
491	#[docify::export]
492	#[test]
493	fn fixed_u64_operation_example() {
494		let rational_1 = FixedU64::from_rational(10, 5); // "200%" aka 2.
495		let rational_2 = FixedU64::from_rational(8, 5); // "160%" aka 1.6.
496
497		let addition = rational_1 + rational_2;
498		let multiplication = rational_1 * rational_2;
499		let division = rational_1 / rational_2;
500		let subtraction = rational_1 - rational_2;
501
502		assert_eq!(addition, FixedU64::from_float(3.6));
503		assert_eq!(multiplication, FixedU64::from_float(3.2));
504		assert_eq!(division, FixedU64::from_float(1.25));
505		assert_eq!(subtraction, FixedU64::from_float(0.4));
506	}
507}
508
509#[cfg(test)]
510mod threshold_compare_tests {
511	use super::*;
512	use crate::arithmetic::traits::Saturating;
513	use core::cmp::Ordering;
514
515	#[test]
516	fn epsilon_ord_works() {
517		let b = 115u32;
518		let e = Perbill::from_percent(10).mul_ceil(b);
519
520		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
521		assert_eq!((103u32).tcmp(&b, e), Ordering::Equal);
522		assert_eq!((104u32).tcmp(&b, e), Ordering::Equal);
523		assert_eq!((115u32).tcmp(&b, e), Ordering::Equal);
524		assert_eq!((120u32).tcmp(&b, e), Ordering::Equal);
525		assert_eq!((126u32).tcmp(&b, e), Ordering::Equal);
526		assert_eq!((127u32).tcmp(&b, e), Ordering::Equal);
527
528		assert_eq!((128u32).tcmp(&b, e), Ordering::Greater);
529		assert_eq!((102u32).tcmp(&b, e), Ordering::Less);
530	}
531
532	#[test]
533	fn epsilon_ord_works_with_small_epc() {
534		let b = 115u32;
535		// way less than 1 percent. threshold will be zero. Result should be same as normal ord.
536		let e = Perbill::from_parts(100) * b;
537
538		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
539		assert_eq!((103u32).tcmp(&b, e), (103u32).cmp(&b));
540		assert_eq!((104u32).tcmp(&b, e), (104u32).cmp(&b));
541		assert_eq!((115u32).tcmp(&b, e), (115u32).cmp(&b));
542		assert_eq!((120u32).tcmp(&b, e), (120u32).cmp(&b));
543		assert_eq!((126u32).tcmp(&b, e), (126u32).cmp(&b));
544		assert_eq!((127u32).tcmp(&b, e), (127u32).cmp(&b));
545
546		assert_eq!((128u32).tcmp(&b, e), (128u32).cmp(&b));
547		assert_eq!((102u32).tcmp(&b, e), (102u32).cmp(&b));
548	}
549
550	#[test]
551	fn peru16_rational_does_not_overflow() {
552		// A historical example that will panic only for per_thing type that are created with
553		// maximum capacity of their type, e.g. PerU16.
554		let _ = PerU16::from_rational(17424870u32, 17424870);
555	}
556
557	#[test]
558	fn saturating_mul_works() {
559		assert_eq!(Saturating::saturating_mul(2, i32::MIN), i32::MIN);
560		assert_eq!(Saturating::saturating_mul(2, i32::MAX), i32::MAX);
561	}
562
563	#[test]
564	fn saturating_pow_works() {
565		assert_eq!(Saturating::saturating_pow(i32::MIN, 0), 1);
566		assert_eq!(Saturating::saturating_pow(i32::MAX, 0), 1);
567		assert_eq!(Saturating::saturating_pow(i32::MIN, 3), i32::MIN);
568		assert_eq!(Saturating::saturating_pow(i32::MIN, 2), i32::MAX);
569		assert_eq!(Saturating::saturating_pow(i32::MAX, 2), i32::MAX);
570	}
571}