rust_decimal/
rand_0_9.rs

1use crate::Decimal;
2use rand_0_9::{
3    distr::{
4        uniform::{SampleBorrow, SampleUniform, UniformInt, UniformSampler},
5        Distribution, StandardUniform,
6    },
7    Rng,
8};
9
10impl Distribution<Decimal> for StandardUniform {
11    fn sample<R>(&self, rng: &mut R) -> Decimal
12    where
13        R: Rng + ?Sized,
14    {
15        Decimal::from_parts(
16            rng.next_u32(),
17            rng.next_u32(),
18            rng.next_u32(),
19            rng.random(),
20            rng.random_range(0..=Decimal::MAX_SCALE),
21        )
22    }
23}
24
25impl SampleUniform for Decimal {
26    type Sampler = DecimalSampler;
27}
28
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub struct DecimalSampler {
31    mantissa_sampler: UniformInt<i128>,
32    scale: u32,
33}
34
35impl UniformSampler for DecimalSampler {
36    type X = Decimal;
37
38    /// Creates a new sampler that will yield random decimal objects between `low` and `high`.
39    ///
40    /// The sampler will always provide decimals at the same scale as the inputs; if the inputs
41    /// have different scales, the higher scale is used.
42    ///
43    /// # Example
44    ///
45    /// ```
46    /// # use rand_0_9 as rand;
47    /// # use rand::Rng;
48    /// # use rust_decimal_macros::dec;
49    /// let mut rng = rand::rng();
50    /// let random = rng.random_range(dec!(1.00)..dec!(2.00));
51    /// assert!(random >= dec!(1.00));
52    /// assert!(random < dec!(2.00));
53    /// assert_eq!(random.scale(), 2);
54    /// ```
55    #[inline]
56    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
57    where
58        B1: SampleBorrow<Self::X> + Sized,
59        B2: SampleBorrow<Self::X> + Sized,
60    {
61        let (low, high) = sync_scales(*low.borrow(), *high.borrow());
62        let high = Decimal::from_i128_with_scale(high.mantissa() - 1, high.scale());
63        UniformSampler::new_inclusive(low, high)
64    }
65
66    /// Creates a new sampler that will yield random decimal objects between `low` and `high`.
67    ///
68    /// The sampler will always provide decimals at the same scale as the inputs; if the inputs
69    /// have different scales, the higher scale is used.
70    ///
71    /// # Example
72    ///
73    /// ```
74    /// # use rand_0_9 as rand;
75    /// # use rand::Rng;
76    /// # use rust_decimal_macros::dec;
77    /// let mut rng = rand::rng();
78    /// let random = rng.random_range(dec!(1.00)..=dec!(2.00));
79    /// assert!(random >= dec!(1.00));
80    /// assert!(random <= dec!(2.00));
81    /// assert_eq!(random.scale(), 2);
82    /// ```
83    #[inline]
84    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
85    where
86        B1: SampleBorrow<Self::X> + Sized,
87        B2: SampleBorrow<Self::X> + Sized,
88    {
89        let (low, high) = sync_scales(*low.borrow(), *high.borrow());
90
91        // Return our sampler, which contains an underlying i128 sampler so we
92        // outsource the actual randomness implementation.
93        Ok(Self {
94            mantissa_sampler: UniformInt::new_inclusive(low.mantissa(), high.mantissa())?,
95            scale: low.scale(),
96        })
97    }
98
99    #[inline]
100    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
101        let mantissa = self.mantissa_sampler.sample(rng);
102        Decimal::from_i128_with_scale(mantissa, self.scale)
103    }
104}
105
106/// Return equivalent Decimal objects with the same scale as one another.
107#[inline]
108fn sync_scales(mut a: Decimal, mut b: Decimal) -> (Decimal, Decimal) {
109    if a.scale() == b.scale() {
110        return (a, b);
111    }
112
113    // Set scales to match one another, because we are relying on mantissas'
114    // being comparable in order outsource the actual sampling implementation.
115    a.rescale(a.scale().max(b.scale()));
116    b.rescale(a.scale().max(b.scale()));
117
118    // Edge case: If the values have _wildly_ different scales, the values may not have rescaled far enough to match one another.
119    //
120    // In this case, we accept some precision loss because the randomization approach we are using assumes that the scales will necessarily match.
121    if a.scale() != b.scale() {
122        a.rescale(a.scale().min(b.scale()));
123        b.rescale(a.scale().min(b.scale()));
124    }
125
126    (a, b)
127}
128
129#[cfg(test)]
130mod rand_tests {
131    use std::collections::HashSet;
132
133    use super::*;
134
135    macro_rules! dec {
136        ($e:expr) => {
137            Decimal::from_str_exact(stringify!($e)).unwrap()
138        };
139    }
140
141    #[test]
142    fn has_random_decimal_instances() {
143        let mut rng = rand_0_9::rng();
144        let random: [Decimal; 32] = rng.random();
145        assert!(random.windows(2).any(|slice| { slice[0] != slice[1] }));
146    }
147
148    #[test]
149    fn generates_within_range() {
150        let mut rng = rand_0_9::rng();
151        for _ in 0..128 {
152            let random = rng.random_range(dec!(1.00)..dec!(1.05));
153            assert!(random < dec!(1.05));
154            assert!(random >= dec!(1.00));
155        }
156    }
157
158    #[test]
159    fn generates_within_inclusive_range() {
160        let mut rng = rand_0_9::rng();
161        let mut values: HashSet<Decimal> = HashSet::new();
162        for _ in 0..256 {
163            let random = rng.random_range(dec!(1.00)..=dec!(1.01));
164            // The scale is 2, so 1.00 and 1.01 are the only two valid choices.
165            assert!(random == dec!(1.00) || random == dec!(1.01));
166            values.insert(random);
167        }
168        // Somewhat flaky, will fail 1 out of every 2^255 times this is run.
169        // Probably acceptable in the real world.
170        assert_eq!(values.len(), 2);
171    }
172
173    #[test]
174    fn test_edge_case_scales_match() {
175        let (low, high) = sync_scales(dec!(1.000_000_000_000_000_000_01), dec!(100_000_000_000_000_000_001));
176        assert_eq!(low.scale(), high.scale());
177    }
178}