1use crate::arithmetic::{biguint::BigUint, helpers_128bit, Rounding};
8use core::cmp::Ordering;
9use num_traits::{Bounded, One, Zero};
10
11#[derive(Clone, Default, Eq)]
16pub struct RationalInfinite(BigUint, BigUint);
17
18impl RationalInfinite {
19 pub fn n(&self) -> &BigUint {
21 &self.0
22 }
23
24 pub fn d(&self) -> &BigUint {
26 &self.1
27 }
28
29 pub fn from(n: BigUint, d: BigUint) -> Self {
31 Self(n, d.max(BigUint::one()))
32 }
33
34 pub fn zero() -> Self {
36 Self(BigUint::zero(), BigUint::one())
37 }
38
39 pub fn one() -> Self {
41 Self(BigUint::one(), BigUint::one())
42 }
43}
44
45impl PartialOrd for RationalInfinite {
46 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
47 Some(self.cmp(other))
48 }
49}
50
51impl Ord for RationalInfinite {
52 fn cmp(&self, other: &Self) -> Ordering {
53 if self.d() == other.d() {
55 self.n().cmp(other.n())
56 } else if self.d().is_zero() {
57 Ordering::Greater
58 } else if other.d().is_zero() {
59 Ordering::Less
60 } else {
61 self.n().clone().mul(other.d()).cmp(&other.n().clone().mul(self.d()))
63 }
64 }
65}
66
67impl PartialEq for RationalInfinite {
68 fn eq(&self, other: &Self) -> bool {
69 self.cmp(other) == Ordering::Equal
70 }
71}
72
73impl From<Rational128> for RationalInfinite {
74 fn from(t: Rational128) -> Self {
75 Self(t.0.into(), t.1.into())
76 }
77}
78
79#[derive(Clone, Copy, Default, Eq)]
81pub struct Rational128(u128, u128);
82
83#[cfg(feature = "std")]
84impl core::fmt::Debug for Rational128 {
85 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86 write!(f, "Rational128({} / {} ≈ {:.8})", self.0, self.1, self.0 as f64 / self.1 as f64)
87 }
88}
89
90#[cfg(not(feature = "std"))]
91impl core::fmt::Debug for Rational128 {
92 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
93 write!(f, "Rational128({} / {})", self.0, self.1)
94 }
95}
96
97impl Rational128 {
98 pub fn zero() -> Self {
100 Self(0, 1)
101 }
102
103 pub fn one() -> Self {
105 Self(1, 1)
106 }
107
108 pub fn is_zero(&self) -> bool {
110 self.0.is_zero()
111 }
112
113 pub fn from(n: u128, d: u128) -> Self {
115 Self(n, d.max(1))
116 }
117
118 pub fn from_unchecked(n: u128, d: u128) -> Self {
120 Self(n, d)
121 }
122
123 pub fn n(&self) -> u128 {
125 self.0
126 }
127
128 pub fn d(&self) -> u128 {
130 self.1
131 }
132
133 pub fn to_den(self, den: u128) -> Option<Self> {
138 if den == self.1 {
139 Some(self)
140 } else {
141 helpers_128bit::multiply_by_rational_with_rounding(
142 self.0,
143 den,
144 self.1,
145 Rounding::NearestPrefDown,
146 )
147 .map(|n| Self(n, den))
148 }
149 }
150
151 pub fn lcm(&self, other: &Self) -> Option<u128> {
156 if self.1 == other.1 {
158 return Some(self.1);
159 }
160 let g = helpers_128bit::gcd(self.1, other.1);
161 helpers_128bit::multiply_by_rational_with_rounding(
162 self.1,
163 other.1,
164 g,
165 Rounding::NearestPrefDown,
166 )
167 }
168
169 pub fn lazy_saturating_add(self, other: Self) -> Self {
171 if other.is_zero() {
172 self
173 } else {
174 Self(self.0.saturating_add(other.0), self.1)
175 }
176 }
177
178 pub fn lazy_saturating_sub(self, other: Self) -> Self {
180 if other.is_zero() {
181 self
182 } else {
183 Self(self.0.saturating_sub(other.0), self.1)
184 }
185 }
186
187 pub fn checked_add(self, other: Self) -> Result<Self, &'static str> {
191 let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
192 let self_scaled =
193 self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
194 let other_scaled =
195 other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
196 let n = self_scaled
197 .0
198 .checked_add(other_scaled.0)
199 .ok_or("overflow while adding numerators")?;
200 Ok(Self(n, self_scaled.1))
201 }
202
203 pub fn checked_sub(self, other: Self) -> Result<Self, &'static str> {
207 let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
208 let self_scaled =
209 self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
210 let other_scaled =
211 other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
212
213 let n = self_scaled
214 .0
215 .checked_sub(other_scaled.0)
216 .ok_or("overflow while subtracting numerators")?;
217 Ok(Self(n, self_scaled.1))
218 }
219}
220
221impl Bounded for Rational128 {
222 fn min_value() -> Self {
223 Self(0, 1)
224 }
225
226 fn max_value() -> Self {
227 Self(Bounded::max_value(), 1)
228 }
229}
230
231impl<T: Into<u128>> From<T> for Rational128 {
232 fn from(t: T) -> Self {
233 Self::from(t.into(), 1)
234 }
235}
236
237impl PartialOrd for Rational128 {
238 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
239 Some(self.cmp(other))
240 }
241}
242
243impl Ord for Rational128 {
244 fn cmp(&self, other: &Self) -> Ordering {
245 if self.1 == other.1 {
247 self.0.cmp(&other.0)
248 } else if self.1.is_zero() {
249 Ordering::Greater
250 } else if other.1.is_zero() {
251 Ordering::Less
252 } else {
253 let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
255 let other_n =
256 helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
257 self_n.cmp(&other_n)
258 }
259 }
260}
261
262impl PartialEq for Rational128 {
263 fn eq(&self, other: &Self) -> bool {
264 if self.1 == other.1 {
266 self.0.eq(&other.0)
267 } else {
268 let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
269 let other_n =
270 helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
271 self_n.eq(&other_n)
272 }
273 }
274}
275
276pub trait MultiplyRational: Sized {
277 fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self>;
278}
279
280macro_rules! impl_rrm {
281 ($ulow:ty, $uhi:ty) => {
282 impl MultiplyRational for $ulow {
283 fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
284 if d.is_zero() {
285 return None;
286 }
287
288 let sn = (self as $uhi) * (n as $uhi);
289 let mut result = sn / (d as $uhi);
290 let remainder = (sn % (d as $uhi)) as $ulow;
291 if match r {
292 Rounding::Up => remainder > 0,
293 Rounding::NearestPrefUp => remainder >= d / 2 + d % 2,
295 Rounding::NearestPrefDown => remainder > d / 2,
296 Rounding::Down => false,
297 } {
298 result = match result.checked_add(1) {
299 Some(v) => v,
300 None => return None,
301 };
302 }
303 if result > (<$ulow>::max_value() as $uhi) {
304 None
305 } else {
306 Some(result as $ulow)
307 }
308 }
309 }
310 };
311}
312
313impl_rrm!(u8, u16);
314impl_rrm!(u16, u32);
315impl_rrm!(u32, u64);
316impl_rrm!(u64, u128);
317
318impl MultiplyRational for u128 {
319 fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
320 crate::arithmetic::helpers_128bit::multiply_by_rational_with_rounding(self, n, d, r)
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::{helpers_128bit::*, *};
327 use static_assertions::const_assert;
328
329 const MAX128: u128 = u128::MAX;
330 const MAX64: u128 = u64::MAX as u128;
331 const MAX64_2: u128 = 2 * u64::MAX as u128;
332
333 fn r(p: u128, q: u128) -> Rational128 {
334 Rational128(p, q)
335 }
336
337 fn mul_div(a: u128, b: u128, c: u128) -> u128 {
338 use primitive_types::U256;
339 if a.is_zero() {
340 return Zero::zero();
341 }
342 let c = c.max(1);
343
344 let ae: U256 = a.into();
346 let be: U256 = b.into();
347 let ce: U256 = c.into();
348
349 let r = ae * be / ce;
350 if r > u128::max_value().into() {
351 a
352 } else {
353 r.as_u128()
354 }
355 }
356
357 #[test]
358 fn truth_value_function_works() {
359 assert_eq!(mul_div(2u128.pow(100), 8, 4), 2u128.pow(101));
360 assert_eq!(mul_div(2u128.pow(100), 4, 8), 2u128.pow(99));
361
362 assert_eq!(mul_div(MAX128 - 10, 2, 1), MAX128 - 10);
364 }
365
366 #[test]
367 fn to_denom_works() {
368 assert_eq!(r(1, 5).to_den(10), Some(r(2, 10)));
370 assert_eq!(r(4, 10).to_den(5), Some(r(2, 5)));
371
372 assert_eq!(r(MAX128 - 10, MAX128).to_den(10), Some(r(10, 10)));
374 assert_eq!(r(MAX128 / 2, MAX128).to_den(10), Some(r(5, 10)));
375
376 assert_eq!(r(MAX128 / 2, MAX128).to_den(1000_000_000), Some(r(500_000_000, 1000_000_000)));
378
379 assert_eq!(r(MAX128 / 2, MAX128).to_den(MAX128 / 2), Some(r(MAX128 / 4, MAX128 / 2)));
381 }
382
383 #[test]
384 fn gdc_works() {
385 assert_eq!(gcd(10, 5), 5);
386 assert_eq!(gcd(7, 22), 1);
387 }
388
389 #[test]
390 fn lcm_works() {
391 assert_eq!(r(3, 10).lcm(&r(4, 15)).unwrap(), 30);
393 assert_eq!(r(5, 30).lcm(&r(1, 7)).unwrap(), 210);
394 assert_eq!(r(5, 30).lcm(&r(1, 10)).unwrap(), 30);
395
396 assert_eq!(r(1_000_000_000, MAX128).lcm(&r(7_000_000_000, MAX128 - 1)), None,);
398 assert_eq!(
399 r(1_000_000_000, MAX64).lcm(&r(7_000_000_000, MAX64 - 1)),
400 Some(340282366920938463408034375210639556610),
401 );
402 const_assert!(340282366920938463408034375210639556610 < MAX128);
403 const_assert!(340282366920938463408034375210639556610 == MAX64 * (MAX64 - 1));
404 }
405
406 #[test]
407 fn add_works() {
408 assert_eq!(r(3, 10).checked_add(r(1, 10)).unwrap(), r(2, 5));
410 assert_eq!(r(3, 10).checked_add(r(3, 7)).unwrap(), r(51, 70));
411
412 assert_eq!(
414 r(1, MAX128).checked_add(r(1, MAX128 - 1)),
415 Err("failed to scale to denominator"),
416 );
417 assert_eq!(
418 r(7, MAX128).checked_add(r(MAX128, MAX128)),
419 Err("overflow while adding numerators"),
420 );
421 assert_eq!(
422 r(MAX128, MAX128).checked_add(r(MAX128, MAX128)),
423 Err("overflow while adding numerators"),
424 );
425 }
426
427 #[test]
428 fn sub_works() {
429 assert_eq!(r(3, 10).checked_sub(r(1, 10)).unwrap(), r(1, 5));
431 assert_eq!(r(6, 10).checked_sub(r(3, 7)).unwrap(), r(12, 70));
432
433 assert_eq!(
435 r(2, MAX128).checked_sub(r(1, MAX128 - 1)),
436 Err("failed to scale to denominator"),
437 );
438 assert_eq!(
439 r(7, MAX128).checked_sub(r(MAX128, MAX128)),
440 Err("overflow while subtracting numerators"),
441 );
442 assert_eq!(r(1, 10).checked_sub(r(2, 10)), Err("overflow while subtracting numerators"));
443 }
444
445 #[test]
446 fn ordering_and_eq_works() {
447 assert!(r(1, 2) > r(1, 3));
448 assert!(r(1, 2) > r(2, 6));
449
450 assert!(r(1, 2) < r(6, 6));
451 assert!(r(2, 1) > r(2, 6));
452
453 assert!(r(5, 10) == r(1, 2));
454 assert!(r(1, 2) == r(1, 2));
455
456 assert!(r(1, 1490000000000200000) > r(1, 1490000000000200001));
457 }
458
459 #[test]
460 fn multiply_by_rational_with_rounding_works() {
461 assert_eq!(multiply_by_rational_with_rounding(7, 2, 3, Rounding::Down).unwrap(), 7 * 2 / 3);
462 assert_eq!(
463 multiply_by_rational_with_rounding(7, 20, 30, Rounding::Down).unwrap(),
464 7 * 2 / 3
465 );
466 assert_eq!(
467 multiply_by_rational_with_rounding(20, 7, 30, Rounding::Down).unwrap(),
468 7 * 2 / 3
469 );
470
471 assert_eq!(
472 multiply_by_rational_with_rounding(MAX128, 2, 3, Rounding::Down).unwrap(),
474 MAX128 / 3 * 2,
475 );
476 assert_eq!(
477 multiply_by_rational_with_rounding(MAX128, 5, 7, Rounding::Down).unwrap(),
479 (MAX128 / 7 * 5) + (3 * 5 / 7),
480 );
481 assert_eq!(
482 multiply_by_rational_with_rounding(MAX128, 11, 13, Rounding::Down).unwrap(),
484 (MAX128 / 13 * 11) + (8 * 11 / 13),
485 );
486 assert_eq!(
487 multiply_by_rational_with_rounding(MAX128, 555, 1000, Rounding::Down).unwrap(),
489 (MAX128 / 1000 * 555) + (455 * 555 / 1000),
490 );
491
492 assert_eq!(
493 multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64, MAX64, Rounding::Down)
494 .unwrap(),
495 2 * MAX64 - 1
496 );
497 assert_eq!(
498 multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64 - 1, MAX64, Rounding::Down)
499 .unwrap(),
500 2 * MAX64 - 3
501 );
502
503 assert_eq!(
504 multiply_by_rational_with_rounding(MAX64 + 100, MAX64_2, MAX64_2 / 2, Rounding::Down)
505 .unwrap(),
506 (MAX64 + 100) * 2,
507 );
508 assert_eq!(
509 multiply_by_rational_with_rounding(
510 MAX64 + 100,
511 MAX64_2 / 100,
512 MAX64_2 / 200,
513 Rounding::Down
514 )
515 .unwrap(),
516 (MAX64 + 100) * 2,
517 );
518
519 assert_eq!(
520 multiply_by_rational_with_rounding(
521 2u128.pow(66) - 1,
522 2u128.pow(65) - 1,
523 2u128.pow(65),
524 Rounding::Down
525 )
526 .unwrap(),
527 73786976294838206461,
528 );
529 assert_eq!(
530 multiply_by_rational_with_rounding(1_000_000_000, MAX128 / 8, MAX128 / 2, Rounding::Up)
531 .unwrap(),
532 250000000
533 );
534
535 assert_eq!(
536 multiply_by_rational_with_rounding(
537 29459999999999999988000u128,
538 1000000000000000000u128,
539 10000000000000000000u128,
540 Rounding::Down
541 )
542 .unwrap(),
543 2945999999999999998800u128
544 );
545 }
546
547 #[test]
548 fn multiply_by_rational_with_rounding_a_b_are_interchangeable() {
549 assert_eq!(
550 multiply_by_rational_with_rounding(10, MAX128, MAX128 / 2, Rounding::NearestPrefDown),
551 Some(20)
552 );
553 assert_eq!(
554 multiply_by_rational_with_rounding(MAX128, 10, MAX128 / 2, Rounding::NearestPrefDown),
555 Some(20)
556 );
557 }
558
559 #[test]
560 #[ignore]
561 fn multiply_by_rational_with_rounding_fuzzed_equation() {
562 assert_eq!(
563 multiply_by_rational_with_rounding(
564 154742576605164960401588224,
565 9223376310179529214,
566 549756068598,
567 Rounding::NearestPrefDown
568 ),
569 Some(2596149632101417846585204209223679)
570 );
571 }
572}