variational_regression/
logistic.rs

1use std::f64::consts::PI;
2use nalgebra::{Cholesky, DVector, DMatrix};
3use special::Gamma;
4use serde::{Serialize, Deserialize};
5
6use crate::{BinaryLabels, Features, design_vector, Standardizer, VariationalRegression, get_weights, get_bias};
7use crate::distribution::{GammaDistribution, BernoulliDistribution, ScalarDistribution};
8use crate::error::RegressionError;
9use crate::math::{LN_2PI, logistic, trace_of_product, scale_rows};
10
11type DenseVector = DVector<f64>;
12type DenseMatrix = DMatrix<f64>;
13
14
15///
16/// Specifies configurable options for training a 
17/// variational logistic regression model
18/// 
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LogisticTrainConfig {
21    /// Prior distribution over the precision of the model weights
22    pub weight_precision_prior: GammaDistribution,
23    /// Whether or not to include a bias term
24    pub use_bias: bool,
25    /// Whether or not to standardize the features
26    pub standardize: bool,
27    /// Maximum number of training iterations
28    pub max_iter: usize,
29    /// Convergence criteria threshold
30    pub tolerance: f64,
31    /// Indicates whether or not to print training info
32    pub verbose: bool
33}
34
35impl Default for LogisticTrainConfig {
36    fn default() -> Self {
37        LogisticTrainConfig {
38            weight_precision_prior: GammaDistribution::vague(),
39            use_bias: true,
40            standardize: true,
41            max_iter: 1000, 
42            tolerance: 1e-4,
43            verbose: true
44        }
45    }
46}
47
48///
49/// Represents a logistic regression model trained via variational inference
50/// 
51#[derive(Clone, Serialize, Deserialize)]
52pub struct VariationalLogisticRegression {
53    /// Learned model parameters
54    params: DenseVector,
55    /// Covariance matrix
56    covariance: DenseMatrix,
57    /// Whether the model was trained with a bias term or not
58    includes_bias: bool,
59    /// Optional feature standardizer
60    standardizer: Option<Standardizer>,
61    /// Variational lower bound
62    pub bound: f64
63}
64
65impl VariationalLogisticRegression {
66
67    ///
68    /// Trains the model on the provided data
69    /// 
70    /// # Arguments
71    /// 
72    /// `features` - The feature values (in row-major orientation)
73    /// `labels` - The vector of corresponding labels
74    /// `config` - The training configuration
75    /// 
76    pub fn train(
77        features: impl Features,
78        labels: impl BinaryLabels,
79        config: &LogisticTrainConfig
80    ) -> Result<VariationalLogisticRegression, RegressionError> {
81        // precompute required values
82        let mut problem = Problem::new(features, labels, config);
83        // optimize the variational lower bound until convergence
84        for iter in 0..config.max_iter {
85            q_theta(&mut problem)?; // model parameters
86            q_alpha(&mut problem)?; // weight precisions
87            update_zeta(&mut problem)?; // noise precision
88            let new_bound = lower_bound(&problem)?;
89            if config.verbose {
90                println!("Iteration {}, Lower Bound = {}", iter + 1, new_bound);
91            }
92            if (new_bound - problem.bound) / problem.bound.abs() <= config.tolerance {
93                return Ok(VariationalLogisticRegression {
94                    params: problem.theta, 
95                    covariance: problem.s,
96                    includes_bias: config.use_bias,
97                    standardizer: problem.standardizer,
98                    bound: new_bound
99                })
100            } else {
101                problem.bound = new_bound;
102            }
103        }
104        // admit defeat
105        Err(RegressionError::ConvergenceFailure(config.max_iter))
106    }
107}
108
109impl VariationalRegression<BernoulliDistribution> for VariationalLogisticRegression {
110
111    fn predict(&self, features: &[f64]) -> Result<BernoulliDistribution, RegressionError> {
112        let mut x = design_vector(features, self.includes_bias);
113        if let Some(std) = &self.standardizer {
114            std.transform_vector(&mut x);
115        }
116        let mu = x.dot(&self.params);
117        let s = (&self.covariance * &x).dot(&x);
118        let k = 1.0 / (1.0 + (PI * s) / 8.0).sqrt();
119        let p = logistic(k * mu);
120        BernoulliDistribution::new(p)
121    }
122
123    fn weights(&self) -> &[f64] {
124        get_weights(self.includes_bias, &self.params)
125    }
126
127    fn bias(&self) -> Option<f64> {
128        get_bias(self.includes_bias, &self.params)
129    }
130}
131
132// Defines the regression problem
133struct Problem {
134    pub x: DenseMatrix, // features
135    pub y: DenseVector, // labels
136    pub xtr: DenseVector, // t(x) * (y - 0.5)
137    pub theta: DenseVector, // parameters (bias & weights)
138    pub s: DenseMatrix, // covariance
139    pub alpha: Vec<GammaDistribution>, // parameter precisions
140    pub zeta: DenseVector, // variational parameters
141    pub bpp: Option<GammaDistribution>, // bias prior precision
142    pub wpp: GammaDistribution, // weight prior precision
143    pub n: usize, // number of training examples
144    pub d: usize, // feature dimensionality (inclusing bias)
145    pub bound: f64, // variational lower bound
146    pub standardizer: Option<Standardizer> // feature standardizer
147}
148
149impl Problem {
150
151    fn new(
152        features: impl Features,
153        labels: impl BinaryLabels,
154        config: &LogisticTrainConfig
155    ) -> Problem {
156        let mut x = features.into_matrix(config.use_bias);
157        let standardizer = if config.standardize {
158            Some(Standardizer::fit(&x))
159        } else {
160            None
161        };
162        if let Some(std) = &standardizer {
163            std.transform_matrix(&mut x);
164        }
165        let n = x.nrows();
166        let d = x.ncols();
167        let y = labels.into_vector();
168        let xtr = x.tr_mul(&y.map(|v| v - 0.5));
169        let wpp = config.weight_precision_prior;
170        let bpp = if config.use_bias {
171            Some(GammaDistribution::vague())
172        } else {
173            None
174        };
175        let mut alpha = vec![wpp; x.ncols()];
176        if let Some(pp) = bpp {
177            alpha[0] = pp;
178        }
179        let zeta = DenseVector::from_element(n, 1.0);
180        let bound = f64::NEG_INFINITY;
181        let theta = DenseVector::zeros(d);
182        let s = DenseMatrix::zeros(d, d);
183        Problem { x, y, xtr, theta, s, alpha, zeta, bpp, wpp, n, d, bound, standardizer }
184    }
185
186    fn param_precision_prior(&self, ind: usize) -> GammaDistribution {
187        match (ind, self.bpp) {
188            (0, Some(bpp)) => bpp,
189            _ => self.wpp
190        }
191    }
192}
193
194fn lambda(val: f64) -> f64 {
195    (1.0 / (2.0 * val)) * (logistic(val) - 0.5)
196}
197
198// Factorized distribution for parameter values
199fn q_theta(prob: &mut Problem) -> Result<(), RegressionError> {
200    let a = DenseVector::from(prob.alpha.iter().map(|alpha| alpha.mean()).collect::<Vec<f64>>());
201    let mut s_inv = DenseMatrix::from_diagonal(&a);
202    let lambdas = prob.zeta.map(|z| lambda(z));
203    s_inv += prob.x.tr_mul(&scale_rows(&prob.x, &lambdas)) * 2.0;
204    prob.s = Cholesky::new(s_inv)
205        .ok_or(RegressionError::CholeskyFailure)?
206        .inverse();
207    prob.theta = &prob.s * &prob.xtr;
208    Ok(())
209}
210
211// Factorized distribution for parameter precisions
212fn q_alpha(prob: &mut Problem) -> Result<(), RegressionError> {
213    for i in 0..prob.d {
214        let pp = prob.param_precision_prior(i);
215        let inv_scale = pp.rate + 0.5 * (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]);
216        prob.alpha[i] = GammaDistribution::new(pp.shape + 0.5, inv_scale)?;
217    }
218    Ok(())
219}
220
221// Update zeta values
222fn update_zeta(prob: &mut Problem) -> Result<(), RegressionError> {
223    let a = &prob.s + (&prob.theta * prob.theta.transpose());
224    let iter = prob.x.row_iter().map(|xi| {
225        (&xi * &a).dot(&xi).sqrt()
226    });
227    prob.zeta = DenseVector::from_iterator(prob.n, iter);
228    Ok(())
229}
230
231// Variational lower bound given current model parameters
232fn lower_bound(prob: &Problem) -> Result<f64, RegressionError> {
233    Ok(expect_ln_p_y(prob)? +
234    expect_ln_p_theta(prob)? +
235    expect_ln_p_alpha(prob)? -
236    expect_ln_q_theta(prob)? -
237    expect_ln_q_alpha(prob)?)
238}
239
240// E[ln p(y|theta)]
241fn expect_ln_p_y(prob: &Problem) -> Result<f64, RegressionError> {
242    let part1 = prob.zeta.map(lambda);
243    let part2 = prob.zeta.map(|z| logistic(z).ln()).sum();
244    let part3 = &prob.x * &prob.theta;
245    let part4 = (&part3.transpose() * prob.y.map(|y| y - 0.5)).sum();
246    let part5 = prob.zeta.sum() / 2.0;
247    let part6 = (part3.map(|v| v * v).transpose() * &part1).sum();
248    let part7 = trace_of_product(&scale_rows(&(&prob.x * &prob.s), &part1), &prob.x.transpose());
249    let part8 = part1.component_mul(&prob.zeta.map(|z| z * z)).sum();
250    Ok(part2 + part4 - part5 - part6 - part7 + part8)
251}
252
253// E[ln p(theta|alpha)]
254fn expect_ln_p_theta(prob: &Problem) -> Result<f64, RegressionError> {
255    let init = (prob.theta.len() as f64 * -0.5) * LN_2PI;
256    prob.alpha.iter().enumerate().try_fold(init, |sum, (i, a)| {
257        let am = a.mean();
258        let part1 = Gamma::digamma(a.shape) - a.rate.ln();
259        let part2 = (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]) * am;
260        Ok(sum + 0.5 * (part1 - part2))
261    })
262}
263
264// E[ln p(alpha)]
265fn expect_ln_p_alpha(prob: &Problem) -> Result<f64, RegressionError> {
266    prob.alpha.iter().enumerate().try_fold(0.0, |sum, (i, a)| {
267        let am = a.mean();
268        let pp = prob.param_precision_prior(i);
269        let term1 = pp.shape * pp.rate.ln();
270        let term2 = (pp.shape - 1.0) * (Gamma::digamma(a.shape) - a.rate.ln());
271        let term3 = (pp.rate * am) + Gamma::ln_gamma(prob.wpp.shape).0;
272        Ok(sum + term1 + term2 - term3)
273    })
274}
275
276// E[ln q(theta)]
277fn expect_ln_q_theta(prob: &Problem) -> Result<f64, RegressionError> {
278    let m = prob.s.shape().0;
279    let chol = Cholesky::new(prob.s.clone()).unwrap().l();
280    let mut ln_det = 0.0;
281    for i in 0..prob.s.ncols() {
282        ln_det += chol[(i, i)].ln();
283    }
284    ln_det *= 2.0;
285    Ok(-(0.5 * ln_det + (m as f64 / 2.0) * (1.0 + LN_2PI)))
286}
287
288// E[ln q(alpha)]
289fn expect_ln_q_alpha(prob: &Problem) -> Result<f64, RegressionError> {
290    prob.alpha.iter().try_fold(0.0, |sum, a| {
291        let part1 = Gamma::ln_gamma(a.shape).0;
292        let part2 = (a.shape - 1.0) * Gamma::digamma(a.shape);
293        let part3 = a.shape - a.rate.ln();
294        Ok(sum - (part1 - part2 + part3))
295    })
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use assert_approx_eq::assert_approx_eq;
302
303    const FEATURES: [[f64; 4]; 10] = [
304        [-0.2, -0.9, -0.5, 0.3],
305        [0.6, 0.3, 0.3, -0.4],
306        [0.9, -0.4, -0.5, -0.6],
307        [-0.7, 0.8, 0.3, -0.3],
308        [-0.5, -0.7, -0.1, 0.8],
309        [0.5, 0.5, 0.0, 0.1],
310        [0.1, -0.0, 0.0, -0.2],
311        [0.4, 0.0, 0.2, 0.0],
312        [-0.2, 0.9, -0.1, -0.9],
313        [0.1, 0.4, -0.5, 0.9],
314    ];
315
316    const LABELS: [bool; 10] = [
317        true, false, true, false, true, false, true, false, true, false
318    ];
319
320    #[test]
321    fn test_train_with_bias_with_standardize() {
322        let x = Vec::from(FEATURES.map(Vec::from));
323        let y = Vec::from(LABELS);
324        let config = LogisticTrainConfig::default();
325        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
326        assert_approx_eq!(model.bias().unwrap(), 0.00951244801510034);
327        assert_approx_eq!(model.weights()[0], -0.19303165213334386);
328        assert_approx_eq!(model.weights()[1], -1.2534945326354745);
329        assert_approx_eq!(model.weights()[2], -0.6963518106208433);
330        assert_approx_eq!(model.weights()[3], -0.8508100398896856);
331    }
332
333    #[test]
334    fn test_train_with_bias_no_standardize() {
335        let x = Vec::from(FEATURES.map(Vec::from));
336        let y = Vec::from(LABELS);
337        let config = LogisticTrainConfig {
338            standardize: false,
339            ..Default::default()
340        };
341        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
342        assert_approx_eq!(model.bias().unwrap(), 0.0043520654824470515);
343        assert_approx_eq!(model.weights()[0], -0.10946450049722892);
344        assert_approx_eq!(model.weights()[1], -1.6472155373009127);
345        assert_approx_eq!(model.weights()[2], -1.215877178138718);
346        assert_approx_eq!(model.weights()[3], -0.7679465673373882);
347    }
348
349    #[test]
350    fn test_train_no_bias_with_standardize() {
351        let x = Vec::from(FEATURES.map(Vec::from));
352        let y = Vec::from(LABELS);
353        let config = LogisticTrainConfig {
354            use_bias: false,
355            ..Default::default()
356        };
357        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
358        assert_approx_eq!(model.weights()[0], -0.22479264662358672);
359        assert_approx_eq!(model.weights()[1], -1.194338553263914);
360        assert_approx_eq!(model.weights()[2], -0.6763443319536045);
361        assert_approx_eq!(model.weights()[3], -0.793934474799946);
362    }
363
364    #[test]
365    fn test_train_no_bias_no_standardize() {
366        let x = Vec::from(FEATURES.map(Vec::from));
367        let y = Vec::from(LABELS);
368        let config = LogisticTrainConfig {
369            use_bias: false,
370            standardize: false,
371            ..Default::default()
372        };
373        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
374        assert_approx_eq!(model.weights()[0], -0.11478846445757208);
375        assert_approx_eq!(model.weights()[1], -1.6111314555274376);
376        assert_approx_eq!(model.weights()[2], -1.0489256680896761);
377        assert_approx_eq!(model.weights()[3], -0.6788653466293544);
378    }
379
380    #[test]
381    fn test_predict_with_bias_with_standardize() {
382        let x = Vec::from(FEATURES.map(Vec::from));
383        let y = Vec::from(LABELS);
384        let config = LogisticTrainConfig::default();
385        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
386        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
387        assert_approx_eq!(p, 0.27380317759208006);
388    }
389
390    #[test]
391    fn test_predict_with_bias_no_standardize() {
392        let x = Vec::from(FEATURES.map(Vec::from));
393        let y = Vec::from(LABELS);
394        let config = LogisticTrainConfig {
395            standardize: false,
396            ..Default::default()
397        };
398        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
399        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
400        assert_approx_eq!(p, 0.2956358962602995);
401    }
402
403    #[test]
404    fn test_predict_no_bias_with_standardize() {
405        let x = Vec::from(FEATURES.map(Vec::from));
406        let y = Vec::from(LABELS);
407        let config = LogisticTrainConfig {
408            use_bias: false,
409            ..Default::default()
410        };
411        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
412        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
413        assert_approx_eq!(p, 0.275642768184428);
414    }
415
416    #[test]
417    fn test_predict_no_bias_no_standardize() {
418        let x = Vec::from(FEATURES.map(Vec::from));
419        let y = Vec::from(LABELS);
420        let config = LogisticTrainConfig {
421            use_bias: false,
422            standardize: false,
423            ..Default::default()
424        };
425        let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
426        let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
427        assert_approx_eq!(p, 0.29090997574190514);
428    }
429}