1use crate::{Perbill, traits::{AtLeast32BitUnsigned, SaturatedConversion}};
21use core::ops::Sub;
22
23#[derive(PartialEq, Eq, tet_core::RuntimeDebug)]
25pub struct PiecewiseLinear<'a> {
26 pub points: &'a [(Perbill, Perbill)],
28 pub maximum: Perbill,
30}
31
32fn abs_sub<N: Ord + Sub<Output=N> + Clone>(a: N, b: N) -> N where {
33 a.clone().max(b.clone()) - a.min(b)
34}
35
36impl<'a> PiecewiseLinear<'a> {
37 pub fn calculate_for_fraction_times_denominator<N>(&self, n: N, d: N) -> N where
39 N: AtLeast32BitUnsigned + Clone
40 {
41 let n = n.min(d.clone());
42
43 if self.points.len() == 0 {
44 return N::zero()
45 }
46
47 let next_point_index = self.points.iter()
48 .position(|p| n < p.0 * d.clone());
49
50 let (prev, next) = if let Some(next_point_index) = next_point_index {
51 if let Some(previous_point_index) = next_point_index.checked_sub(1) {
52 (self.points[previous_point_index], self.points[next_point_index])
53 } else {
54 return self.points.first().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
56 }
57 } else {
58 return self.points.last().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
60 };
61
62 let delta_y = multiply_by_rational_saturating(
63 abs_sub(n.clone(), prev.0 * d.clone()),
64 abs_sub(next.1.deconstruct(), prev.1.deconstruct()),
65 next.0.deconstruct().saturating_sub(prev.0.deconstruct()),
67 );
68
69 if (n > prev.0 * d.clone()) == (next.1.deconstruct() > prev.1.deconstruct()) {
71 (prev.1 * d).saturating_add(delta_y)
72 } else {
74 (prev.1 * d).saturating_sub(delta_y)
75 }
76 }
77}
78
79fn multiply_by_rational_saturating<N>(value: N, p: u32, q: u32) -> N
83 where N: AtLeast32BitUnsigned + Clone
84{
85 let q = q.max(1);
86
87 let result_divisor_part = (value.clone() / q.into()).saturating_mul(p.into());
89
90 let result_remainder_part = {
91 let rem = value % q.into();
92
93 let rem_u32 = rem.saturated_into::<u32>();
95
96 let rem_part = rem_u32 as u64 * p as u64 / q as u64;
98
99 rem_part.saturated_into::<N>()
101 };
102
103 result_divisor_part.saturating_add(result_remainder_part)
105}
106
107#[test]
108fn test_multiply_by_rational_saturating() {
109 use std::convert::TryInto;
110
111 let div = 100u32;
112 for value in 0..=div {
113 for p in 0..=div {
114 for q in 1..=div {
115 let value: u64 = (value as u128 * u64::max_value() as u128 / div as u128)
116 .try_into().unwrap();
117 let p = (p as u64 * u32::max_value() as u64 / div as u64)
118 .try_into().unwrap();
119 let q = (q as u64 * u32::max_value() as u64 / div as u64)
120 .try_into().unwrap();
121
122 assert_eq!(
123 multiply_by_rational_saturating(value, p, q),
124 (value as u128 * p as u128 / q as u128)
125 .try_into().unwrap_or(u64::max_value())
126 );
127 }
128 }
129 }
130}
131
132#[test]
133fn test_calculate_for_fraction_times_denominator() {
134 use std::convert::TryInto;
135
136 let curve = PiecewiseLinear {
137 points: &[
138 (Perbill::from_parts(0_000_000_000), Perbill::from_parts(0_500_000_000)),
139 (Perbill::from_parts(0_500_000_000), Perbill::from_parts(1_000_000_000)),
140 (Perbill::from_parts(1_000_000_000), Perbill::from_parts(0_000_000_000)),
141 ],
142 maximum: Perbill::from_parts(1_000_000_000),
143 };
144
145 pub fn formal_calculate_for_fraction_times_denominator(n: u64, d: u64) -> u64 {
146 if n <= Perbill::from_parts(0_500_000_000) * d.clone() {
147 n + d / 2
148 } else {
149 (d as u128 * 2 - n as u128 * 2).try_into().unwrap()
150 }
151 }
152
153 let div = 100u32;
154 for d in 0..=div {
155 for n in 0..=d {
156 let d: u64 = (d as u128 * u64::max_value() as u128 / div as u128)
157 .try_into().unwrap();
158 let n: u64 = (n as u128 * u64::max_value() as u128 / div as u128)
159 .try_into().unwrap();
160
161 let res = curve.calculate_for_fraction_times_denominator(n, d);
162 let expected = formal_calculate_for_fraction_times_denominator(n, d);
163
164 assert!(abs_sub(res, expected) <= 1);
165 }
166 }
167}