scirs2_stats/distributions/
uniform.rs1use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use crate::traits::{ContinuousDistribution, Distribution};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::{Float, NumCast};
10use scirs2_core::random::{Distribution as RandDistribution, Uniform as RandUniform};
11
12pub struct Uniform<F: Float> {
14 pub low: F,
16 pub high: F,
18 rand_distr: RandUniform<f64>,
20}
21
22impl<F: Float + NumCast + std::fmt::Display> Uniform<F> {
23 pub fn new(low: F, high: F) -> StatsResult<Self> {
42 if low >= high {
43 return Err(StatsError::DomainError(
44 "Lower bound must be less than upper bound".to_string(),
45 ));
46 }
47
48 let low_f64 = <f64 as NumCast>::from(low).expect("Operation failed");
50 let high_f64 = <f64 as NumCast>::from(high).expect("Operation failed");
51
52 match RandUniform::new_inclusive(low_f64, high_f64) {
53 Ok(rand_distr) => Ok(Uniform {
54 low,
55 high,
56 rand_distr,
57 }),
58 Err(_) => Err(StatsError::ComputationError(
59 "Failed to create uniform distribution".to_string(),
60 )),
61 }
62 }
63
64 pub fn pdf(&self, x: F) -> F {
84 if x >= self.low && x < self.high {
86 F::one() / (self.high - self.low)
87 } else {
88 F::zero()
89 }
90 }
91
92 pub fn cdf(&self, x: F) -> F {
112 if x <= self.low {
113 F::zero()
114 } else if x >= self.high {
115 F::one()
116 } else {
117 (x - self.low) / (self.high - self.low)
118 }
119 }
120
121 pub fn ppf(&self, p: F) -> StatsResult<F> {
141 if p < F::zero() || p > F::one() {
142 return Err(StatsError::DomainError(
143 "Probability must be between 0 and 1".to_string(),
144 ));
145 }
146
147 Ok(self.low + p * (self.high - self.low))
149 }
150
151 pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
171 let mut rng = scirs2_core::random::thread_rng();
172 let mut samples = Vec::with_capacity(size);
173
174 for _ in 0..size {
175 let sample = self.rand_distr.sample(&mut rng);
176 samples.push(F::from(sample).expect("Failed to convert to float"));
177 }
178
179 Ok(Array1::from(samples))
180 }
181}
182
183impl<F: Float + NumCast + std::fmt::Display> Distribution<F> for Uniform<F> {
185 fn mean(&self) -> F {
186 (self.low + self.high) / F::from(2.0).expect("Failed to convert constant to float")
187 }
188
189 fn var(&self) -> F {
190 let range = self.high - self.low;
191 range * range / F::from(12.0).expect("Failed to convert constant to float")
192 }
193
194 fn std(&self) -> F {
195 self.var().sqrt()
196 }
197
198 fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
199 self.rvs(size)
200 }
201
202 fn entropy(&self) -> F {
203 (self.high - self.low).ln()
204 }
205}
206
207impl<F: Float + NumCast + std::fmt::Display> ContinuousDistribution<F> for Uniform<F> {
209 fn pdf(&self, x: F) -> F {
210 self.pdf(x)
211 }
212
213 fn cdf(&self, x: F) -> F {
214 self.cdf(x)
215 }
216
217 fn ppf(&self, p: F) -> StatsResult<F> {
218 self.ppf(p)
219 }
220}
221
222impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Uniform<F> {
224 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
225 let array = self.rvs(size)?;
226 Ok(array.to_vec())
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use approx::assert_relative_eq;
234
235 #[test]
236 fn test_uniform_creation() {
237 let unif = Uniform::new(0.0, 1.0).expect("Operation failed");
239 assert_eq!(unif.low, 0.0);
240 assert_eq!(unif.high, 1.0);
241
242 let custom = Uniform::new(-1.0, 1.0).expect("Operation failed");
244 assert_eq!(custom.low, -1.0);
245 assert_eq!(custom.high, 1.0);
246
247 assert!(Uniform::<f64>::new(0.0, 0.0).is_err());
249 assert!(Uniform::<f64>::new(1.0, 0.0).is_err());
250 }
251
252 #[test]
253 fn test_uniform_pdf() {
254 let unif = Uniform::new(0.0, 1.0).expect("Operation failed");
256
257 let pdf_in_range = unif.pdf(0.5);
259 assert_relative_eq!(pdf_in_range, 1.0, epsilon = 1e-10);
260
261 let pdf_at_low = unif.pdf(0.0);
263 assert_relative_eq!(pdf_at_low, 1.0, epsilon = 1e-10);
264
265 let pdf_at_high = unif.pdf(1.0);
267 assert_relative_eq!(pdf_at_high, 0.0, epsilon = 1e-10);
268
269 let pdf_outside = unif.pdf(2.0);
271 assert_relative_eq!(pdf_outside, 0.0, epsilon = 1e-10);
272
273 let unif2 = Uniform::new(0.0, 2.0).expect("Operation failed");
275 let pdf2 = unif2.pdf(1.0);
276 assert_relative_eq!(pdf2, 0.5, epsilon = 1e-10);
277 }
278
279 #[test]
280 fn test_uniform_cdf() {
281 let unif = Uniform::new(0.0, 1.0).expect("Operation failed");
283
284 let cdf_mid = unif.cdf(0.5);
286 assert_relative_eq!(cdf_mid, 0.5, epsilon = 1e-10);
287
288 let cdf_at_low = unif.cdf(0.0);
290 assert_relative_eq!(cdf_at_low, 0.0, epsilon = 1e-10);
291
292 let cdf_at_high = unif.cdf(1.0);
294 assert_relative_eq!(cdf_at_high, 1.0, epsilon = 1e-10);
295
296 let cdf_below = unif.cdf(-1.0);
298 assert_relative_eq!(cdf_below, 0.0, epsilon = 1e-10);
299
300 let cdf_above = unif.cdf(2.0);
301 assert_relative_eq!(cdf_above, 1.0, epsilon = 1e-10);
302
303 let unif2 = Uniform::new(-1.0, 1.0).expect("Operation failed");
305 let cdf2 = unif2.cdf(0.0);
306 assert_relative_eq!(cdf2, 0.5, epsilon = 1e-10);
307 }
308
309 #[test]
310 fn test_uniform_ppf() {
311 let unif = Uniform::new(0.0, 1.0).expect("Operation failed");
313
314 let median = unif.ppf(0.5).expect("Operation failed");
316 assert_relative_eq!(median, 0.5, epsilon = 1e-10);
317
318 let p75 = unif.ppf(0.75).expect("Operation failed");
320 assert_relative_eq!(p75, 0.75, epsilon = 1e-10);
321
322 let p25 = unif.ppf(0.25).expect("Operation failed");
324 assert_relative_eq!(p25, 0.25, epsilon = 1e-10);
325
326 assert!(unif.ppf(-0.1).is_err());
328 assert!(unif.ppf(1.1).is_err());
329
330 let unif2 = Uniform::new(-1.0, 1.0).expect("Operation failed");
332 let median2 = unif2.ppf(0.5).expect("Operation failed");
333 assert_relative_eq!(median2, 0.0, epsilon = 1e-10);
334 }
335
336 #[test]
337 fn test_uniform_rvs() {
338 let unif = Uniform::new(0.0, 1.0).expect("Operation failed");
339
340 let samples = unif.rvs(1000).expect("Operation failed");
342
343 assert_eq!(samples.len(), 1000);
345
346 let sum: f64 = samples.iter().sum();
348 let mean = sum / 1000.0;
349
350 assert!((mean - 0.5).abs() < 0.1);
352
353 for &sample in samples.iter() {
355 assert!((0.0..=1.0).contains(&sample));
356 }
357 }
358}