safe_num/
lib.rs

1use std::ops::Add;
2
3#[derive(Clone, Copy)]
4pub enum Operation {
5    Add,
6    Sub,
7    Mul,
8    Div,
9    Sqrt
10}
11
12#[derive(Clone, Copy)]
13pub struct ArithmeticError<T> {
14    pub left: T,
15    pub right: T,
16    pub op: Operation,
17}
18
19#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct Safe<T>(T);
21
22#[derive(Clone, Copy)]
23pub struct SafeResult<T>(Result<Safe<T>, ArithmeticError<T>>);
24
25/// Marker trait to identify a Safe type
26pub trait IsSafe {}
27impl<T> IsSafe for Safe<T> {}
28
29/// Trait with all operations over a safe number
30pub trait SafeNum<T>: IsSafe + Add<T> + Add<Safe<T>> + Add<SafeResult<T>> {}
31impl<T> SafeNum<T> for T
32    where T: IsSafe
33        + Add<T>
34        + Add<Safe<T>>
35        + Add<SafeResult<T>>
36{}
37
38impl Add<u64> for Safe<u64> {
39    type Output = SafeResult<u64>;
40
41    fn add(self, other: u64) -> Self::Output {
42        match self.0.checked_add(other) {
43            Some(r) => SafeResult(Ok(Safe(r))),
44            None => SafeResult(Err(
45                ArithmeticError {
46                    left: self.0,
47                    right: other,
48                    op: Operation::Add,
49                }
50            ))
51        }
52    }
53}
54
55impl Add for Safe<u64> {
56    type Output = SafeResult<u64>;
57
58    fn add(self, other: Safe<u64>) -> Self::Output {
59        self + other.0
60    }
61}
62
63impl Add<SafeResult<u64>> for Safe<u64> {
64    type Output = SafeResult<u64>;
65
66    fn add(self, other: SafeResult<u64>) -> Self::Output {
67        match other.0 {
68            Ok(other) => self + other,
69            Err(e) => SafeResult(Err(e))
70        }
71    }
72}
73
74impl Add<u64> for SafeResult<u64> {
75    type Output = SafeResult<u64>;
76
77    fn add(self, other: u64) -> Self::Output {
78        match self.0 {
79            Ok(this) => this + other,
80            Err(e) => SafeResult(Err(e))
81        }
82    }
83}
84
85impl Add<Safe<u64>> for SafeResult<u64> {
86    type Output = SafeResult<u64>;
87
88    fn add(self, other: Safe<u64>) -> Self::Output {
89        match self.0 {
90            Ok(this) => this + other,
91            Err(e) => SafeResult(Err(e))
92        }
93    }
94}
95
96impl Add for SafeResult<u64> {
97    type Output = SafeResult<u64>;
98
99    fn add(self, other: SafeResult<u64>) -> Self::Output {
100        match self.0 {
101            Ok(this) => this + other,
102            Err(e) => SafeResult(Err(e))
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    fn vector_length(x: Safe<u64>, y: Safe<u64>) -> SafeResult<u64> {
112        (x + y) + (x + y)
113    }
114
115    #[test]
116    fn test_vector_length() {
117        let r = vector_length(Safe(4u64), Safe(5u64));
118    }
119}