variational_regression/
linear.rs

1use nalgebra::{Cholesky, DVector, DMatrix};
2use special::Gamma;
3use serde::{Serialize, Deserialize};
4
5use crate::{RealLabels, Features, design_vector, Standardizer, VariationalRegression, get_weights, get_bias};
6use crate::error::RegressionError;
7use crate::distribution::{GammaDistribution, GaussianDistribution, ScalarDistribution};
8use crate::math::LN_2PI;
9
10type DenseVector = DVector<f64>;
11type DenseMatrix = DMatrix<f64>;
12
13///
14/// Specifies configurable options for training a 
15/// variational linear regression model
16/// 
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct LinearTrainConfig {
19    /// Prior distribution over the precision of the model weights
20    pub weight_precision_prior: GammaDistribution,
21    /// Prior distribution over the precision of the noise term
22    pub noise_precision_prior: GammaDistribution,
23    /// Whether or not to include a bias term
24    pub use_bias: bool,
25    /// Whether or not to standardize the features
26    pub standardize: bool,
27    /// Maximum number of training iterations
28    pub max_iter: usize,
29    /// Convergence criteria threshold
30    pub tolerance: f64,
31    /// Indicates whether or not to print training info
32    pub verbose: bool
33}
34
35impl Default for LinearTrainConfig {
36    fn default() -> Self {
37        LinearTrainConfig {
38            weight_precision_prior: GammaDistribution::vague(),
39            noise_precision_prior: GammaDistribution::new(1.0, 1e-4).unwrap(),
40            use_bias: true,
41            standardize: true,
42            max_iter: 1000, 
43            tolerance: 1e-4,
44            verbose: true
45        }
46    }
47}
48
49///
50/// Represents a linear regression model trained via variational inference
51/// 
52#[derive(Clone, Serialize, Deserialize)]
53pub struct VariationalLinearRegression {
54    /// Learned model weights
55    params: DenseVector,
56    /// Covariance matrix
57    covariance: DenseMatrix,
58    /// Whether the model was trained with a bias term or not
59    includes_bias: bool,
60    /// Optional feature standardizer
61    standardizer: Option<Standardizer>,
62    /// Noise precision distribution
63    pub noise_precision: GammaDistribution,
64    /// Variational lower bound
65    pub bound: f64
66}
67
68impl VariationalLinearRegression {
69
70    ///
71    /// Trains the model on the provided data
72    /// 
73    /// # Arguments
74    /// 
75    /// `features` - The feature values (in row-major orientation)
76    /// `labels` - The vector of corresponding labels
77    /// `config` - The training configuration
78    /// 
79    pub fn train(
80        features: impl Features,
81        labels: impl RealLabels,
82        config: &LinearTrainConfig
83    ) -> Result<VariationalLinearRegression, RegressionError> {
84        // precompute required values
85        let mut problem = Problem::new(features, labels, config);
86        // optimize the variational lower bound until convergence
87        for iter in 0..config.max_iter {
88            q_theta(&mut problem)?; // model parameters
89            q_alpha(&mut problem)?; // weight precisions
90            q_beta(&mut problem)?; // noise precision
91            let new_bound = lower_bound(&problem)?;
92            if config.verbose {
93                println!("Iteration {}, Lower Bound = {}", iter + 1, new_bound);
94            }
95            if (new_bound - problem.bound) / problem.bound.abs() <= config.tolerance {
96                return Ok(VariationalLinearRegression {
97                    params: problem.theta, 
98                    covariance: problem.s, 
99                    includes_bias: config.use_bias,
100                    standardizer: problem.standardizer,
101                    noise_precision: problem.beta, 
102                    bound: new_bound
103                })
104            } else {
105                problem.bound = new_bound;
106            }
107        }
108        // admit defeat
109        Err(RegressionError::ConvergenceFailure(config.max_iter))
110    }
111}
112
113impl VariationalRegression<GaussianDistribution> for VariationalLinearRegression {
114
115    fn predict(&self, features: &[f64]) -> Result<GaussianDistribution, RegressionError> {
116        let mut x = design_vector(features, self.includes_bias);
117        if let Some(std) = &self.standardizer {
118            std.transform_vector(&mut x);
119        }
120        let npm = self.noise_precision.mean();
121        let pred_mean = x.dot(&self.params);
122        let pred_var = (1.0 / npm) + (&self.covariance * &x).dot(&x);
123        GaussianDistribution::new(pred_mean, pred_var)
124    }
125
126    fn weights(&self) -> &[f64] {
127        get_weights(self.includes_bias, &self.params)
128    }
129
130    fn bias(&self) -> Option<f64> {
131        get_bias(self.includes_bias, &self.params)
132    }
133}
134
135// Defines the regression problem
136struct Problem {
137    pub xtx: DenseMatrix, // t(x) * x
138    pub xty: DenseVector, // t(x) * y
139    pub yty: f64, // t(y) * y
140    pub theta: DenseVector, // parameters (bias & weights)
141    pub s: DenseMatrix, // covariance
142    pub alpha: Vec<GammaDistribution>, // parameter precisions
143    pub beta: GammaDistribution, // noise precision
144    pub bpp: Option<GammaDistribution>, // bias prior precision
145    pub wpp: GammaDistribution, // weight prior precision
146    pub npp: GammaDistribution, // noise prior precision
147    pub n: usize, // number of training examples
148    pub d: usize, // feature dimensionality (including bias)
149    pub bound: f64, // variational lower bound
150    pub standardizer: Option<Standardizer> // feature standardizer
151}
152
153impl Problem {
154
155    fn new(
156        features: impl Features,
157        labels: impl RealLabels,
158        config: &LinearTrainConfig
159    ) -> Problem {
160        let mut x = features.into_matrix(config.use_bias);
161        let standardizer = if config.standardize {
162            Some(Standardizer::fit(&x))
163        } else {
164            None
165        };
166        if let Some(std) = &standardizer {
167            std.transform_matrix(&mut x);
168        }
169        let n = x.nrows();
170        let d = x.ncols();
171        let y = labels.into_vector();
172        let xtx = x.tr_mul(&x);
173        let xty = x.tr_mul(&y);
174        let yty = y.dot(&y);
175        let bpp = if config.use_bias {
176            Some(GammaDistribution::vague())
177        } else {
178            None
179        };
180        let wpp = config.weight_precision_prior;
181        let npp = config.noise_precision_prior;
182        let mut alpha = vec![wpp; x.ncols()];
183        if let Some(pp) = bpp {
184            alpha[0] = pp;
185        }
186        let beta = npp;
187        let bound = f64::NEG_INFINITY;
188        let theta = DenseVector::zeros(d);
189        let s = DenseMatrix::zeros(d, d);
190        Problem { xtx, xty, yty, theta, s, alpha, beta, bpp, wpp, npp, n, d, bound, standardizer }
191    }
192
193    fn param_precision_prior(&self, ind: usize) -> GammaDistribution {
194        match (ind, self.bpp) {
195            (0, Some(bpp)) => bpp,
196            _ => self.wpp
197        }
198    }
199}
200
201// Factorized distribution for parameter values
202fn q_theta(prob: &mut Problem) -> Result<(), RegressionError> {
203    let mut s_inv = &prob.xtx * prob.beta.mean();
204    for i in 0..prob.d {
205        let a = prob.alpha[i].mean();
206        s_inv[(i, i)] += a;
207    }
208    prob.s = Cholesky::new(s_inv)
209        .ok_or(RegressionError::CholeskyFailure)?
210        .inverse();
211    prob.theta = (&prob.s * prob.beta.mean()) * &prob.xty;
212    Ok(())
213}
214
215// Factorized distribution for parameter precisions
216fn q_alpha(prob: &mut Problem) -> Result<(), RegressionError> {
217    for i in 0..prob.d {
218        let pp = prob.param_precision_prior(i);
219        let inv_scale = pp.rate + 0.5 * (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]);
220        prob.alpha[i] = GammaDistribution::new(pp.shape + 0.5, inv_scale)?;
221    }
222    Ok(())
223}
224
225// Factorized distribution for noise precision
226fn q_beta(prob: &mut Problem) -> Result<(), RegressionError> {
227    let shape = prob.npp.shape + (prob.n as f64 / 2.0);
228    let t = (&prob.xtx * (&prob.theta * prob.theta.transpose() + &prob.s)).trace();
229    let inv_scale = prob.npp.rate + 0.5 * (prob.yty - 2.0 * prob.theta.dot(&prob.xty) + t);
230    prob.beta = GammaDistribution::new(shape, inv_scale)?;
231    Ok(())
232}
233
234// Variational lower bound given current model parameters
235fn lower_bound(prob: &Problem) -> Result<f64, RegressionError> {
236    Ok(expect_ln_p_y(prob)? +
237    expect_ln_p_theta(prob)? +
238    expect_ln_p_alpha(prob)? +
239    expect_ln_p_beta(prob)? -
240    expect_ln_q_theta(prob)? -
241    expect_ln_q_alpha(prob)? -
242    expect_ln_q_beta(prob)?)
243}
244
245// E[ln p(y|theta)]
246fn expect_ln_p_y(prob: &Problem) -> Result<f64, RegressionError> {
247    let bm = prob.beta.mean();
248    let tc = &prob.theta * prob.theta.transpose();
249    let part1 = prob.xty.len() as f64 * 0.5;
250    let part2 = Gamma::digamma(prob.beta.shape) - prob.beta.rate.ln() - LN_2PI;
251    let part3 = (bm * 0.5) * prob.yty;
252    let part4 = bm * prob.theta.dot(&prob.xty);
253    let part5 = (bm * 0.5) * (&prob.xtx * (tc + &prob.s)).trace();
254    Ok(part1 * part2 - part3 + part4 - part5)
255}
256
257// E[ln p(theta|alpha)]
258fn expect_ln_p_theta(prob: &Problem) -> Result<f64, RegressionError> {
259    let init = (prob.theta.len() as f64 * -0.5) * LN_2PI;
260    prob.alpha.iter().enumerate().try_fold(init, |sum, (i, a)| {
261        let am = a.mean();
262        let part1 = Gamma::digamma(a.shape) - a.rate.ln();
263        let part2 = (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]) * am;
264        Ok(sum + 0.5 * (part1 - part2))
265    })
266}
267
268// E[ln p(alpha)]
269fn expect_ln_p_alpha(prob: &Problem) -> Result<f64, RegressionError> {
270    prob.alpha.iter().enumerate().try_fold(0.0, |sum, (i, a)| {
271        let am = a.mean();
272        let pp = prob.param_precision_prior(i);
273        let term1 = pp.shape * pp.rate.ln();
274        let term2 = (pp.shape - 1.0) * (Gamma::digamma(a.shape) - a.rate.ln());
275        let term3 = (pp.rate * am) + Gamma::ln_gamma(pp.shape).0;
276        Ok(sum + term1 + term2 - term3)
277    })
278}
279
280// E[ln p(beta)]
281fn expect_ln_p_beta(prob: &Problem) -> Result<f64, RegressionError> {
282    let part1 = prob.npp.shape * prob.npp.rate.ln();
283    let part2 = (prob.npp.shape - 1.0) * (Gamma::digamma(prob.beta.shape) - prob.beta.rate.ln());
284    let part3 = (prob.npp.rate * prob.beta.mean()) + Gamma::ln_gamma(prob.npp.shape).0;
285    Ok(part1 + part2 - part3)
286}
287
288// E[ln q(theta)]
289fn expect_ln_q_theta(prob: &Problem) -> Result<f64, RegressionError> {
290    let m = prob.s.shape().0;
291    let chol = Cholesky::new(prob.s.clone())
292    .ok_or(RegressionError::CholeskyFailure)?
293    .l();
294    let mut ln_det = 0.0;
295    for i in 0..prob.s.ncols() {
296        ln_det += chol[(i, i)].ln();
297    }
298    ln_det *= 2.0;
299    Ok(-(0.5 * ln_det + (m as f64 / 2.0) * (1.0 + LN_2PI)))
300}
301
302// E[ln q(alpha)]
303fn expect_ln_q_alpha(prob: &Problem) -> Result<f64, RegressionError> {
304    prob.alpha.iter().try_fold(0.0, |sum, a| {
305        let part1 = Gamma::ln_gamma(a.shape).0;
306        let part2 = (a.shape - 1.0) * Gamma::digamma(a.shape);
307        let part3 = a.shape - a.rate.ln();
308        Ok(sum - (part1 - part2 + part3))
309    })
310}
311
312// E[ln q(beta)]
313fn expect_ln_q_beta(prob: &Problem) -> Result<f64, RegressionError> {
314    Ok(-(Gamma::ln_gamma(prob.beta.shape).0 - 
315    (prob.beta.shape - 1.0) * Gamma::digamma(prob.beta.shape) - 
316    prob.beta.rate.ln() + 
317    prob.beta.shape))
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use assert_approx_eq::assert_approx_eq;
324
325    const FEATURES: [[f64; 4]; 10] = [
326        [-0.2, -0.9, -0.5, 0.3],
327        [0.6, 0.3, 0.3, -0.4],
328        [0.9, -0.4, -0.5, -0.6],
329        [-0.7, 0.8, 0.3, -0.3],
330        [-0.5, -0.7, -0.1, 0.8],
331        [0.5, 0.5, 0.0, 0.1],
332        [0.1, -0.0, 0.0, -0.2],
333        [0.4, 0.0, 0.2, 0.0],
334        [-0.2, 0.9, -0.1, -0.9],
335        [0.1, 0.4, -0.5, 0.9],
336    ];
337
338    const LABELS: [f64; 10] = [
339        -0.4, 0.1, -0.8, 0.5, 0.6, -0.2, 0.0, 0.7, -0.3, 0.2
340    ];
341
342    #[test]
343    fn test_train_with_bias_with_standardize() {
344        let x = Vec::from(FEATURES.map(Vec::from));
345        let y = Vec::from(LABELS);
346        let config = LinearTrainConfig {
347            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
348            ..Default::default()
349        };
350        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
351        assert_approx_eq!(model.bias().unwrap(), 0.009795973392064526);
352        assert_approx_eq!(model.weights()[0], -0.053736076572620695);
353        assert_approx_eq!(model.weights()[1], 0.002348926942734912);
354        assert_approx_eq!(model.weights()[2], 0.36479166380848826);
355        assert_approx_eq!(model.weights()[3], 0.2995772527448547);
356    }
357
358    #[test]
359    fn test_train_with_bias_no_standardize() {
360        let x = Vec::from(FEATURES.map(Vec::from));
361        let y = Vec::from(LABELS);
362        let config = LinearTrainConfig {
363            standardize: false,
364            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
365            ..Default::default()
366        };
367        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
368        assert_approx_eq!(model.bias().unwrap(), 0.14022283613177447);
369        assert_approx_eq!(model.weights()[0], -0.08826080780896867);
370        assert_approx_eq!(model.weights()[1], 0.003684347234472394);
371        assert_approx_eq!(model.weights()[2], 1.1209335465339734);
372        assert_approx_eq!(model.weights()[3], 0.5137103057008632);
373    }
374
375    #[test]
376    fn test_train_no_bias_with_standardize() {
377        let x = Vec::from(FEATURES.map(Vec::from));
378        let y = Vec::from(LABELS);
379        let config = LinearTrainConfig {
380            use_bias: false,
381            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
382            ..Default::default()
383        };
384        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
385        assert_approx_eq!(model.weights()[0], -0.0536007042304908);
386        assert_approx_eq!(model.weights()[1], 0.0024537840396777044);
387        assert_approx_eq!(model.weights()[2], 0.3649008472250164);
388        assert_approx_eq!(model.weights()[3], 0.2997887456881104);
389    }
390
391    #[test]
392    fn test_train_no_bias_no_standardize() {
393        let x = Vec::from(FEATURES.map(Vec::from));
394        let y = Vec::from(LABELS);
395        let config = LinearTrainConfig {
396            use_bias: false,
397            standardize: false,
398            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
399            ..Default::default()
400        };
401        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
402        assert_approx_eq!(model.weights()[0], -0.0362564312306051);
403        assert_approx_eq!(model.weights()[1], 0.021598779423334057);
404        assert_approx_eq!(model.weights()[2], 0.9458928058270641);
405        assert_approx_eq!(model.weights()[3], 0.4751696529319309);
406    }
407
408    #[test]
409    fn test_predict_with_bias_with_standardize() {
410        let x = Vec::from(FEATURES.map(Vec::from));
411        let y = Vec::from(LABELS);
412        let config = LinearTrainConfig {
413            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
414            ..Default::default()
415        };
416        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
417        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
418        assert_approx_eq!(p.mean(), -0.1601830957057508);
419        assert_approx_eq!(p.variance(), 0.0421041223659715);
420    }
421
422    #[test]
423    fn test_predict_with_bias_no_standardize() {
424        let x = Vec::from(FEATURES.map(Vec::from));
425        let y = Vec::from(LABELS);
426        let config = LinearTrainConfig {
427            standardize: false,
428            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
429            ..Default::default()
430        };
431        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
432        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
433        assert_approx_eq!(p.mean(), -0.1495143747869945);
434        assert_approx_eq!(p.variance(), 0.047374206616233275);
435    }
436
437    #[test]
438    fn test_predict_no_bias_with_standardize() {
439        let x = Vec::from(FEATURES.map(Vec::from));
440        let y = Vec::from(LABELS);
441        let config = LinearTrainConfig {
442            use_bias: false,
443            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
444            ..Default::default()
445        };
446        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
447        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
448        assert_approx_eq!(p.mean(), -0.16990565682335487);
449        assert_approx_eq!(p.variance(), 0.0409272332865222);
450    }
451
452    #[test]
453    fn test_predict_no_bias_no_standardize() {
454        let x = Vec::from(FEATURES.map(Vec::from));
455        let y = Vec::from(LABELS);
456        let config = LinearTrainConfig {
457            use_bias: false,
458            standardize: false,
459            noise_precision_prior: GammaDistribution { shape: 1.0001, rate: 1e-4 },
460            ..Default::default()
461        };
462        let model = VariationalLinearRegression::train(&x, &y, &config).unwrap();
463        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap();
464        assert_approx_eq!(p.mean(), -0.2307380822928);
465        assert_approx_eq!(p.variance(), 0.07177809358927849);
466    }
467}