variational_regression/
distribution.rs1use crate::error::{RegressionError::InvalidDistribution, RegressionError};
2use serde::{Serialize, Deserialize};
3
4pub trait ScalarDistribution {
8
9 fn mean(&self) -> f64;
13
14 fn variance(&self) -> f64;
18
19 fn std_dev(&self) -> f64;
23}
24
25#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
29#[non_exhaustive]
30pub struct GammaDistribution {
31 pub shape: f64,
33 pub rate: f64
35}
36
37impl GammaDistribution {
38
39 pub fn new(shape: f64, rate: f64) -> Result<GammaDistribution, RegressionError> {
48 if shape <= 0.0 {
49 Err(InvalidDistribution(format!("Shape parameter must be positive (found {})", shape)))
50 } else if rate <= 0.0 {
51 Err(InvalidDistribution(format!("Rate parameter must be positive (found {})", rate)))
52 } else {
53 Ok(GammaDistribution { shape, rate })
54 }
55 }
56
57 pub (crate) fn vague() -> GammaDistribution {
59 GammaDistribution { shape: 1e-4, rate: 1e-4 }
60 }
61}
62
63impl ScalarDistribution for GammaDistribution {
64
65 #[inline]
66 fn mean(&self) -> f64 {
67 self.shape / self.rate
68 }
69
70 #[inline]
71 fn variance(&self) -> f64 {
72 self.shape / (self.rate * self.rate)
73 }
74
75 #[inline]
76 fn std_dev(&self) -> f64 {
77 self.variance().sqrt()
78 }
79}
80
81#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
85#[non_exhaustive]
86pub struct GaussianDistribution {
87 pub mean: f64,
89 pub variance: f64
91}
92
93impl GaussianDistribution {
94
95 pub fn new(mean: f64, variance: f64) -> Result<GaussianDistribution, RegressionError> {
104 if variance <= 0.0 {
105 Err(InvalidDistribution(format!("Variance must be positive (found {})", variance)))
106 } else {
107 Ok(GaussianDistribution { mean, variance })
108 }
109 }
110}
111
112impl ScalarDistribution for GaussianDistribution {
113
114 #[inline]
115 fn mean(&self) -> f64 {
116 self.mean
117 }
118
119 #[inline]
120 fn variance(&self) -> f64 {
121 self.variance
122 }
123
124 #[inline]
125 fn std_dev(&self) -> f64 {
126 self.variance.sqrt()
127 }
128}
129
130#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
134#[non_exhaustive]
135pub struct BernoulliDistribution {
136 pub p: f64
138}
139
140impl BernoulliDistribution {
141
142 pub fn new(p: f64) -> Result<BernoulliDistribution, RegressionError> {
150 if p >= 0.0 && p <= 1.0 {
151 Ok(BernoulliDistribution { p })
152 } else {
153 Err(InvalidDistribution(format!("Invalid parameter 'p': {}", p)))
154 }
155 }
156}
157
158impl ScalarDistribution for BernoulliDistribution {
159
160 #[inline]
161 fn mean(&self) -> f64 {
162 self.p
163 }
164
165 #[inline]
166 fn variance(&self) -> f64 {
167 self.p * (1.0 - self.p)
168 }
169
170 #[inline]
171 fn std_dev(&self) -> f64 {
172 self.variance().sqrt()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use assert_approx_eq::assert_approx_eq;
180
181 #[test]
182 fn test_gamma() {
183 let a = GammaDistribution::new(1.0, 2.0).unwrap();
184 assert!(a.shape == 1.0);
185 assert!(a.rate == 2.0);
186 assert_approx_eq!(a.mean(), 0.5);
187 assert_approx_eq!(a.variance(), 0.25);
188 assert_approx_eq!(a.std_dev(), 0.5);
189
190 assert!(GammaDistribution::new(0.0, 1.0).is_err());
191 }
192
193 #[test]
194 fn test_gaussian() {
195 let a = GaussianDistribution::new(1.0, 4.0).unwrap();
196 assert!(a.mean == 1.0);
197 assert!(a.variance == 4.0);
198 assert_approx_eq!(a.mean(), 1.0);
199 assert_approx_eq!(a.variance(), 4.0);
200 assert_approx_eq!(a.std_dev(), 2.0);
201
202 assert!(GaussianDistribution::new(0.0, 0.0).is_err());
203 }
204
205 #[test]
206 fn test_bernoulli() {
207 let a = BernoulliDistribution::new(0.4).unwrap();
208 assert!(a.p == 0.4);
209 assert_approx_eq!(a.mean(), 0.4);
210 assert_approx_eq!(a.variance(), 0.24);
211 assert_approx_eq!(a.std_dev(), 0.24f64.sqrt());
212
213 assert!(BernoulliDistribution::new(2.0).is_err());
214 }
215}