1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
use num_rational::Ratio;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::ops::Mul;

/// A probability is a [rational number (ℚ)](https://en.wikipedia.org/wiki/Rational_number)
/// in the range of 0 and 1 (both inclusive).
///
/// In other words: if you got a variable `p` from type `Probability`
/// you can be sure about the following: `0 <= p && p <= 1`.
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Probability {
    ratio: Ratio<u64>,
}

impl Display for Probability {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.ratio)
    }
}

impl Default for Probability {
    fn default() -> Self {
        PROBABILITY_ZERO
    }
}

impl Mul<Probability> for Probability {
    type Output = Probability;

    fn mul(self, rhs: Probability) -> Self::Output {
        // Multiplication between two probabilities is always safe to be in bounds!
        Self {
            ratio: self.ratio * rhs.ratio,
        }
    }
}

impl From<Ratio<u64>> for Probability {
    /// Creates a new `Probability` from the given ratio.
    ///
    /// # Panics
    ///
    /// - if ratio > 1 ⇒ value out of bounds!
    fn from(ratio: Ratio<u64>) -> Self {
        assert!(
            ratio <= Ratio::from(1),
            "ratio is not in the bounds of 0 and 1"
        );

        Self { ratio }
    }
}

impl Probability {
    /// Creates a new `Probability`.
    ///
    /// For a safer method (panic-free), please consider using: [`Probability::try_new`].
    ///
    /// # Panics
    ///
    /// - if numerator > denominator ⇒ ratio > 1 ⇒ value out of bounds!
    /// - if denominator == 0 ⇒ impossible value!
    ///
    /// # Example
    ///
    /// ```
    /// use num_rational::Ratio;
    /// use stochasta::Probability;
    ///
    /// let p = Probability::new(1, 3);
    /// assert_eq!(p.ratio(), &Ratio::new(1, 3));
    /// ```
    #[must_use]
    pub fn new(numerator: u64, denominator: u64) -> Self {
        Self::from(Ratio::new(numerator, denominator))
    }

    /// Tries to create a new `Probability` from the given ratio.
    ///
    /// # Errors
    ///
    /// - ratio > 1 => value out of bounds!
    ///
    /// ```
    /// use num_rational::Ratio;
    /// use stochasta::{Probability, ProbabilityRatioError};
    ///
    /// assert!(Probability::try_new(1, 2).is_ok());
    /// assert_eq!(Probability::try_new(1, 0), Err(ProbabilityRatioError::DenominatorZero));
    /// assert_eq!(Probability::try_new(2, 1), Err(ProbabilityRatioError::RatioGreaterOne));
    /// ```
    pub fn try_new(numerator: u64, denominator: u64) -> Result<Self, ProbabilityRatioError> {
        if denominator == 0 {
            Err(ProbabilityRatioError::DenominatorZero)
        } else if numerator > denominator {
            Err(ProbabilityRatioError::RatioGreaterOne)
        } else {
            Ok(Self {
                ratio: Ratio::new(numerator, denominator),
            })
        }
    }

    /// Returns the inner ratio
    #[must_use]
    pub fn ratio(&self) -> &Ratio<u64> {
        &self.ratio
    }

    /// Returns the complementary probability: `1 - self`.
    ///
    /// # Example
    ///
    /// ```
    /// use stochasta::Probability;
    ///
    /// let one_third = Probability::new(1, 3);
    /// let two_third = one_third.complementary();
    /// assert_eq!(two_third, Probability::new(2, 3));
    /// ```
    #[must_use]
    pub fn complementary(&self) -> Self {
        Self {
            ratio: RATIO_ONE - self.ratio,
        }
    }
}

/// A probability of 0%.
///
/// An event with the same probability **must never occur**.
pub const PROBABILITY_ZERO: Probability = Probability { ratio: RATIO_ZERO };

const RATIO_ZERO: Ratio<u64> = Ratio::new_raw(0, 1);

/// A probability of 100%.
///
/// An event with the same probability **must occur.**
pub const PROBABILITY_ONE: Probability = Probability { ratio: RATIO_ONE };

const RATIO_ONE: Ratio<u64> = Ratio::new_raw(1, 1);

/// Errors that may happen when trying to create a probability.
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ProbabilityRatioError {
    /// The denominator must not be 0. That's a basic math rule!
    DenominatorZero,
    /// The ratio of `Probability` cannot be lower than 0.
    RatioLowerZero,
    /// The ratio of `Probability` cannot be higher than 1.
    RatioGreaterOne,
}

impl Display for ProbabilityRatioError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{}",
            match self {
                ProbabilityRatioError::DenominatorZero => "The denominator must not be 0.",
                ProbabilityRatioError::RatioLowerZero =>
                    "The ratio of `Probability` cannot be lower than 0.",
                ProbabilityRatioError::RatioGreaterOne =>
                    "The ratio of `Probability` cannot be higher than 1.",
            }
        )
    }
}

impl Error for ProbabilityRatioError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        None
    }

    fn description(&self) -> &str {
        "description() is deprecated; use Display"
    }

    fn cause(&self) -> Option<&dyn Error> {
        self.source()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use num_rational::Ratio;

    #[test]
    fn constants() {
        assert_eq!(PROBABILITY_ZERO.ratio(), &Ratio::new(0, 1));
        assert_eq!(PROBABILITY_ONE.ratio(), &Ratio::new(1, 1));
    }

    #[test]
    fn new_standard() {
        assert_eq!(Probability::new(0, 2).ratio(), &Ratio::new(0, 1));
        assert_eq!(Probability::new(1, 2).ratio(), &Ratio::new(1, 2));
        assert_eq!(Probability::new(2, 2).ratio(), &Ratio::new(1, 1));
    }

    #[test]
    #[should_panic]
    fn new_out_of_bounds() {
        let _ = Probability::new(2, 1);
    }

    #[test]
    #[should_panic]
    fn new_zero_denominator() {
        let _ = Probability::new(1, 0);
    }

    #[test]
    fn from_ratio_standard() {
        assert_eq!(
            Probability::from(Ratio::new(0, 7)).ratio(),
            &Ratio::new(0, 1)
        );
        assert_eq!(
            Probability::from(Ratio::new(4, 9)).ratio(),
            &Ratio::new(4, 9)
        );
        assert_eq!(
            Probability::from(Ratio::new(9, 9)).ratio(),
            &Ratio::new(1, 1)
        );
    }

    #[test]
    #[should_panic]
    fn from_ratio_out_of_bounds() {
        let _ = Probability::from(Ratio::new(2, 1));
    }

    #[test]
    #[should_panic]
    fn from_ratio_zero_denominator() {
        let _ = Probability::from(Ratio::new(1, 0));
    }

    #[test]
    fn derive_copy() {
        let x = Probability::new(1, 3);
        let y = x;
        assert_eq!(x, y);
    }

    #[test]
    fn derive_ord() {
        let one_over_three = Probability::new(1, 3);
        let four_over_seven = Probability::new(4, 7);
        let eight_over_nine = Probability::new(8, 9);
        assert!(one_over_three < four_over_seven);
        assert!(four_over_seven < eight_over_nine);
        assert!(one_over_three < eight_over_nine);
    }

    #[test]
    fn derive_eq() {
        let one_over_four = Probability::new(1, 4);
        let two_over_eight = Probability::new(2, 8);
        assert_eq!(one_over_four, two_over_eight);
        assert_ne!(one_over_four, PROBABILITY_ZERO);
        assert_ne!(one_over_four, PROBABILITY_ONE);
    }
}