scirs2_core/random/
distributions_unified.rs

1//! Unified distribution interface for the SciRS2 ecosystem
2//!
3//! This module provides a consistent interface for all statistical distributions,
4//! ensuring compatibility across the entire SciRS2 ecosystem including ToRSh, SkleaRS, etc.
5//!
6//! ## Design Philosophy
7//!
8//! 1. **Zero Breaking Changes**: All existing code continues to work
9//! 2. **Full Compatibility**: Direct access to all rand_distr distributions
10//! 3. **Enhanced Functionality**: Additional scientific computing features
11//! 4. **Type Safety**: Unified trait system for distribution operations
12//!
13//! ## Usage Examples
14//!
15//! ```rust
16//! use scirs2_core::random::distributions_unified::*;
17//! use scirs2_core::random::thread_rng;
18//!
19//! // Create distributions with unified interface
20//! let normal = UnifiedNormal::new(0.0, 1.0).expect("Operation failed");
21//! let beta = UnifiedBeta::new(2.0, 5.0).expect("Operation failed");
22//! let student_t = UnifiedStudentT::new(10.0).expect("Operation failed");
23//!
24//! let mut rng = thread_rng();
25//! let sample = normal.sample_unified(&mut rng);
26//! ```
27
28use crate::random::core::Random;
29use ::ndarray::{Array1, ArrayD, Dimension, IxDyn};
30use rand::Rng;
31use rand_distr::Distribution;
32use std::fmt;
33
34/// Error type for unified distribution operations
35#[derive(Debug, Clone)]
36pub enum UnifiedDistributionError {
37    /// Invalid parameter value
38    InvalidParameter(String),
39    /// Construction failed
40    ConstructionFailed(String),
41    /// Generic error
42    Other(String),
43}
44
45impl fmt::Display for UnifiedDistributionError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            Self::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
49            Self::ConstructionFailed(msg) => write!(f, "Construction failed: {}", msg),
50            Self::Other(msg) => write!(f, "{}", msg),
51        }
52    }
53}
54
55impl std::error::Error for UnifiedDistributionError {}
56
57// Implement From for all rand_distr error types we use
58impl From<rand_distr::NormalError> for UnifiedDistributionError {
59    fn from(e: rand_distr::NormalError) -> Self {
60        Self::ConstructionFailed(format!("Normal distribution error: {:?}", e))
61    }
62}
63
64impl From<rand_distr::BetaError> for UnifiedDistributionError {
65    fn from(e: rand_distr::BetaError) -> Self {
66        Self::ConstructionFailed(format!("Beta distribution error: {:?}", e))
67    }
68}
69
70impl From<rand_distr::CauchyError> for UnifiedDistributionError {
71    fn from(e: rand_distr::CauchyError) -> Self {
72        Self::ConstructionFailed(format!("Cauchy distribution error: {:?}", e))
73    }
74}
75
76impl From<rand_distr::ChiSquaredError> for UnifiedDistributionError {
77    fn from(e: rand_distr::ChiSquaredError) -> Self {
78        Self::ConstructionFailed(format!("ChiSquared distribution error: {:?}", e))
79    }
80}
81
82impl From<rand_distr::FisherFError> for UnifiedDistributionError {
83    fn from(e: rand_distr::FisherFError) -> Self {
84        Self::ConstructionFailed(format!("FisherF distribution error: {:?}", e))
85    }
86}
87
88impl From<rand_distr::ExpError> for UnifiedDistributionError {
89    fn from(e: rand_distr::ExpError) -> Self {
90        Self::ConstructionFailed(format!("Exponential distribution error: {:?}", e))
91    }
92}
93
94impl From<rand_distr::GammaError> for UnifiedDistributionError {
95    fn from(e: rand_distr::GammaError) -> Self {
96        Self::ConstructionFailed(format!("Gamma distribution error: {:?}", e))
97    }
98}
99
100impl From<rand_distr::WeibullError> for UnifiedDistributionError {
101    fn from(e: rand_distr::WeibullError) -> Self {
102        Self::ConstructionFailed(format!("Weibull distribution error: {:?}", e))
103    }
104}
105
106impl From<rand_distr::BinomialError> for UnifiedDistributionError {
107    fn from(e: rand_distr::BinomialError) -> Self {
108        Self::ConstructionFailed(format!("Binomial distribution error: {:?}", e))
109    }
110}
111
112impl From<rand_distr::PoissonError> for UnifiedDistributionError {
113    fn from(e: rand_distr::PoissonError) -> Self {
114        Self::ConstructionFailed(format!("Poisson distribution error: {:?}", e))
115    }
116}
117
118impl From<std::io::Error> for UnifiedDistributionError {
119    fn from(e: std::io::Error) -> Self {
120        Self::Other(e.to_string())
121    }
122}
123
124/// Unified distribution trait for consistent interface across all distributions
125pub trait UnifiedDistribution<T> {
126    /// Sample a single value from the distribution
127    fn sample_unified<R: Rng>(&self, rng: &mut Random<R>) -> T;
128
129    /// Sample multiple values into a vector
130    fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<T>;
131
132    /// Sample into an ndarray
133    fn sample_array<R: Rng>(&self, rng: &mut Random<R>, shape: IxDyn) -> ArrayD<T>
134    where
135        T: Clone;
136
137    /// Get distribution parameters as a string (for debugging/logging)
138    fn parameters_string(&self) -> String;
139
140    /// Validate distribution parameters
141    fn validate(&self) -> Result<(), UnifiedDistributionError>;
142}
143
144/// Macro to implement unified wrapper for rand_distr distributions
145macro_rules! impl_unified_distribution {
146    ($name:ident, $inner:ty, $output:ty, $params:expr) => {
147        #[derive(Debug, Clone)]
148        pub struct $name {
149            inner: $inner,
150        }
151
152        impl $name {
153            /// Get reference to inner distribution
154            pub fn inner(&self) -> &$inner {
155                &self.inner
156            }
157
158            /// Get mutable reference to inner distribution
159            pub fn inner_mut(&mut self) -> &mut $inner {
160                &mut self.inner
161            }
162        }
163
164        impl UnifiedDistribution<$output> for $name {
165            fn sample_unified<R: Rng>(&self, rng: &mut Random<R>) -> $output {
166                rng.sample(&self.inner)
167            }
168
169            fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<$output> {
170                (0..n).map(|_| self.sample_unified(rng)).collect()
171            }
172
173            fn sample_array<R: Rng>(&self, rng: &mut Random<R>, shape: IxDyn) -> ArrayD<$output>
174            where
175                $output: Clone,
176            {
177                let size = shape.size();
178                let values = self.sample_vec(rng, size);
179                ArrayD::from_shape_vec(shape, values).expect("Operation failed")
180            }
181
182            fn parameters_string(&self) -> String {
183                $params(&self.inner)
184            }
185
186            fn validate(&self) -> Result<(), UnifiedDistributionError> {
187                // Validation is done during construction
188                Ok(())
189            }
190        }
191
192        impl Distribution<$output> for $name {
193            fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $output {
194                self.inner.sample(rng)
195            }
196        }
197    };
198}
199
200// Continuous distributions
201
202impl_unified_distribution!(
203    UnifiedNormal,
204    rand_distr::Normal<f64>,
205    f64,
206    |d: &rand_distr::Normal<f64>| format!("Normal(mean={}, std={})", d.mean(), d.std_dev())
207);
208
209impl UnifiedNormal {
210    pub fn new(mean: f64, std_dev: f64) -> Result<Self, UnifiedDistributionError> {
211        Ok(Self {
212            inner: rand_distr::Normal::new(mean, std_dev)?,
213        })
214    }
215
216    pub fn mean(&self) -> f64 {
217        self.inner.mean()
218    }
219
220    pub fn std_dev(&self) -> f64 {
221        self.inner.std_dev()
222    }
223}
224
225impl_unified_distribution!(
226    UnifiedBeta,
227    rand_distr::Beta<f64>,
228    f64,
229    |_: &rand_distr::Beta<f64>| "Beta(alpha, beta)".to_string()
230);
231
232impl UnifiedBeta {
233    pub fn new(alpha: f64, beta: f64) -> Result<Self, UnifiedDistributionError> {
234        Ok(Self {
235            inner: rand_distr::Beta::new(alpha, beta)?,
236        })
237    }
238}
239
240impl_unified_distribution!(
241    UnifiedCauchy,
242    rand_distr::Cauchy<f64>,
243    f64,
244    |_: &rand_distr::Cauchy<f64>| "Cauchy(median, scale)".to_string()
245);
246
247impl UnifiedCauchy {
248    pub fn new(median: f64, scale: f64) -> Result<Self, UnifiedDistributionError> {
249        Ok(Self {
250            inner: rand_distr::Cauchy::new(median, scale)?,
251        })
252    }
253}
254
255impl_unified_distribution!(
256    UnifiedChiSquared,
257    rand_distr::ChiSquared<f64>,
258    f64,
259    |_: &rand_distr::ChiSquared<f64>| "ChiSquared(k)".to_string()
260);
261
262impl UnifiedChiSquared {
263    pub fn new(k: f64) -> Result<Self, UnifiedDistributionError> {
264        Ok(Self {
265            inner: rand_distr::ChiSquared::new(k)?,
266        })
267    }
268}
269
270impl_unified_distribution!(
271    UnifiedFisherF,
272    rand_distr::FisherF<f64>,
273    f64,
274    |_: &rand_distr::FisherF<f64>| "FisherF(m, n)".to_string()
275);
276
277impl UnifiedFisherF {
278    pub fn new(m: f64, n: f64) -> Result<Self, UnifiedDistributionError> {
279        Ok(Self {
280            inner: rand_distr::FisherF::new(m, n)?,
281        })
282    }
283}
284
285impl_unified_distribution!(
286    UnifiedStudentT,
287    rand_distr::StudentT<f64>,
288    f64,
289    |_: &rand_distr::StudentT<f64>| "StudentT(n)".to_string()
290);
291
292impl UnifiedStudentT {
293    pub fn new(n: f64) -> Result<Self, UnifiedDistributionError> {
294        Ok(Self {
295            inner: rand_distr::StudentT::new(n)?,
296        })
297    }
298}
299
300impl_unified_distribution!(
301    UnifiedLogNormal,
302    rand_distr::LogNormal<f64>,
303    f64,
304    |_: &rand_distr::LogNormal<f64>| "LogNormal(mean, std)".to_string()
305);
306
307impl UnifiedLogNormal {
308    pub fn new(mean: f64, std_dev: f64) -> Result<Self, UnifiedDistributionError> {
309        Ok(Self {
310            inner: rand_distr::LogNormal::new(mean, std_dev)?,
311        })
312    }
313}
314
315impl_unified_distribution!(
316    UnifiedWeibull,
317    rand_distr::Weibull<f64>,
318    f64,
319    |_: &rand_distr::Weibull<f64>| "Weibull(scale, shape)".to_string()
320);
321
322impl UnifiedWeibull {
323    pub fn new(scale: f64, shape: f64) -> Result<Self, UnifiedDistributionError> {
324        Ok(Self {
325            inner: rand_distr::Weibull::new(scale, shape)?,
326        })
327    }
328}
329
330impl_unified_distribution!(
331    UnifiedGamma,
332    rand_distr::Gamma<f64>,
333    f64,
334    |_: &rand_distr::Gamma<f64>| "Gamma(shape, scale)".to_string()
335);
336
337impl UnifiedGamma {
338    pub fn new(shape: f64, scale: f64) -> Result<Self, UnifiedDistributionError> {
339        Ok(Self {
340            inner: rand_distr::Gamma::new(shape, scale)?,
341        })
342    }
343}
344
345impl_unified_distribution!(
346    UnifiedExp,
347    rand_distr::Exp<f64>,
348    f64,
349    |_: &rand_distr::Exp<f64>| "Exp(lambda)".to_string()
350);
351
352impl UnifiedExp {
353    pub fn new(lambda: f64) -> Result<Self, UnifiedDistributionError> {
354        Ok(Self {
355            inner: rand_distr::Exp::new(lambda)?,
356        })
357    }
358}
359
360// Discrete distributions
361
362impl_unified_distribution!(
363    UnifiedBinomial,
364    rand_distr::Binomial,
365    u64,
366    |_: &rand_distr::Binomial| "Binomial(n, p)".to_string()
367);
368
369impl UnifiedBinomial {
370    pub fn new(n: u64, p: f64) -> Result<Self, UnifiedDistributionError> {
371        Ok(Self {
372            inner: rand_distr::Binomial::new(n, p)?,
373        })
374    }
375}
376
377// Poisson<f64> in rand_distr 0.5 samples f64, not u64
378impl_unified_distribution!(
379    UnifiedPoisson,
380    rand_distr::Poisson<f64>,
381    f64,
382    |_: &rand_distr::Poisson<f64>| "Poisson(lambda)".to_string()
383);
384
385impl UnifiedPoisson {
386    pub fn new(lambda: f64) -> Result<Self, UnifiedDistributionError> {
387        Ok(Self {
388            inner: rand_distr::Poisson::new(lambda)?,
389        })
390    }
391}
392
393// Multivariate distributions
394
395/// Unified Dirichlet distribution
396///
397/// Uses `scirs2_core::random::distributions::Dirichlet` which supports `Vec<f64>`
398/// instead of `rand_distr::Dirichlet` which requires fixed-size arrays `[f64; N]`
399#[derive(Debug, Clone)]
400pub struct UnifiedDirichlet {
401    inner: crate::random::distributions::Dirichlet,
402}
403
404impl UnifiedDirichlet {
405    pub fn new(alpha: Vec<f64>) -> Result<Self, UnifiedDistributionError> {
406        Ok(Self {
407            inner: crate::random::distributions::Dirichlet::new(alpha).map_err(|e| {
408                UnifiedDistributionError::ConstructionFailed(format!("Dirichlet error: {}", e))
409            })?,
410        })
411    }
412
413    /// Sample from Dirichlet distribution
414    pub fn sample<R: Rng>(&self, rng: &mut Random<R>) -> Vec<f64> {
415        self.inner.sample(rng)
416    }
417
418    /// Sample into an Array1
419    pub fn sample_array<R: Rng>(&self, rng: &mut Random<R>) -> Array1<f64> {
420        Array1::from_vec(self.sample(rng))
421    }
422
423    /// Sample multiple times
424    pub fn sample_multiple<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<Vec<f64>> {
425        (0..n).map(|_| self.sample(rng)).collect()
426    }
427
428    /// Get alpha parameters
429    pub fn alphas(&self) -> &[f64] {
430        self.inner.alphas()
431    }
432}
433
434impl Distribution<Vec<f64>> for UnifiedDirichlet {
435    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
436        // Sample from constituent Gamma distributions manually
437        // This avoids the Random<R> vs &mut R type mismatch
438        use rand_distr::Gamma;
439        let gamma_samples: Vec<f64> = self
440            .inner
441            .alphas()
442            .iter()
443            .map(|&alpha| {
444                let gamma = Gamma::new(alpha, 1.0).expect("Operation failed");
445                rng.sample(gamma)
446            })
447            .collect();
448
449        // Normalize to get Dirichlet sample
450        let sum: f64 = gamma_samples.iter().sum();
451        gamma_samples.into_iter().map(|x| x / sum).collect()
452    }
453}
454
455impl UnifiedDistribution<Vec<f64>> for UnifiedDirichlet {
456    fn sample_unified<R: Rng>(&self, rng: &mut Random<R>) -> Vec<f64> {
457        self.sample(rng)
458    }
459
460    fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, n: usize) -> Vec<Vec<f64>> {
461        self.sample_multiple(rng, n)
462    }
463
464    fn sample_array<R: Rng>(&self, rng: &mut Random<R>, shape: IxDyn) -> ArrayD<Vec<f64>> {
465        let size = shape.size();
466        let values = self.sample_vec(rng, size);
467        ArrayD::from_shape_vec(shape, values).expect("Operation failed")
468    }
469
470    fn parameters_string(&self) -> String {
471        format!("Dirichlet(alpha=[{} values])", self.alphas().len())
472    }
473
474    fn validate(&self) -> Result<(), UnifiedDistributionError> {
475        Ok(())
476    }
477}
478
479/// Convenience functions for creating distributions with default parameters
480pub mod defaults {
481    use super::*;
482
483    /// Create a standard normal distribution (mean=0, std=1)
484    pub fn standard_normal() -> UnifiedNormal {
485        UnifiedNormal::new(0.0, 1.0).expect("Operation failed")
486    }
487
488    /// Create a uniform distribution on [0, 1)
489    pub fn uniform_01() -> rand_distr::Uniform<f64> {
490        rand_distr::Uniform::new(0.0, 1.0).expect("Operation failed")
491    }
492
493    /// Create a standard exponential distribution (lambda=1)
494    pub fn standard_exponential() -> UnifiedExp {
495        UnifiedExp::new(1.0).expect("Operation failed")
496    }
497
498    /// Create a standard gamma distribution (shape=1, scale=1)
499    pub fn standard_gamma() -> UnifiedGamma {
500        UnifiedGamma::new(1.0, 1.0).expect("Operation failed")
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::random::thread_rng;
508
509    #[test]
510    fn test_unified_normal() {
511        let dist = UnifiedNormal::new(0.0, 1.0).expect("Operation failed");
512        let mut rng = thread_rng();
513
514        let sample = dist.sample_unified(&mut rng);
515        assert!(sample.is_finite());
516
517        let samples = dist.sample_vec(&mut rng, 100);
518        assert_eq!(samples.len(), 100);
519    }
520
521    #[test]
522    fn test_unified_beta() {
523        let dist = UnifiedBeta::new(2.0, 5.0).expect("Operation failed");
524        let mut rng = thread_rng();
525
526        let sample = dist.sample_unified(&mut rng);
527        assert!(sample >= 0.0 && sample <= 1.0);
528    }
529
530    #[test]
531    fn test_unified_poisson() {
532        let dist = UnifiedPoisson::new(5.0).expect("Operation failed");
533        let mut rng = thread_rng();
534
535        let sample = dist.sample_unified(&mut rng);
536        assert!(sample >= 0.0);
537    }
538
539    #[test]
540    fn test_unified_dirichlet() {
541        let dist = UnifiedDirichlet::new(vec![1.0, 2.0, 3.0]).expect("Operation failed");
542        let mut rng = thread_rng();
543
544        let sample = dist.sample(&mut rng);
545        assert_eq!(sample.len(), 3);
546
547        let sum: f64 = sample.iter().sum();
548        assert!((sum - 1.0).abs() < 1e-10);
549    }
550
551    #[test]
552    fn test_distribution_trait() {
553        let dist = UnifiedNormal::new(0.0, 1.0).expect("Operation failed");
554        let mut rng = rand::thread_rng();
555
556        // Test that Distribution trait works
557        let sample: f64 = rng.sample(&dist);
558        assert!(sample.is_finite());
559    }
560}