safe_arithmetic/ops/
checked_mul.rs

1use crate::error::{Overflow, Underflow};
2use std::fmt::{self, Debug, Display};
3
4pub trait CheckedMul<Rhs = Self>
5where
6    Self: Sized,
7{
8    type Output;
9    type Error;
10
11    /// Checked arithmetic multiplication with self
12    ///
13    /// # Errors
14    /// When the result of the multiplication can not be represented (e.g. due to an overflow).
15    fn checked_mul(self, scalar: Rhs) -> Result<Self::Output, Self::Error>;
16}
17
18macro_rules! impl_unsigned_checked_mul {
19    ( $T:ty ) => {
20        impl CheckedMul for $T {
21            type Output = Self;
22            type Error = MulError<Self, Self>;
23
24            fn checked_mul(self, rhs: Self) -> Result<Self::Output, Self::Error> {
25                num::CheckedMul::checked_mul(&self, &rhs)
26                    .ok_or(rhs.overflows(self))
27                    .map_err(MulError)
28            }
29        }
30    };
31}
32
33impl_unsigned_checked_mul!(u32);
34impl_unsigned_checked_mul!(u64);
35
36macro_rules! impl_signed_checked_mul {
37    ( $T:ty ) => {
38        impl CheckedMul for $T {
39            type Output = Self;
40            type Error = MulError<Self, Self>;
41
42            fn checked_mul(self, rhs: Self) -> Result<Self::Output, Self::Error> {
43                if self.signum() == rhs.signum() {
44                    num::CheckedMul::checked_mul(&self, &rhs)
45                        .ok_or(rhs.overflows(self))
46                        .map_err(MulError)
47                } else {
48                    num::CheckedMul::checked_mul(&self, &rhs)
49                        .ok_or(rhs.underflows(self))
50                        .map_err(MulError)
51                }
52            }
53        }
54    };
55}
56
57impl_signed_checked_mul!(i64);
58
59macro_rules! impl_float_checked_mul {
60    ( $T:ty ) => {
61        impl CheckedMul for $T {
62            type Output = Self;
63            type Error = MulError<Self, Self>;
64
65            fn checked_mul(self, rhs: Self) -> Result<Self::Output, Self::Error> {
66                let result = self * rhs;
67                if result.is_nan() && self.signum() == rhs.signum() {
68                    // overflow
69                    Err(MulError(rhs.overflows(self)))
70                } else if result.is_nan() {
71                    // underflow
72                    Err(MulError(rhs.underflows(self)))
73                } else {
74                    Ok(result)
75                }
76            }
77        }
78    };
79}
80
81impl_float_checked_mul!(f32);
82impl_float_checked_mul!(f64);
83
84#[derive(PartialEq, Clone, Debug)]
85#[allow(clippy::module_name_repetitions)]
86pub struct MulError<Lhs, Rhs>(pub crate::error::Operation<Lhs, Rhs>);
87
88impl<Lhs, Rhs> crate::error::Arithmetic for MulError<Lhs, Rhs>
89where
90    Lhs: crate::Type,
91    Rhs: crate::Type,
92{
93}
94
95impl<Lhs, Rhs> std::error::Error for MulError<Lhs, Rhs>
96where
97    Lhs: Display + Debug,
98    Rhs: Display + Debug,
99{
100    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
101        self.0.cause.as_deref().map(crate::error::AsErr::as_err)
102    }
103}
104
105impl<Lhs, Rhs> Display for MulError<Lhs, Rhs>
106where
107    Lhs: Display,
108    Rhs: Display,
109{
110    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
111        match self.0.kind {
112            Some(kind) => write!(
113                f,
114                "multiplying {} by {} would {} {}",
115                self.0.lhs,
116                self.0.rhs,
117                kind,
118                std::any::type_name::<Lhs>(),
119            ),
120            None => write!(f, "cannot multiply {} by {}", self.0.lhs, self.0.rhs,),
121        }
122    }
123}