1use crate::Decimal;
2use rand_0_9::{
3 distr::{
4 uniform::{SampleBorrow, SampleUniform, UniformInt, UniformSampler},
5 Distribution, StandardUniform,
6 },
7 Rng,
8};
9
10impl Distribution<Decimal> for StandardUniform {
11 fn sample<R>(&self, rng: &mut R) -> Decimal
12 where
13 R: Rng + ?Sized,
14 {
15 Decimal::from_parts(
16 rng.next_u32(),
17 rng.next_u32(),
18 rng.next_u32(),
19 rng.random(),
20 rng.random_range(0..=Decimal::MAX_SCALE),
21 )
22 }
23}
24
25impl SampleUniform for Decimal {
26 type Sampler = DecimalSampler;
27}
28
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub struct DecimalSampler {
31 mantissa_sampler: UniformInt<i128>,
32 scale: u32,
33}
34
35impl UniformSampler for DecimalSampler {
36 type X = Decimal;
37
38 #[inline]
56 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
57 where
58 B1: SampleBorrow<Self::X> + Sized,
59 B2: SampleBorrow<Self::X> + Sized,
60 {
61 let (low, high) = sync_scales(*low.borrow(), *high.borrow());
62 let high = Decimal::from_i128_with_scale(high.mantissa() - 1, high.scale());
63 UniformSampler::new_inclusive(low, high)
64 }
65
66 #[inline]
84 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
85 where
86 B1: SampleBorrow<Self::X> + Sized,
87 B2: SampleBorrow<Self::X> + Sized,
88 {
89 let (low, high) = sync_scales(*low.borrow(), *high.borrow());
90
91 Ok(Self {
94 mantissa_sampler: UniformInt::new_inclusive(low.mantissa(), high.mantissa())?,
95 scale: low.scale(),
96 })
97 }
98
99 #[inline]
100 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
101 let mantissa = self.mantissa_sampler.sample(rng);
102 Decimal::from_i128_with_scale(mantissa, self.scale)
103 }
104}
105
106#[inline]
108fn sync_scales(mut a: Decimal, mut b: Decimal) -> (Decimal, Decimal) {
109 if a.scale() == b.scale() {
110 return (a, b);
111 }
112
113 a.rescale(a.scale().max(b.scale()));
116 b.rescale(a.scale().max(b.scale()));
117
118 if a.scale() != b.scale() {
122 a.rescale(a.scale().min(b.scale()));
123 b.rescale(a.scale().min(b.scale()));
124 }
125
126 (a, b)
127}
128
129#[cfg(test)]
130mod rand_tests {
131 use super::*;
132
133 macro_rules! dec {
134 ($e:expr) => {
135 Decimal::from_str_exact(stringify!($e)).unwrap()
136 };
137 }
138
139 #[test]
140 fn has_random_decimal_instances() {
141 let mut rng = rand_0_9::rng();
142 let random: [Decimal; 32] = rng.random();
143 assert!(random.windows(2).any(|slice| { slice[0] != slice[1] }));
144 }
145
146 #[test]
147 fn generates_within_range() {
148 let mut rng = rand_0_9::rng();
149 for _ in 0..128 {
150 let random = rng.random_range(dec!(1.00)..dec!(1.05));
151 assert!(random < dec!(1.05));
152 assert!(random >= dec!(1.00));
153 }
154 }
155
156 #[test]
157 fn generates_within_inclusive_range() {
158 let mut rng = rand_0_9::rng();
159 let mut saw_low = false;
160 let mut saw_high = false;
161 for _ in 0..256 {
162 let random = rng.random_range(dec!(1.00)..=dec!(1.01));
163 assert!(random == dec!(1.00) || random == dec!(1.01));
165 if random == dec!(1.00) {
166 saw_low = true;
167 } else {
168 saw_high = true;
169 }
170 }
171 assert!(saw_low && saw_high);
174 }
175
176 #[test]
177 fn test_edge_case_scales_match() {
178 let (low, high) = sync_scales(dec!(1.000_000_000_000_000_000_01), dec!(100_000_000_000_000_000_001));
179 assert_eq!(low.scale(), high.scale());
180 }
181}