tp_runtime/
curve.rs

1// This file is part of Tetcore.
2
3// Copyright (C) 2019-2021 Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Provides some utilities to define a piecewise linear function.
19
20use crate::{Perbill, traits::{AtLeast32BitUnsigned, SaturatedConversion}};
21use core::ops::Sub;
22
23/// Piecewise Linear function in [0, 1] -> [0, 1].
24#[derive(PartialEq, Eq, tet_core::RuntimeDebug)]
25pub struct PiecewiseLinear<'a> {
26	/// Array of points. Must be in order from the lowest abscissas to the highest.
27	pub points: &'a [(Perbill, Perbill)],
28	/// The maximum value that can be returned.
29	pub maximum: Perbill,
30}
31
32fn abs_sub<N: Ord + Sub<Output=N> + Clone>(a: N, b: N) -> N where {
33	a.clone().max(b.clone()) - a.min(b)
34}
35
36impl<'a> PiecewiseLinear<'a> {
37	/// Compute `f(n/d)*d` with `n <= d`. This is useful to avoid loss of precision.
38	pub fn calculate_for_fraction_times_denominator<N>(&self, n: N, d: N) -> N where
39		N: AtLeast32BitUnsigned + Clone
40	{
41		let n = n.min(d.clone());
42
43		if self.points.len() == 0 {
44			return N::zero()
45		}
46
47		let next_point_index = self.points.iter()
48			.position(|p| n < p.0 * d.clone());
49
50		let (prev, next) = if let Some(next_point_index) = next_point_index {
51			if let Some(previous_point_index) = next_point_index.checked_sub(1) {
52				(self.points[previous_point_index], self.points[next_point_index])
53			} else {
54				// There is no previous points, take first point ordinate
55				return self.points.first().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
56			}
57		} else {
58			// There is no next points, take last point ordinate
59			return self.points.last().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
60		};
61
62		let delta_y = multiply_by_rational_saturating(
63			abs_sub(n.clone(), prev.0 * d.clone()),
64			abs_sub(next.1.deconstruct(), prev.1.deconstruct()),
65			// Must not saturate as prev abscissa > next abscissa
66			next.0.deconstruct().saturating_sub(prev.0.deconstruct()),
67		);
68
69		// If both subtractions are same sign then result is positive
70		if (n > prev.0 * d.clone()) == (next.1.deconstruct() > prev.1.deconstruct()) {
71			(prev.1 * d).saturating_add(delta_y)
72		// Otherwise result is negative
73		} else {
74			(prev.1 * d).saturating_sub(delta_y)
75		}
76	}
77}
78
79// Compute value * p / q.
80// This is guaranteed not to overflow on whatever values nor lose precision.
81// `q` must be superior to zero.
82fn multiply_by_rational_saturating<N>(value: N, p: u32, q: u32) -> N
83	where N: AtLeast32BitUnsigned + Clone
84{
85	let q = q.max(1);
86
87	// Mul can saturate if p > q
88	let result_divisor_part = (value.clone() / q.into()).saturating_mul(p.into());
89
90	let result_remainder_part = {
91		let rem = value % q.into();
92
93		// Fits into u32 because q is u32 and remainder < q
94		let rem_u32 = rem.saturated_into::<u32>();
95
96		// Multiplication fits into u64 as both term are u32
97		let rem_part = rem_u32 as u64 * p as u64 / q as u64;
98
99		// Can saturate if p > q
100		rem_part.saturated_into::<N>()
101	};
102
103	// Can saturate if p > q
104	result_divisor_part.saturating_add(result_remainder_part)
105}
106
107#[test]
108fn test_multiply_by_rational_saturating() {
109	use std::convert::TryInto;
110
111	let div = 100u32;
112	for value in 0..=div {
113		for p in 0..=div {
114			for q in 1..=div {
115				let value: u64 = (value as u128 * u64::max_value() as u128 / div as u128)
116					.try_into().unwrap();
117				let p = (p as u64 * u32::max_value() as u64 / div as u64)
118					.try_into().unwrap();
119				let q = (q as u64 * u32::max_value() as u64 / div as u64)
120					.try_into().unwrap();
121
122				assert_eq!(
123					multiply_by_rational_saturating(value, p, q),
124					(value as u128 * p as u128 / q as u128)
125						.try_into().unwrap_or(u64::max_value())
126				);
127			}
128		}
129	}
130}
131
132#[test]
133fn test_calculate_for_fraction_times_denominator() {
134	use std::convert::TryInto;
135
136	let curve = PiecewiseLinear {
137		points: &[
138			(Perbill::from_parts(0_000_000_000), Perbill::from_parts(0_500_000_000)),
139			(Perbill::from_parts(0_500_000_000), Perbill::from_parts(1_000_000_000)),
140			(Perbill::from_parts(1_000_000_000), Perbill::from_parts(0_000_000_000)),
141		],
142		maximum: Perbill::from_parts(1_000_000_000),
143	};
144
145	pub fn formal_calculate_for_fraction_times_denominator(n: u64, d: u64) -> u64 {
146		if n <= Perbill::from_parts(0_500_000_000) * d.clone() {
147			n + d / 2
148		} else {
149			(d as u128 * 2 - n as u128 * 2).try_into().unwrap()
150		}
151	}
152
153	let div = 100u32;
154	for d in 0..=div {
155		for n in 0..=d {
156			let d: u64 = (d as u128 * u64::max_value() as u128 / div as u128)
157				.try_into().unwrap();
158			let n: u64 = (n as u128 * u64::max_value() as u128 / div as u128)
159				.try_into().unwrap();
160
161			let res = curve.calculate_for_fraction_times_denominator(n, d);
162			let expected = formal_calculate_for_fraction_times_denominator(n, d);
163
164			assert!(abs_sub(res, expected) <= 1);
165		}
166	}
167}