1use nalgebra::{Cholesky, DVector, DMatrix};
2use special::Gamma;
3use serde::{Serialize, Deserialize};
4
5use crate::{RealLabels, Features, design_vector, Standardizer, VariationalRegression, get_weights, get_bias};
6use crate::error::RegressionError;
7use crate::distribution::{GammaDistribution, GaussianDistribution, ScalarDistribution};
8use crate::math::LN_2PI;
9
10type DenseVector = DVector<f64>;
11type DenseMatrix = DMatrix<f64>;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LinearTrainConfig {
19 pub weight_precision_prior: GammaDistribution,
21 pub noise_precision_prior: GammaDistribution,
23 pub use_bias: bool,
25 pub standardize: bool,
27 pub max_iter: usize,
29 pub tolerance: f64,
31 pub verbose: bool
33}
34
35impl Default for LinearTrainConfig {
36 fn default() -> Self {
37 LinearTrainConfig {
38 weight_precision_prior: GammaDistribution::vague(),
39 noise_precision_prior: GammaDistribution::new(1.0, 1e-4).unwrap(),
40 use_bias: true,
41 standardize: true,
42 max_iter: 1000,
43 tolerance: 1e-4,
44 verbose: true
45 }
46 }
47}
48
49#[derive(Clone, Serialize, Deserialize)]
53pub struct VariationalLinearRegression {
54 params: DenseVector,
56 covariance: DenseMatrix,
58 includes_bias: bool,
60 standardizer: Option<Standardizer>,
62 pub noise_precision: GammaDistribution,
64 pub bound: f64
66}
67
68impl VariationalLinearRegression {
69
70 pub fn train(
80 features: impl Features,
81 labels: impl RealLabels,
82 config: &LinearTrainConfig
83 ) -> Result<VariationalLinearRegression, RegressionError> {
84 let mut problem = Problem::new(features, labels, config);
86 for iter in 0..config.max_iter {
88 q_theta(&mut problem)?; q_alpha(&mut problem)?; q_beta(&mut problem)?; let new_bound = lower_bound(&problem)?;
92 if config.verbose {
93 println!("Iteration {}, Lower Bound = {}", iter + 1, new_bound);
94 }
95 if (new_bound - problem.bound) / problem.bound.abs() <= config.tolerance {
96 return Ok(VariationalLinearRegression {
97 params: problem.theta,
98 covariance: problem.s,
99 includes_bias: config.use_bias,
100 standardizer: problem.standardizer,
101 noise_precision: problem.beta,
102 bound: new_bound
103 })
104 } else {
105 problem.bound = new_bound;
106 }
107 }
108 Err(RegressionError::ConvergenceFailure(config.max_iter))
110 }
111}
112
113impl VariationalRegression<GaussianDistribution> for VariationalLinearRegression {
114
115 fn predict(&self, features: &[f64]) -> Result<GaussianDistribution, RegressionError> {
116 let mut x = design_vector(features, self.includes_bias);
117 if let Some(std) = &self.standardizer {
118 std.transform_vector(&mut x);
119 }
120 let npm = self.noise_precision.mean();
121 let pred_mean = x.dot(&self.params);
122 let pred_var = (1.0 / npm) + (&self.covariance * &x).dot(&x);
123 GaussianDistribution::new(pred_mean, pred_var)
124 }
125
126 fn weights(&self) -> &[f64] {
127 get_weights(self.includes_bias, &self.params)
128 }
129
130 fn bias(&self) -> Option<f64> {
131 get_bias(self.includes_bias, &self.params)
132 }
133}
134
135struct Problem {
137 pub xtx: DenseMatrix, pub xty: DenseVector, pub yty: f64, pub theta: DenseVector, pub s: DenseMatrix, pub alpha: Vec<GammaDistribution>, pub beta: GammaDistribution, pub bpp: Option<GammaDistribution>, pub wpp: GammaDistribution, pub npp: GammaDistribution, pub n: usize, pub d: usize, pub bound: f64, pub standardizer: Option<Standardizer> }
152
153impl Problem {
154
155 fn new(
156 features: impl Features,
157 labels: impl RealLabels,
158 config: &LinearTrainConfig
159 ) -> Problem {
160 let mut x = features.into_matrix(config.use_bias);
161 let standardizer = if config.standardize {
162 Some(Standardizer::fit(&x))
163 } else {
164 None
165 };
166 if let Some(std) = &standardizer {
167 std.transform_matrix(&mut x);
168 }
169 let n = x.nrows();
170 let d = x.ncols();
171 let y = labels.into_vector();
172 let xtx = x.tr_mul(&x);
173 let xty = x.tr_mul(&y);
174 let yty = y.dot(&y);
175 let bpp = if config.use_bias {
176 Some(GammaDistribution::vague())
177 } else {
178 None
179 };
180 let wpp = config.weight_precision_prior;
181 let npp = config.noise_precision_prior;
182 let mut alpha = vec![wpp; x.ncols()];
183 if let Some(pp) = bpp {
184 alpha[0] = pp;
185 }
186 let beta = npp;
187 let bound = f64::NEG_INFINITY;
188 let theta = DenseVector::zeros(d);
189 let s = DenseMatrix::zeros(d, d);
190 Problem { xtx, xty, yty, theta, s, alpha, beta, bpp, wpp, npp, n, d, bound, standardizer }
191 }
192
193 fn param_precision_prior(&self, ind: usize) -> GammaDistribution {
194 match (ind, self.bpp) {
195 (0, Some(bpp)) => bpp,
196 _ => self.wpp
197 }
198 }
199}
200
201fn q_theta(prob: &mut Problem) -> Result<(), RegressionError> {
203 let mut s_inv = &prob.xtx * prob.beta.mean();
204 for i in 0..prob.d {
205 let a = prob.alpha[i].mean();
206 s_inv[(i, i)] += a;
207 }
208 prob.s = Cholesky::new(s_inv)
209 .ok_or(RegressionError::CholeskyFailure)?
210 .inverse();
211 prob.theta = (&prob.s * prob.beta.mean()) * &prob.xty;
212 Ok(())
213}
214
215fn q_alpha(prob: &mut Problem) -> Result<(), RegressionError> {
217 for i in 0..prob.d {
218 let pp = prob.param_precision_prior(i);
219 let inv_scale = pp.rate + 0.5 * (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]);
220 prob.alpha[i] = GammaDistribution::new(pp.shape + 0.5, inv_scale)?;
221 }
222 Ok(())
223}
224
225fn q_beta(prob: &mut Problem) -> Result<(), RegressionError> {
227 let shape = prob.npp.shape + (prob.n as f64 / 2.0);
228 let t = (&prob.xtx * (&prob.theta * prob.theta.transpose() + &prob.s)).trace();
229 let inv_scale = prob.npp.rate + 0.5 * (prob.yty - 2.0 * prob.theta.dot(&prob.xty) + t);
230 prob.beta = GammaDistribution::new(shape, inv_scale)?;
231 Ok(())
232}
233
234fn lower_bound(prob: &Problem) -> Result<f64, RegressionError> {
236 Ok(expect_ln_p_y(prob)? +
237 expect_ln_p_theta(prob)? +
238 expect_ln_p_alpha(prob)? +
239 expect_ln_p_beta(prob)? -
240 expect_ln_q_theta(prob)? -
241 expect_ln_q_alpha(prob)? -
242 expect_ln_q_beta(prob)?)
243}
244
245fn expect_ln_p_y(prob: &Problem) -> Result<f64, RegressionError> {
247 let bm = prob.beta.mean();
248 let tc = &prob.theta * prob.theta.transpose();
249 let part1 = prob.xty.len() as f64 * 0.5;
250 let part2 = Gamma::digamma(prob.beta.shape) - prob.beta.rate.ln() - LN_2PI;
251 let part3 = (bm * 0.5) * prob.yty;
252 let part4 = bm * prob.theta.dot(&prob.xty);
253 let part5 = (bm * 0.5) * (&prob.xtx * (tc + &prob.s)).trace();
254 Ok(part1 * part2 - part3 + part4 - part5)
255}
256
257fn expect_ln_p_theta(prob: &Problem) -> Result<f64, RegressionError> {
259 let init = (prob.theta.len() as f64 * -0.5) * LN_2PI;
260 prob.alpha.iter().enumerate().try_fold(init, |sum, (i, a)| {
261 let am = a.mean();
262 let part1 = Gamma::digamma(a.shape) - a.rate.ln();
263 let part2 = (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]) * am;
264 Ok(sum + 0.5 * (part1 - part2))
265 })
266}
267
268fn expect_ln_p_alpha(prob: &Problem) -> Result<f64, RegressionError> {
270 prob.alpha.iter().enumerate().try_fold(0.0, |sum, (i, a)| {
271 let am = a.mean();
272 let pp = prob.param_precision_prior(i);
273 let term1 = pp.shape * pp.rate.ln();
274 let term2 = (pp.shape - 1.0) * (Gamma::digamma(a.shape) - a.rate.ln());
275 let term3 = (pp.rate * am) + Gamma::ln_gamma(pp.shape).0;
276 Ok(sum + term1 + term2 - term3)
277 })
278}
279
280fn expect_ln_p_beta(prob: &Problem) -> Result<f64, RegressionError> {
282 let part1 = prob.npp.shape * prob.npp.rate.ln();
283 let part2 = (prob.npp.shape - 1.0) * (Gamma::digamma(prob.beta.shape) - prob.beta.rate.ln());
284 let part3 = (prob.npp.rate * prob.beta.mean()) + Gamma::ln_gamma(prob.npp.shape).0;
285 Ok(part1 + part2 - part3)
286}
287
288fn expect_ln_q_theta(prob: &Problem) -> Result<f64, RegressionError> {
290 let m = prob.s.shape().0;
291 let chol = Cholesky::new(prob.s.clone())
292 .ok_or(RegressionError::CholeskyFailure)?
293 .l();
294 let mut ln_det = 0.0;
295 for i in 0..prob.s.ncols() {
296 ln_det += chol[(i, i)].ln();
297 }
298 ln_det *= 2.0;
299 Ok(-(0.5 * ln_det + (m as f64 / 2.0) * (1.0 + LN_2PI)))
300}
301
302fn expect_ln_q_alpha(prob: &Problem) -> Result<f64, RegressionError> {
304 prob.alpha.iter().try_fold(0.0, |sum, a| {
305 let part1 = Gamma::ln_gamma(a.shape).0;
306 let part2 = (a.shape - 1.0) * Gamma::digamma(a.shape);
307 let part3 = a.shape - a.rate.ln();
308 Ok(sum - (part1 - part2 + part3))
309 })
310}
311
312fn expect_ln_q_beta(prob: &Problem) -> Result<f64, RegressionError> {
314 Ok(-(Gamma::ln_gamma(prob.beta.shape).0 -
315 (prob.beta.shape - 1.0) * Gamma::digamma(prob.beta.shape) -
316 prob.beta.rate.ln() +
317 prob.beta.shape))
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use assert_approx_eq::assert_approx_eq;
324
325 const FEATURES: [[f64; 4]; 10] = [
326 [-0.2, -0.9, -0.5, 0.3],
327 [0.6, 0.3, 0.3, -0.4],
328 [0.9, -0.4, -0.5, -0.6],
329 [-0.7, 0.8, 0.3, -0.3],
330 [-0.5, -0.7, -0.1, 0.8],
331 [0.5, 0.5, 0.0, 0.1],
332 [0.1, -0.0, 0.0, -0.2],
333 [0.4, 0.0, 0.2, 0.0],
334 [-0.2, 0.9, -0.1, -0.9],
335 [0.1, 0.4, -0.5, 0.9],
336 ];
337
338 const LABELS: [f64; 10] = [
339 -0.4, 0.1, -0.8, 0.5, 0.6, -0.2, 0.0, 0.7, -0.3, 0.2
340 ];
341
342 #[test]
343 fn test_train_with_bias_with_standardize() {
344 let x = Vec::from(FEATURES.map(Vec::from));
345 let y = Vec::from(LABELS);
346 let config = LinearTrainConfig {
347 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
348 ..Default::default()
349 };
350 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
351 assert_approx_eq!(model.bias().unwrap(), 0.009795973392064526);
352 assert_approx_eq!(model.weights()[0], -0.053736076572620695);
353 assert_approx_eq!(model.weights()[1], 0.002348926942734912);
354 assert_approx_eq!(model.weights()[2], 0.36479166380848826);
355 assert_approx_eq!(model.weights()[3], 0.2995772527448547);
356 }
357
358 #[test]
359 fn test_train_with_bias_no_standardize() {
360 let x = Vec::from(FEATURES.map(Vec::from));
361 let y = Vec::from(LABELS);
362 let config = LinearTrainConfig {
363 standardize: false,
364 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
365 ..Default::default()
366 };
367 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
368 assert_approx_eq!(model.bias().unwrap(), 0.14022283613177447);
369 assert_approx_eq!(model.weights()[0], -0.08826080780896867);
370 assert_approx_eq!(model.weights()[1], 0.003684347234472394);
371 assert_approx_eq!(model.weights()[2], 1.1209335465339734);
372 assert_approx_eq!(model.weights()[3], 0.5137103057008632);
373 }
374
375 #[test]
376 fn test_train_no_bias_with_standardize() {
377 let x = Vec::from(FEATURES.map(Vec::from));
378 let y = Vec::from(LABELS);
379 let config = LinearTrainConfig {
380 use_bias: false,
381 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
382 ..Default::default()
383 };
384 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
385 assert_approx_eq!(model.weights()[0], -0.0536007042304908);
386 assert_approx_eq!(model.weights()[1], 0.0024537840396777044);
387 assert_approx_eq!(model.weights()[2], 0.3649008472250164);
388 assert_approx_eq!(model.weights()[3], 0.2997887456881104);
389 }
390
391 #[test]
392 fn test_train_no_bias_no_standardize() {
393 let x = Vec::from(FEATURES.map(Vec::from));
394 let y = Vec::from(LABELS);
395 let config = LinearTrainConfig {
396 use_bias: false,
397 standardize: false,
398 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
399 ..Default::default()
400 };
401 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
402 assert_approx_eq!(model.weights()[0], -0.0362564312306051);
403 assert_approx_eq!(model.weights()[1], 0.021598779423334057);
404 assert_approx_eq!(model.weights()[2], 0.9458928058270641);
405 assert_approx_eq!(model.weights()[3], 0.4751696529319309);
406 }
407
408 #[test]
409 fn test_predict_with_bias_with_standardize() {
410 let x = Vec::from(FEATURES.map(Vec::from));
411 let y = Vec::from(LABELS);
412 let config = LinearTrainConfig {
413 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
414 ..Default::default()
415 };
416 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
417 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
418 assert_approx_eq!(p.mean(), -0.1601830957057508);
419 assert_approx_eq!(p.variance(), 0.0421041223659715);
420 }
421
422 #[test]
423 fn test_predict_with_bias_no_standardize() {
424 let x = Vec::from(FEATURES.map(Vec::from));
425 let y = Vec::from(LABELS);
426 let config = LinearTrainConfig {
427 standardize: false,
428 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
429 ..Default::default()
430 };
431 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
432 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
433 assert_approx_eq!(p.mean(), -0.1495143747869945);
434 assert_approx_eq!(p.variance(), 0.047374206616233275);
435 }
436
437 #[test]
438 fn test_predict_no_bias_with_standardize() {
439 let x = Vec::from(FEATURES.map(Vec::from));
440 let y = Vec::from(LABELS);
441 let config = LinearTrainConfig {
442 use_bias: false,
443 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
444 ..Default::default()
445 };
446 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
447 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
448 assert_approx_eq!(p.mean(), -0.16990565682335487);
449 assert_approx_eq!(p.variance(), 0.0409272332865222);
450 }
451
452 #[test]
453 fn test_predict_no_bias_no_standardize() {
454 let x = Vec::from(FEATURES.map(Vec::from));
455 let y = Vec::from(LABELS);
456 let config = LinearTrainConfig {
457 use_bias: false,
458 standardize: false,
459 noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
460 ..Default::default()
461 };
462 let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
463 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
464 assert_approx_eq!(p.mean(), -0.2307380822928);
465 assert_approx_eq!(p.variance(), 0.07177809358927849);
466 }
467}