1use std::ops::Add;
2
3#[derive(Clone, Copy)]
4pub enum Operation {
5 Add,
6 Sub,
7 Mul,
8 Div,
9 Sqrt
10}
11
12#[derive(Clone, Copy)]
13pub struct ArithmeticError<T> {
14 pub left: T,
15 pub right: T,
16 pub op: Operation,
17}
18
19#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct Safe<T>(T);
21
22#[derive(Clone, Copy)]
23pub struct SafeResult<T>(Result<Safe<T>, ArithmeticError<T>>);
24
25pub trait IsSafe {}
27impl<T> IsSafe for Safe<T> {}
28
29pub trait SafeNum<T>: IsSafe + Add<T> + Add<Safe<T>> + Add<SafeResult<T>> {}
31impl<T> SafeNum<T> for T
32 where T: IsSafe
33 + Add<T>
34 + Add<Safe<T>>
35 + Add<SafeResult<T>>
36{}
37
38impl Add<u64> for Safe<u64> {
39 type Output = SafeResult<u64>;
40
41 fn add(self, other: u64) -> Self::Output {
42 match self.0.checked_add(other) {
43 Some(r) => SafeResult(Ok(Safe(r))),
44 None => SafeResult(Err(
45 ArithmeticError {
46 left: self.0,
47 right: other,
48 op: Operation::Add,
49 }
50 ))
51 }
52 }
53}
54
55impl Add for Safe<u64> {
56 type Output = SafeResult<u64>;
57
58 fn add(self, other: Safe<u64>) -> Self::Output {
59 self + other.0
60 }
61}
62
63impl Add<SafeResult<u64>> for Safe<u64> {
64 type Output = SafeResult<u64>;
65
66 fn add(self, other: SafeResult<u64>) -> Self::Output {
67 match other.0 {
68 Ok(other) => self + other,
69 Err(e) => SafeResult(Err(e))
70 }
71 }
72}
73
74impl Add<u64> for SafeResult<u64> {
75 type Output = SafeResult<u64>;
76
77 fn add(self, other: u64) -> Self::Output {
78 match self.0 {
79 Ok(this) => this + other,
80 Err(e) => SafeResult(Err(e))
81 }
82 }
83}
84
85impl Add<Safe<u64>> for SafeResult<u64> {
86 type Output = SafeResult<u64>;
87
88 fn add(self, other: Safe<u64>) -> Self::Output {
89 match self.0 {
90 Ok(this) => this + other,
91 Err(e) => SafeResult(Err(e))
92 }
93 }
94}
95
96impl Add for SafeResult<u64> {
97 type Output = SafeResult<u64>;
98
99 fn add(self, other: SafeResult<u64>) -> Self::Output {
100 match self.0 {
101 Ok(this) => this + other,
102 Err(e) => SafeResult(Err(e))
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 fn vector_length(x: Safe<u64>, y: Safe<u64>) -> SafeResult<u64> {
112 (x + y) + (x + y)
113 }
114
115 #[test]
116 fn test_vector_length() {
117 let r = vector_length(Safe(4u64), Safe(5u64));
118 }
119}