variational_regression/
distribution.rs

1use crate::error::{RegressionError::InvalidDistribution, RegressionError};
2use serde::{Serialize, Deserialize};
3
4///
5/// Represents a generic scalar distribution
6/// 
7pub trait ScalarDistribution {
8
9    ///
10    /// Computes the mean of the distribution
11    /// 
12    fn mean(&self) -> f64;
13
14    ///
15    /// Computes the variance of the distribution
16    /// 
17    fn variance(&self) -> f64;
18
19    ///
20    /// Computes the standard deviation of the distribution
21    /// 
22    fn std_dev(&self) -> f64;
23}
24
25///
26/// Represents a Gamma distribution
27/// 
28#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
29#[non_exhaustive]
30pub struct GammaDistribution {
31    /// Shape parameter
32    pub shape: f64,
33    /// Rate (inverse scale) parameter
34    pub rate: f64
35}
36
37impl GammaDistribution {
38
39    ///
40    /// Constructs a new Gamma distribution from the provided parameters
41    /// 
42    /// # Arguments
43    /// 
44    /// `shape` - The shape parameter
45    /// `rate` - The rate parameter
46    /// 
47    pub fn new(shape: f64, rate: f64) -> Result<GammaDistribution, RegressionError> {
48        if shape <= 0.0 {
49            Err(InvalidDistribution(format!("Shape parameter must be positive (found {})", shape)))
50        } else if rate <= 0.0 {
51            Err(InvalidDistribution(format!("Rate parameter must be positive (found {})", rate)))
52        } else {
53            Ok(GammaDistribution { shape, rate })
54        }
55    }
56
57    // use as vague prior
58    pub (crate) fn vague() -> GammaDistribution {
59        GammaDistribution { shape: 1e-4, rate: 1e-4 }
60    }
61}
62
63impl ScalarDistribution for GammaDistribution {
64
65    #[inline]
66    fn mean(&self) -> f64 {
67        self.shape / self.rate
68    }
69
70    #[inline]
71    fn variance(&self) -> f64 {
72        self.shape / (self.rate * self.rate)
73    }
74
75    #[inline]
76    fn std_dev(&self) -> f64 {
77        self.variance().sqrt()
78    }
79}
80
81///
82/// Represents a Gaussian (Normal) distribution
83/// 
84#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
85#[non_exhaustive]
86pub struct GaussianDistribution {
87    /// Mean parameter
88    pub mean: f64,
89    /// Variance parameter
90    pub variance: f64
91}
92
93impl GaussianDistribution {
94
95    ///
96    /// Constructs a new Gaussian distribution from the provided parameters
97    /// 
98    /// # Arguments
99    /// 
100    /// `mean` - The mean parameter
101    /// `variance` - The variance parameter
102    /// 
103    pub fn new(mean: f64, variance: f64) -> Result<GaussianDistribution, RegressionError> {
104        if variance <= 0.0 {
105            Err(InvalidDistribution(format!("Variance must be positive (found {})", variance)))
106        }  else {
107            Ok(GaussianDistribution { mean, variance })
108        }
109    }
110}
111
112impl ScalarDistribution for GaussianDistribution {
113
114    #[inline]
115    fn mean(&self) -> f64 {
116        self.mean
117    }
118
119    #[inline]
120    fn variance(&self) -> f64 {
121        self.variance
122    }
123
124    #[inline]
125    fn std_dev(&self) -> f64 {
126        self.variance.sqrt()
127    }
128}
129
130///
131/// Represents a Bernoulli distribution
132/// 
133#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
134#[non_exhaustive]
135pub struct BernoulliDistribution {
136    /// Probability of positive event
137    pub p: f64
138}
139
140impl BernoulliDistribution {
141    
142    ///
143    /// Constructs a new Bernoulli distribution from the provided parameter
144    /// 
145    /// # Arguments
146    /// 
147    /// `p` - The probability of a positive event
148    /// 
149    pub fn new(p: f64) -> Result<BernoulliDistribution, RegressionError> {
150        if p >= 0.0 && p <= 1.0 {
151            Ok(BernoulliDistribution { p })
152        } else {
153            Err(InvalidDistribution(format!("Invalid parameter 'p': {}", p)))
154        }
155    }
156}
157
158impl ScalarDistribution for BernoulliDistribution {
159
160    #[inline]
161    fn mean(&self) -> f64 {
162        self.p
163    }
164
165    #[inline]
166    fn variance(&self) -> f64 {
167        self.p * (1.0 - self.p)
168    }
169
170    #[inline]
171    fn std_dev(&self) -> f64 {
172        self.variance().sqrt()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use assert_approx_eq::assert_approx_eq;
180
181    #[test]
182    fn test_gamma() {
183        let a = GammaDistribution::new(1.0, 2.0).unwrap();
184        assert!(a.shape == 1.0);
185        assert!(a.rate == 2.0);
186        assert_approx_eq!(a.mean(), 0.5);
187        assert_approx_eq!(a.variance(), 0.25);
188        assert_approx_eq!(a.std_dev(), 0.5);
189
190        assert!(GammaDistribution::new(0.0, 1.0).is_err());
191    }
192
193    #[test]
194    fn test_gaussian() {
195        let a = GaussianDistribution::new(1.0, 4.0).unwrap();
196        assert!(a.mean == 1.0);
197        assert!(a.variance == 4.0);
198        assert_approx_eq!(a.mean(), 1.0);
199        assert_approx_eq!(a.variance(), 4.0);
200        assert_approx_eq!(a.std_dev(), 2.0);
201
202        assert!(GaussianDistribution::new(0.0, 0.0).is_err());
203    }
204
205    #[test]
206    fn test_bernoulli() {
207        let a = BernoulliDistribution::new(0.4).unwrap();
208        assert!(a.p == 0.4);
209        assert_approx_eq!(a.mean(), 0.4);
210        assert_approx_eq!(a.variance(), 0.24);
211        assert_approx_eq!(a.std_dev(), 0.24f64.sqrt());
212
213        assert!(BernoulliDistribution::new(2.0).is_err());
214    }
215}