1use crate::Decimal;
2use rand_0_9::{
3 distr::{
4 uniform::{SampleBorrow, SampleUniform, UniformInt, UniformSampler},
5 Distribution, StandardUniform,
6 },
7 Rng,
8};
9
10impl Distribution<Decimal> for StandardUniform {
11 fn sample<R>(&self, rng: &mut R) -> Decimal
12 where
13 R: Rng + ?Sized,
14 {
15 Decimal::from_parts(
16 rng.next_u32(),
17 rng.next_u32(),
18 rng.next_u32(),
19 rng.random(),
20 rng.random_range(0..=Decimal::MAX_SCALE),
21 )
22 }
23}
24
25impl SampleUniform for Decimal {
26 type Sampler = DecimalSampler;
27}
28
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub struct DecimalSampler {
31 mantissa_sampler: UniformInt<i128>,
32 scale: u32,
33}
34
35impl UniformSampler for DecimalSampler {
36 type X = Decimal;
37
38 #[inline]
56 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
57 where
58 B1: SampleBorrow<Self::X> + Sized,
59 B2: SampleBorrow<Self::X> + Sized,
60 {
61 let (low, high) = sync_scales(*low.borrow(), *high.borrow());
62 let high = Decimal::from_i128_with_scale(high.mantissa() - 1, high.scale());
63 UniformSampler::new_inclusive(low, high)
64 }
65
66 #[inline]
84 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
85 where
86 B1: SampleBorrow<Self::X> + Sized,
87 B2: SampleBorrow<Self::X> + Sized,
88 {
89 let (low, high) = sync_scales(*low.borrow(), *high.borrow());
90
91 Ok(Self {
94 mantissa_sampler: UniformInt::new_inclusive(low.mantissa(), high.mantissa())?,
95 scale: low.scale(),
96 })
97 }
98
99 #[inline]
100 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
101 let mantissa = self.mantissa_sampler.sample(rng);
102 Decimal::from_i128_with_scale(mantissa, self.scale)
103 }
104}
105
106#[inline]
108fn sync_scales(mut a: Decimal, mut b: Decimal) -> (Decimal, Decimal) {
109 if a.scale() == b.scale() {
110 return (a, b);
111 }
112
113 a.rescale(a.scale().max(b.scale()));
116 b.rescale(a.scale().max(b.scale()));
117
118 if a.scale() != b.scale() {
122 a.rescale(a.scale().min(b.scale()));
123 b.rescale(a.scale().min(b.scale()));
124 }
125
126 (a, b)
127}
128
129#[cfg(test)]
130mod rand_tests {
131 use std::collections::HashSet;
132
133 use super::*;
134
135 macro_rules! dec {
136 ($e:expr) => {
137 Decimal::from_str_exact(stringify!($e)).unwrap()
138 };
139 }
140
141 #[test]
142 fn has_random_decimal_instances() {
143 let mut rng = rand_0_9::rng();
144 let random: [Decimal; 32] = rng.random();
145 assert!(random.windows(2).any(|slice| { slice[0] != slice[1] }));
146 }
147
148 #[test]
149 fn generates_within_range() {
150 let mut rng = rand_0_9::rng();
151 for _ in 0..128 {
152 let random = rng.random_range(dec!(1.00)..dec!(1.05));
153 assert!(random < dec!(1.05));
154 assert!(random >= dec!(1.00));
155 }
156 }
157
158 #[test]
159 fn generates_within_inclusive_range() {
160 let mut rng = rand_0_9::rng();
161 let mut values: HashSet<Decimal> = HashSet::new();
162 for _ in 0..256 {
163 let random = rng.random_range(dec!(1.00)..=dec!(1.01));
164 assert!(random == dec!(1.00) || random == dec!(1.01));
166 values.insert(random);
167 }
168 assert_eq!(values.len(), 2);
171 }
172
173 #[test]
174 fn test_edge_case_scales_match() {
175 let (low, high) = sync_scales(dec!(1.000_000_000_000_000_000_01), dec!(100_000_000_000_000_000_001));
176 assert_eq!(low.scale(), high.scale());
177 }
178}