safe_arithmetic/
cast.rs

1use std::marker::PhantomData;
2
3pub trait Cast
4where
5    Self: Sized + num::ToPrimitive + Copy,
6{
7    /// Checked cast of self to `Target`.
8    ///
9    /// # Errors
10    /// When `self` can not be casted to `Target`.
11    fn cast<Target>(self) -> Result<Target, CastError<Self, Target>>
12    where
13        Target: num::NumCast;
14}
15
16impl<Src> Cast for Src
17where
18    Self: Sized + num::ToPrimitive + Copy,
19{
20    fn cast<Target>(self) -> Result<Target, CastError<Self, Target>>
21    where
22        Target: num::NumCast,
23    {
24        num::NumCast::from(self).ok_or(CastError {
25            src: self,
26            target: PhantomData,
27            cause: None,
28        })
29    }
30}
31
32#[derive(PartialEq, Clone)]
33#[allow(clippy::module_name_repetitions)]
34pub struct CastError<Src, Target> {
35    pub src: Src,
36    pub target: PhantomData<Target>,
37    pub cause: Option<crate::error::Error>,
38}
39
40impl<Src, Target> crate::error::Arithmetic for CastError<Src, Target>
41where
42    Src: crate::Type,
43    Target: crate::Type,
44{
45}
46
47impl<Src, Target> std::error::Error for CastError<Src, Target>
48where
49    Src: std::fmt::Debug + std::fmt::Display,
50{
51    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
52        self.cause.as_deref().map(crate::error::AsErr::as_err)
53    }
54}
55
56impl<Src, Target> std::fmt::Debug for CastError<Src, Target>
57where
58    Src: std::fmt::Debug,
59{
60    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
61        f.debug_struct("CastError")
62            .field("src", &self.src)
63            .field("target", &std::any::type_name::<Target>())
64            .field("cause", &self.cause)
65            .finish()
66    }
67}
68
69impl<Src, Target> std::fmt::Display for CastError<Src, Target>
70where
71    Src: std::fmt::Display,
72{
73    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
74        write!(
75            f,
76            "cannot cast {} of type {} to {}",
77            self.src,
78            std::any::type_name::<Src>(),
79            std::any::type_name::<Target>(),
80        )
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use approx::assert_abs_diff_eq;
88    use similar_asserts::assert_eq as sim_assert_eq;
89
90    #[test]
91    fn invalid_num_cast() {
92        sim_assert_eq!(
93            &42_000f64.cast::<i8>().err().unwrap().to_string(),
94            "cannot cast 42000 of type f64 to i8"
95        );
96        sim_assert_eq!(
97            &(-42f64).cast::<u32>().err().unwrap().to_string(),
98            "cannot cast -42 of type f64 to u32"
99        );
100        sim_assert_eq!(
101            &(-42i64).cast::<u32>().err().unwrap().to_string(),
102            "cannot cast -42 of type i64 to u32"
103        );
104        let value = i64::MAX;
105        sim_assert_eq!(
106            &value.cast::<u32>().err().unwrap().to_string(),
107            &format!("cannot cast {} of type i64 to u32", &value)
108        );
109        let value = i64::MIN;
110        sim_assert_eq!(
111            &value.cast::<u64>().err().unwrap().to_string(),
112            &format!("cannot cast {} of type i64 to u64", &value)
113        );
114    }
115
116    #[test]
117    fn valid_num_cast() {
118        sim_assert_eq!(42f64.cast::<f32>().ok(), Some(42f32));
119        sim_assert_eq!(42f32.cast::<f64>().ok(), Some(42f64));
120        sim_assert_eq!(42u64.cast::<f32>().ok(), Some(42f32));
121        sim_assert_eq!(42i64.cast::<f32>().ok(), Some(42f32));
122        sim_assert_eq!(42.1f64.cast::<i8>().ok(), Some(42i8));
123        sim_assert_eq!(42.6f64.cast::<i8>().ok(), Some(42i8));
124        assert!(u32::MAX.cast::<i64>().is_ok());
125        assert!(i64::MAX.cast::<u64>().is_ok());
126        assert!(i128::MAX.cast::<f64>().is_ok());
127        assert!(u128::MAX.cast::<f64>().is_ok());
128        sim_assert_eq!(f32::MAX.cast::<u32>().ok(), None);
129
130        assert_abs_diff_eq!(
131            u32::MAX.cast::<f32>().unwrap(),
132            2f32.powi(32),
133            epsilon = 2.0
134        );
135        assert_abs_diff_eq!(
136            u32::MAX.cast::<f64>().unwrap(),
137            2f64.powi(32),
138            epsilon = 2.0
139        );
140    }
141}