Skip to main content

subsoil/arithmetic/
rational.rs

1// This file is part of Soil.
2
3// Copyright (C) Soil contributors.
4// Copyright (C) Parity Technologies (UK) Ltd.
5// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later WITH Classpath-exception-2.0
6
7use crate::arithmetic::{biguint::BigUint, helpers_128bit, Rounding};
8use core::cmp::Ordering;
9use num_traits::{Bounded, One, Zero};
10
11/// A wrapper for any rational number with infinitely large numerator and denominator.
12///
13/// This type exists to facilitate `cmp` operation
14/// on values like `a/b < c/d` where `a, b, c, d` are all `BigUint`.
15#[derive(Clone, Default, Eq)]
16pub struct RationalInfinite(BigUint, BigUint);
17
18impl RationalInfinite {
19	/// Return the numerator reference.
20	pub fn n(&self) -> &BigUint {
21		&self.0
22	}
23
24	/// Return the denominator reference.
25	pub fn d(&self) -> &BigUint {
26		&self.1
27	}
28
29	/// Build from a raw `n/d`.
30	pub fn from(n: BigUint, d: BigUint) -> Self {
31		Self(n, d.max(BigUint::one()))
32	}
33
34	/// Zero.
35	pub fn zero() -> Self {
36		Self(BigUint::zero(), BigUint::one())
37	}
38
39	/// One.
40	pub fn one() -> Self {
41		Self(BigUint::one(), BigUint::one())
42	}
43}
44
45impl PartialOrd for RationalInfinite {
46	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
47		Some(self.cmp(other))
48	}
49}
50
51impl Ord for RationalInfinite {
52	fn cmp(&self, other: &Self) -> Ordering {
53		// handle some edge cases.
54		if self.d() == other.d() {
55			self.n().cmp(other.n())
56		} else if self.d().is_zero() {
57			Ordering::Greater
58		} else if other.d().is_zero() {
59			Ordering::Less
60		} else {
61			// (a/b) cmp (c/d) => (a*d) cmp (c*b)
62			self.n().clone().mul(other.d()).cmp(&other.n().clone().mul(self.d()))
63		}
64	}
65}
66
67impl PartialEq for RationalInfinite {
68	fn eq(&self, other: &Self) -> bool {
69		self.cmp(other) == Ordering::Equal
70	}
71}
72
73impl From<Rational128> for RationalInfinite {
74	fn from(t: Rational128) -> Self {
75		Self(t.0.into(), t.1.into())
76	}
77}
78
79/// A wrapper for any rational number with a 128 bit numerator and denominator.
80#[derive(Clone, Copy, Default, Eq)]
81pub struct Rational128(u128, u128);
82
83#[cfg(feature = "std")]
84impl core::fmt::Debug for Rational128 {
85	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86		write!(f, "Rational128({} / {} ≈ {:.8})", self.0, self.1, self.0 as f64 / self.1 as f64)
87	}
88}
89
90#[cfg(not(feature = "std"))]
91impl core::fmt::Debug for Rational128 {
92	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
93		write!(f, "Rational128({} / {})", self.0, self.1)
94	}
95}
96
97impl Rational128 {
98	/// Zero.
99	pub fn zero() -> Self {
100		Self(0, 1)
101	}
102
103	/// One
104	pub fn one() -> Self {
105		Self(1, 1)
106	}
107
108	/// If it is zero or not
109	pub fn is_zero(&self) -> bool {
110		self.0.is_zero()
111	}
112
113	/// Build from a raw `n/d`.
114	pub fn from(n: u128, d: u128) -> Self {
115		Self(n, d.max(1))
116	}
117
118	/// Build from a raw `n/d`. This could lead to / 0 if not properly handled.
119	pub fn from_unchecked(n: u128, d: u128) -> Self {
120		Self(n, d)
121	}
122
123	/// Return the numerator.
124	pub fn n(&self) -> u128 {
125		self.0
126	}
127
128	/// Return the denominator.
129	pub fn d(&self) -> u128 {
130		self.1
131	}
132
133	/// Convert `self` to a similar rational number where denominator is the given `den`.
134	//
135	/// This only returns if the result is accurate. `None` is returned if the result cannot be
136	/// accurately calculated.
137	pub fn to_den(self, den: u128) -> Option<Self> {
138		if den == self.1 {
139			Some(self)
140		} else {
141			helpers_128bit::multiply_by_rational_with_rounding(
142				self.0,
143				den,
144				self.1,
145				Rounding::NearestPrefDown,
146			)
147			.map(|n| Self(n, den))
148		}
149	}
150
151	/// Get the least common divisor of `self` and `other`.
152	///
153	/// This only returns if the result is accurate. `None` is returned if the result cannot be
154	/// accurately calculated.
155	pub fn lcm(&self, other: &Self) -> Option<u128> {
156		// this should be tested better: two large numbers that are almost the same.
157		if self.1 == other.1 {
158			return Some(self.1);
159		}
160		let g = helpers_128bit::gcd(self.1, other.1);
161		helpers_128bit::multiply_by_rational_with_rounding(
162			self.1,
163			other.1,
164			g,
165			Rounding::NearestPrefDown,
166		)
167	}
168
169	/// A saturating add that assumes `self` and `other` have the same denominator.
170	pub fn lazy_saturating_add(self, other: Self) -> Self {
171		if other.is_zero() {
172			self
173		} else {
174			Self(self.0.saturating_add(other.0), self.1)
175		}
176	}
177
178	/// A saturating subtraction that assumes `self` and `other` have the same denominator.
179	pub fn lazy_saturating_sub(self, other: Self) -> Self {
180		if other.is_zero() {
181			self
182		} else {
183			Self(self.0.saturating_sub(other.0), self.1)
184		}
185	}
186
187	/// Addition. Simply tries to unify the denominators and add the numerators.
188	///
189	/// Overflow might happen during any of the steps. Error is returned in such cases.
190	pub fn checked_add(self, other: Self) -> Result<Self, &'static str> {
191		let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
192		let self_scaled =
193			self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
194		let other_scaled =
195			other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
196		let n = self_scaled
197			.0
198			.checked_add(other_scaled.0)
199			.ok_or("overflow while adding numerators")?;
200		Ok(Self(n, self_scaled.1))
201	}
202
203	/// Subtraction. Simply tries to unify the denominators and subtract the numerators.
204	///
205	/// Overflow might happen during any of the steps. None is returned in such cases.
206	pub fn checked_sub(self, other: Self) -> Result<Self, &'static str> {
207		let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
208		let self_scaled =
209			self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
210		let other_scaled =
211			other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
212
213		let n = self_scaled
214			.0
215			.checked_sub(other_scaled.0)
216			.ok_or("overflow while subtracting numerators")?;
217		Ok(Self(n, self_scaled.1))
218	}
219}
220
221impl Bounded for Rational128 {
222	fn min_value() -> Self {
223		Self(0, 1)
224	}
225
226	fn max_value() -> Self {
227		Self(Bounded::max_value(), 1)
228	}
229}
230
231impl<T: Into<u128>> From<T> for Rational128 {
232	fn from(t: T) -> Self {
233		Self::from(t.into(), 1)
234	}
235}
236
237impl PartialOrd for Rational128 {
238	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
239		Some(self.cmp(other))
240	}
241}
242
243impl Ord for Rational128 {
244	fn cmp(&self, other: &Self) -> Ordering {
245		// handle some edge cases.
246		if self.1 == other.1 {
247			self.0.cmp(&other.0)
248		} else if self.1.is_zero() {
249			Ordering::Greater
250		} else if other.1.is_zero() {
251			Ordering::Less
252		} else {
253			// Don't even compute gcd.
254			let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
255			let other_n =
256				helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
257			self_n.cmp(&other_n)
258		}
259	}
260}
261
262impl PartialEq for Rational128 {
263	fn eq(&self, other: &Self) -> bool {
264		// handle some edge cases.
265		if self.1 == other.1 {
266			self.0.eq(&other.0)
267		} else {
268			let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
269			let other_n =
270				helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
271			self_n.eq(&other_n)
272		}
273	}
274}
275
276pub trait MultiplyRational: Sized {
277	fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self>;
278}
279
280macro_rules! impl_rrm {
281	($ulow:ty, $uhi:ty) => {
282		impl MultiplyRational for $ulow {
283			fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
284				if d.is_zero() {
285					return None;
286				}
287
288				let sn = (self as $uhi) * (n as $uhi);
289				let mut result = sn / (d as $uhi);
290				let remainder = (sn % (d as $uhi)) as $ulow;
291				if match r {
292					Rounding::Up => remainder > 0,
293					// cannot be `(d + 1) / 2` since `d` might be `max_value` and overflow.
294					Rounding::NearestPrefUp => remainder >= d / 2 + d % 2,
295					Rounding::NearestPrefDown => remainder > d / 2,
296					Rounding::Down => false,
297				} {
298					result = match result.checked_add(1) {
299						Some(v) => v,
300						None => return None,
301					};
302				}
303				if result > (<$ulow>::max_value() as $uhi) {
304					None
305				} else {
306					Some(result as $ulow)
307				}
308			}
309		}
310	};
311}
312
313impl_rrm!(u8, u16);
314impl_rrm!(u16, u32);
315impl_rrm!(u32, u64);
316impl_rrm!(u64, u128);
317
318impl MultiplyRational for u128 {
319	fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
320		crate::arithmetic::helpers_128bit::multiply_by_rational_with_rounding(self, n, d, r)
321	}
322}
323
324#[cfg(test)]
325mod tests {
326	use super::{helpers_128bit::*, *};
327	use static_assertions::const_assert;
328
329	const MAX128: u128 = u128::MAX;
330	const MAX64: u128 = u64::MAX as u128;
331	const MAX64_2: u128 = 2 * u64::MAX as u128;
332
333	fn r(p: u128, q: u128) -> Rational128 {
334		Rational128(p, q)
335	}
336
337	fn mul_div(a: u128, b: u128, c: u128) -> u128 {
338		use primitive_types::U256;
339		if a.is_zero() {
340			return Zero::zero();
341		}
342		let c = c.max(1);
343
344		// e for extended
345		let ae: U256 = a.into();
346		let be: U256 = b.into();
347		let ce: U256 = c.into();
348
349		let r = ae * be / ce;
350		if r > u128::max_value().into() {
351			a
352		} else {
353			r.as_u128()
354		}
355	}
356
357	#[test]
358	fn truth_value_function_works() {
359		assert_eq!(mul_div(2u128.pow(100), 8, 4), 2u128.pow(101));
360		assert_eq!(mul_div(2u128.pow(100), 4, 8), 2u128.pow(99));
361
362		// and it returns a if result cannot fit
363		assert_eq!(mul_div(MAX128 - 10, 2, 1), MAX128 - 10);
364	}
365
366	#[test]
367	fn to_denom_works() {
368		// simple up and down
369		assert_eq!(r(1, 5).to_den(10), Some(r(2, 10)));
370		assert_eq!(r(4, 10).to_den(5), Some(r(2, 5)));
371
372		// up and down with large numbers
373		assert_eq!(r(MAX128 - 10, MAX128).to_den(10), Some(r(10, 10)));
374		assert_eq!(r(MAX128 / 2, MAX128).to_den(10), Some(r(5, 10)));
375
376		// large to perbill. This is very well needed for npos-elections.
377		assert_eq!(r(MAX128 / 2, MAX128).to_den(1000_000_000), Some(r(500_000_000, 1000_000_000)));
378
379		// large to large
380		assert_eq!(r(MAX128 / 2, MAX128).to_den(MAX128 / 2), Some(r(MAX128 / 4, MAX128 / 2)));
381	}
382
383	#[test]
384	fn gdc_works() {
385		assert_eq!(gcd(10, 5), 5);
386		assert_eq!(gcd(7, 22), 1);
387	}
388
389	#[test]
390	fn lcm_works() {
391		// simple stuff
392		assert_eq!(r(3, 10).lcm(&r(4, 15)).unwrap(), 30);
393		assert_eq!(r(5, 30).lcm(&r(1, 7)).unwrap(), 210);
394		assert_eq!(r(5, 30).lcm(&r(1, 10)).unwrap(), 30);
395
396		// large numbers
397		assert_eq!(r(1_000_000_000, MAX128).lcm(&r(7_000_000_000, MAX128 - 1)), None,);
398		assert_eq!(
399			r(1_000_000_000, MAX64).lcm(&r(7_000_000_000, MAX64 - 1)),
400			Some(340282366920938463408034375210639556610),
401		);
402		const_assert!(340282366920938463408034375210639556610 < MAX128);
403		const_assert!(340282366920938463408034375210639556610 == MAX64 * (MAX64 - 1));
404	}
405
406	#[test]
407	fn add_works() {
408		// works
409		assert_eq!(r(3, 10).checked_add(r(1, 10)).unwrap(), r(2, 5));
410		assert_eq!(r(3, 10).checked_add(r(3, 7)).unwrap(), r(51, 70));
411
412		// errors
413		assert_eq!(
414			r(1, MAX128).checked_add(r(1, MAX128 - 1)),
415			Err("failed to scale to denominator"),
416		);
417		assert_eq!(
418			r(7, MAX128).checked_add(r(MAX128, MAX128)),
419			Err("overflow while adding numerators"),
420		);
421		assert_eq!(
422			r(MAX128, MAX128).checked_add(r(MAX128, MAX128)),
423			Err("overflow while adding numerators"),
424		);
425	}
426
427	#[test]
428	fn sub_works() {
429		// works
430		assert_eq!(r(3, 10).checked_sub(r(1, 10)).unwrap(), r(1, 5));
431		assert_eq!(r(6, 10).checked_sub(r(3, 7)).unwrap(), r(12, 70));
432
433		// errors
434		assert_eq!(
435			r(2, MAX128).checked_sub(r(1, MAX128 - 1)),
436			Err("failed to scale to denominator"),
437		);
438		assert_eq!(
439			r(7, MAX128).checked_sub(r(MAX128, MAX128)),
440			Err("overflow while subtracting numerators"),
441		);
442		assert_eq!(r(1, 10).checked_sub(r(2, 10)), Err("overflow while subtracting numerators"));
443	}
444
445	#[test]
446	fn ordering_and_eq_works() {
447		assert!(r(1, 2) > r(1, 3));
448		assert!(r(1, 2) > r(2, 6));
449
450		assert!(r(1, 2) < r(6, 6));
451		assert!(r(2, 1) > r(2, 6));
452
453		assert!(r(5, 10) == r(1, 2));
454		assert!(r(1, 2) == r(1, 2));
455
456		assert!(r(1, 1490000000000200000) > r(1, 1490000000000200001));
457	}
458
459	#[test]
460	fn multiply_by_rational_with_rounding_works() {
461		assert_eq!(multiply_by_rational_with_rounding(7, 2, 3, Rounding::Down).unwrap(), 7 * 2 / 3);
462		assert_eq!(
463			multiply_by_rational_with_rounding(7, 20, 30, Rounding::Down).unwrap(),
464			7 * 2 / 3
465		);
466		assert_eq!(
467			multiply_by_rational_with_rounding(20, 7, 30, Rounding::Down).unwrap(),
468			7 * 2 / 3
469		);
470
471		assert_eq!(
472			// MAX128 % 3 == 0
473			multiply_by_rational_with_rounding(MAX128, 2, 3, Rounding::Down).unwrap(),
474			MAX128 / 3 * 2,
475		);
476		assert_eq!(
477			// MAX128 % 7 == 3
478			multiply_by_rational_with_rounding(MAX128, 5, 7, Rounding::Down).unwrap(),
479			(MAX128 / 7 * 5) + (3 * 5 / 7),
480		);
481		assert_eq!(
482			// MAX128 % 7 == 3
483			multiply_by_rational_with_rounding(MAX128, 11, 13, Rounding::Down).unwrap(),
484			(MAX128 / 13 * 11) + (8 * 11 / 13),
485		);
486		assert_eq!(
487			// MAX128 % 1000 == 455
488			multiply_by_rational_with_rounding(MAX128, 555, 1000, Rounding::Down).unwrap(),
489			(MAX128 / 1000 * 555) + (455 * 555 / 1000),
490		);
491
492		assert_eq!(
493			multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64, MAX64, Rounding::Down)
494				.unwrap(),
495			2 * MAX64 - 1
496		);
497		assert_eq!(
498			multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64 - 1, MAX64, Rounding::Down)
499				.unwrap(),
500			2 * MAX64 - 3
501		);
502
503		assert_eq!(
504			multiply_by_rational_with_rounding(MAX64 + 100, MAX64_2, MAX64_2 / 2, Rounding::Down)
505				.unwrap(),
506			(MAX64 + 100) * 2,
507		);
508		assert_eq!(
509			multiply_by_rational_with_rounding(
510				MAX64 + 100,
511				MAX64_2 / 100,
512				MAX64_2 / 200,
513				Rounding::Down
514			)
515			.unwrap(),
516			(MAX64 + 100) * 2,
517		);
518
519		assert_eq!(
520			multiply_by_rational_with_rounding(
521				2u128.pow(66) - 1,
522				2u128.pow(65) - 1,
523				2u128.pow(65),
524				Rounding::Down
525			)
526			.unwrap(),
527			73786976294838206461,
528		);
529		assert_eq!(
530			multiply_by_rational_with_rounding(1_000_000_000, MAX128 / 8, MAX128 / 2, Rounding::Up)
531				.unwrap(),
532			250000000
533		);
534
535		assert_eq!(
536			multiply_by_rational_with_rounding(
537				29459999999999999988000u128,
538				1000000000000000000u128,
539				10000000000000000000u128,
540				Rounding::Down
541			)
542			.unwrap(),
543			2945999999999999998800u128
544		);
545	}
546
547	#[test]
548	fn multiply_by_rational_with_rounding_a_b_are_interchangeable() {
549		assert_eq!(
550			multiply_by_rational_with_rounding(10, MAX128, MAX128 / 2, Rounding::NearestPrefDown),
551			Some(20)
552		);
553		assert_eq!(
554			multiply_by_rational_with_rounding(MAX128, 10, MAX128 / 2, Rounding::NearestPrefDown),
555			Some(20)
556		);
557	}
558
559	#[test]
560	#[ignore]
561	fn multiply_by_rational_with_rounding_fuzzed_equation() {
562		assert_eq!(
563			multiply_by_rational_with_rounding(
564				154742576605164960401588224,
565				9223376310179529214,
566				549756068598,
567				Rounding::NearestPrefDown
568			),
569			Some(2596149632101417846585204209223679)
570		);
571	}
572}