1use std::cell::RefCell;
11
12use num::{FromPrimitive, Integer, One, Zero};
13use rust_fixed_point_decimal_core::ten_pow;
14
15use crate::{
16 prec_constraints::{PrecLimitCheck, True},
17 Decimal, MAX_PREC,
18};
19
20#[derive(Clone, Copy, Debug, PartialEq)]
23pub enum RoundingMode {
24 Round05Up,
27 RoundCeiling,
29 RoundDown,
31 RoundFloor,
33 RoundHalfDown,
35 RoundHalfEven,
37 RoundHalfUp,
39 RoundUp,
41}
42
43thread_local!(
44 static DFLT_ROUNDING_MODE: RefCell<RoundingMode> =
45 RefCell::new(RoundingMode::RoundHalfEven)
46);
47
48impl Default for RoundingMode {
49 fn default() -> Self {
51 DFLT_ROUNDING_MODE.with(|m| *m.borrow())
52 }
53}
54
55impl RoundingMode {
56 pub fn set_default(mode: RoundingMode) {
58 DFLT_ROUNDING_MODE.with(|m| *m.borrow_mut() = mode)
59 }
60}
61
62pub trait Round
65where
66 Self: Sized,
67{
68 fn round(self: Self, n_frac_digits: i8) -> Self;
71
72 fn checked_round(self: Self, n_frac_digits: i8) -> Option<Self>;
77}
78
79impl<const P: u8> Round for Decimal<P>
80where
81 PrecLimitCheck<{ P <= MAX_PREC }>: True,
82{
83 fn round(self, n_frac_digits: i8) -> Self {
104 if n_frac_digits >= P as i8 {
105 self.clone()
106 } else if n_frac_digits < P as i8 - 38 {
107 Self::ZERO
108 } else {
109 let shift: u8 = (P as i8 - n_frac_digits) as u8;
111 let divisor = ten_pow(shift);
112 Self::new_raw(div_i128_rounded(self.coeff, divisor, None) * divisor)
113 }
114 }
115
116 fn checked_round(self, n_frac_digits: i8) -> Option<Self> {
138 if n_frac_digits >= P as i8 {
139 Some(self.clone())
140 } else if n_frac_digits < P as i8 - 38 {
141 Some(Self::ZERO)
142 } else {
143 let shift: u8 = (P as i8 - n_frac_digits) as u8;
145 let divisor = ten_pow(shift);
146 match div_i128_rounded(self.coeff, divisor, None)
147 .checked_mul(divisor)
148 {
149 None => None,
150 Some(coeff) => Some(Self::new_raw(coeff)),
151 }
152 }
153 }
154}
155
156pub trait RoundInto<T>
158where
159 Self: Sized,
160 T: Sized,
161{
162 fn round_into(self: Self) -> T;
165}
166
167impl<const P: u8> RoundInto<i128> for Decimal<P>
168where
169 PrecLimitCheck<{ P <= MAX_PREC }>: True,
170{
171 #[inline(always)]
187 fn round_into(self: Self) -> i128 {
188 div_i128_rounded(self.coeff, ten_pow(P), None)
189 }
190}
191
192impl<const P: u8, const Q: u8> RoundInto<Decimal<Q>> for Decimal<P>
193where
194 PrecLimitCheck<{ P <= MAX_PREC }>: True,
195 PrecLimitCheck<{ Q <= MAX_PREC }>: True,
196 PrecLimitCheck<{ Q < P }>: True,
197{
198 #[inline(always)]
219 fn round_into(self: Self) -> Decimal<Q> {
220 Decimal::<Q> {
221 coeff: div_i128_rounded(self.coeff, ten_pow(P - Q), None),
222 }
223 }
224}
225
226pub(crate) fn div_i128_rounded(
230 mut divident: i128,
231 mut divisor: i128,
232 mode: Option<RoundingMode>,
233) -> i128 {
234 let zero = i128::zero();
235 let one = i128::one();
236 let five = i128::from_u8(5).unwrap();
237 if divisor < 0 {
238 divident = -divident;
239 divisor = -divisor;
240 }
241 let (quot, rem) = divident.div_mod_floor(&divisor);
242 if rem == zero {
244 return quot;
246 }
247 let mode = match mode {
250 None => RoundingMode::default(),
251 Some(mode) => mode,
252 };
253 match mode {
254 RoundingMode::Round05Up => {
255 if quot >= zero && quot % five == zero
260 || quot < zero && (quot + one) % five != zero
261 {
262 return quot + one;
263 }
264 }
265 RoundingMode::RoundCeiling => {
266 return quot + one;
269 }
270 RoundingMode::RoundDown => {
271 if quot < zero {
274 return quot + one;
275 }
276 }
277 RoundingMode::RoundFloor => {
278 return quot;
281 }
282 RoundingMode::RoundHalfDown => {
283 let rem_doubled = rem << 1;
288 if rem_doubled > divisor || rem_doubled == divisor && quot < zero {
289 return quot + one;
290 }
291 }
292 RoundingMode::RoundHalfEven => {
293 let rem_doubled = rem << 1;
298 if rem_doubled > divisor
299 || rem_doubled == divisor && !quot.is_even()
300 {
301 return quot + one;
302 }
303 }
304 RoundingMode::RoundHalfUp => {
305 let rem_doubled = rem << 1;
310 if rem_doubled > divisor || rem_doubled == divisor && quot >= zero {
311 return quot + one;
312 }
313 }
314 RoundingMode::RoundUp => {
315 if quot >= zero {
318 return quot + one;
319 }
320 }
321 }
322 quot
324}
325
326#[cfg(test)]
327mod rounding_mode_tests {
328 use super::*;
329
330 #[test]
331 fn test1() {
332 assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
333 RoundingMode::set_default(RoundingMode::RoundUp);
334 assert_eq!(RoundingMode::default(), RoundingMode::RoundUp);
335 RoundingMode::set_default(RoundingMode::RoundHalfEven);
336 assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
337 }
338
339 #[test]
340 fn test2() {
341 assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
342 RoundingMode::set_default(RoundingMode::RoundHalfUp);
343 assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfUp);
344 RoundingMode::set_default(RoundingMode::RoundHalfEven);
345 assert_eq!(RoundingMode::default(), RoundingMode::RoundHalfEven);
346 }
347}
348
349#[cfg(test)]
350mod helper_tests {
351 use super::*;
352
353 const TESTDATA: [(i128, i128, RoundingMode, i128); 34] = [
354 (17, 5, RoundingMode::Round05Up, 3),
355 (27, 5, RoundingMode::Round05Up, 6),
356 (-17, 5, RoundingMode::Round05Up, -3),
357 (-27, 5, RoundingMode::Round05Up, -6),
358 (17, 5, RoundingMode::RoundCeiling, 4),
359 (15, 5, RoundingMode::RoundCeiling, 3),
360 (-17, 5, RoundingMode::RoundCeiling, -3),
361 (-15, 5, RoundingMode::RoundCeiling, -3),
362 (19, 5, RoundingMode::RoundDown, 3),
363 (15, 5, RoundingMode::RoundDown, 3),
364 (-18, 5, RoundingMode::RoundDown, -3),
365 (-15, 5, RoundingMode::RoundDown, -3),
366 (19, 5, RoundingMode::RoundFloor, 3),
367 (15, 5, RoundingMode::RoundFloor, 3),
368 (-18, 5, RoundingMode::RoundFloor, -4),
369 (-15, 5, RoundingMode::RoundFloor, -3),
370 (19, 2, RoundingMode::RoundHalfDown, 9),
371 (15, 4, RoundingMode::RoundHalfDown, 4),
372 (-19, 2, RoundingMode::RoundHalfDown, -9),
373 (-15, 4, RoundingMode::RoundHalfDown, -4),
374 (19, 2, RoundingMode::RoundHalfEven, 10),
375 (15, 4, RoundingMode::RoundHalfEven, 4),
376 (-225, 50, RoundingMode::RoundHalfEven, -4),
377 (-15, 4, RoundingMode::RoundHalfEven, -4),
378 (
379 u64::MAX as i128,
380 i64::MIN as i128 * 10,
381 RoundingMode::RoundHalfEven,
382 0,
383 ),
384 (19, 2, RoundingMode::RoundHalfUp, 10),
385 (10802, 4321, RoundingMode::RoundHalfUp, 2),
386 (-19, 2, RoundingMode::RoundHalfUp, -10),
387 (-10802, 4321, RoundingMode::RoundHalfUp, -2),
388 (19, 2, RoundingMode::RoundUp, 10),
389 (10802, 4321, RoundingMode::RoundUp, 3),
390 (-19, 2, RoundingMode::RoundUp, -10),
391 (-10802, 4321, RoundingMode::RoundUp, -3),
392 (i32::MAX as i128, 1, RoundingMode::RoundUp, i32::MAX as i128),
393 ];
394 #[test]
395 fn test_div_rounded() {
396 for (divident, divisor, rnd_mode, result) in TESTDATA {
397 let quot = div_i128_rounded(divident, divisor, Some(rnd_mode));
398 assert_eq!(quot, result);
400 }
401 }
402}
403
404#[cfg(test)]
405mod round_decimal_tests {
406 use super::*;
407
408 macro_rules! test_decimal_round_no_op {
409 ($(($p:expr, $func:ident)),*) => {
410 $(
411 #[test]
412 fn $func() {
413 let x = Decimal::<$p>::MIN;
414 let y = x.round($p);
415 assert_eq!(x.coeff, y.coeff);
416 let y = x.round($p + 2);
417 assert_eq!(x.coeff, y.coeff);
418 let y = x.checked_round($p).unwrap();
419 assert_eq!(x.coeff, y.coeff);
420 let y = x.checked_round($p + 2).unwrap();
421 assert_eq!(x.coeff, y.coeff);
422 }
423 )*
424 }
425 }
426
427 test_decimal_round_no_op!(
428 (0, test_decimal0_round_no_op),
429 (1, test_decimal1_round_no_op),
430 (2, test_decimal2_round_no_op),
431 (3, test_decimal3_round_no_op),
432 (4, test_decimal4_round_no_op),
433 (5, test_decimal5_round_no_op),
434 (6, test_decimal6_round_no_op),
435 (7, test_decimal7_round_no_op),
436 (8, test_decimal8_round_no_op),
437 (9, test_decimal9_round_no_op)
438 );
439
440 macro_rules! test_decimal_round_result_zero {
441 ($(($p:expr, $func:ident)),*) => {
442 $(
443 #[test]
444 fn $func() {
445 let x = Decimal::<$p>::MIN;
446 let y = x.round($p - 39);
447 assert_eq!(y.coeff, 0);
448 let y = x.round($p - 42);
449 assert_eq!(y.coeff, 0);
450 let y = x.checked_round($p - 39).unwrap();
451 assert_eq!(y.coeff, 0);
452 let y = x.checked_round($p - 42).unwrap();
453 assert_eq!(y.coeff, 0);
454 }
455 )*
456 }
457 }
458
459 test_decimal_round_result_zero!(
460 (0, test_decimal0_round_result_zero),
461 (1, test_decimal1_round_result_zero),
462 (2, test_decimal2_round_result_zero),
463 (3, test_decimal3_round_result_zero),
464 (4, test_decimal4_round_result_zero),
465 (5, test_decimal5_round_result_zero),
466 (6, test_decimal6_round_result_zero),
467 (7, test_decimal7_round_result_zero),
468 (8, test_decimal8_round_result_zero),
469 (9, test_decimal9_round_result_zero)
470 );
471
472 #[test]
473 fn test_decimal_round() {
474 let d = Decimal::<0>::new_raw(12345);
475 assert_eq!(d.round(-1).coeff, 12340);
476 let d = Decimal::<0>::new_raw(1285);
477 assert_eq!(d.round(-2).coeff, 1300);
478 let d = Decimal::<1>::new_raw(12345);
479 assert_eq!(d.round(0).coeff, 12340);
480 let d = Decimal::<2>::new_raw(1285);
481 assert_eq!(d.round(0).coeff, 1300);
482 let d = Decimal::<7>::new_raw(12345678909876543);
483 assert_eq!(d.round(0).coeff, 12345678910000000);
484 let d = Decimal::<9>::new_raw(123455);
485 assert_eq!(d.round(8).coeff, 123460);
486 }
487
488 #[test]
489 #[should_panic]
490 fn test_decimal_round_overflow() {
491 let d = Decimal::<8>::MAX;
492 let _ = d.round(0);
493 }
494
495 #[test]
496 fn test_decimal_checked_round() {
497 let d = Decimal::<0>::new_raw(12345);
498 assert_eq!(d.checked_round(-1).unwrap().coeff, 12340);
499 let d = Decimal::<0>::new_raw(1285);
500 assert_eq!(d.checked_round(-2).unwrap().coeff, 1300);
501 let d = Decimal::<1>::new_raw(12345);
502 assert_eq!(d.checked_round(0).unwrap().coeff, 12340);
503 let d = Decimal::<2>::new_raw(1285);
504 assert_eq!(d.checked_round(0).unwrap().coeff, 1300);
505 let d = Decimal::<7>::new_raw(12345678909876543);
506 assert_eq!(d.checked_round(0).unwrap().coeff, 12345678910000000);
507 let d = Decimal::<9>::new_raw(123455);
508 assert_eq!(d.checked_round(8).unwrap().coeff, 123460);
509 let d = Decimal::<0>::MAX;
510 let res = d.checked_round(-1);
511 assert!(res.is_none());
512 let d = Decimal::<7>::MAX;
513 let res = d.checked_round(4);
514 assert!(res.is_none());
515 }
516
517 #[test]
518 fn test_round_into_i128() {
519 let d = Decimal::<4>::new_raw(12345000);
520 let i: i128 = d.round_into();
521 assert_eq!(i, 1234);
522 let d = Decimal::<4>::new_raw(12345678);
523 let i: i128 = d.round_into();
524 assert_eq!(i, 1235);
525 let d = Decimal::<2>::new_raw(12345678);
526 let i: i128 = d.round_into();
527 assert_eq!(i, 123457);
528 }
529
530 #[test]
531 fn test_round_into_decimal() {
532 let d = Decimal::<4>::new_raw(12345000);
533 let r: Decimal<0> = d.round_into();
534 assert_eq!(r.coeff, 1234);
535 let d = Decimal::<4>::new_raw(12345678);
536 let r: Decimal<0> = d.round_into();
537 assert_eq!(r.coeff, 1235);
538 let d = Decimal::<4>::new_raw(12345678);
539 let r: Decimal<2> = d.round_into();
540 assert_eq!(r.coeff, 123457);
541 let d = Decimal::<7>::MAX; let r: Decimal<2> = d.round_into();
543 assert_eq!(r.coeff, 1701411834604692317316873037158841_i128);
544 }
545}