1use std::f64::consts::PI;
2use nalgebra::{Cholesky, DVector, DMatrix};
3use special::Gamma;
4use serde::{Serialize, Deserialize};
5
6use crate::{BinaryLabels, Features, design_vector, Standardizer, VariationalRegression, get_weights, get_bias};
7use crate::distribution::{GammaDistribution, BernoulliDistribution, ScalarDistribution};
8use crate::error::RegressionError;
9use crate::math::{LN_2PI, logistic, trace_of_product, scale_rows};
10
11type DenseVector = DVector<f64>;
12type DenseMatrix = DMatrix<f64>;
13
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct LogisticTrainConfig {
21 pub weight_precision_prior: GammaDistribution,
23 pub use_bias: bool,
25 pub standardize: bool,
27 pub max_iter: usize,
29 pub tolerance: f64,
31 pub verbose: bool
33}
34
35impl Default for LogisticTrainConfig {
36 fn default() -> Self {
37 LogisticTrainConfig {
38 weight_precision_prior: GammaDistribution::vague(),
39 use_bias: true,
40 standardize: true,
41 max_iter: 1000,
42 tolerance: 1e-4,
43 verbose: true
44 }
45 }
46}
47
48#[derive(Clone, Serialize, Deserialize)]
52pub struct VariationalLogisticRegression {
53 params: DenseVector,
55 covariance: DenseMatrix,
57 includes_bias: bool,
59 standardizer: Option<Standardizer>,
61 pub bound: f64
63}
64
65impl VariationalLogisticRegression {
66
67 pub fn train(
77 features: impl Features,
78 labels: impl BinaryLabels,
79 config: &LogisticTrainConfig
80 ) -> Result<VariationalLogisticRegression, RegressionError> {
81 let mut problem = Problem::new(features, labels, config);
83 for iter in 0..config.max_iter {
85 q_theta(&mut problem)?; q_alpha(&mut problem)?; update_zeta(&mut problem)?; let new_bound = lower_bound(&problem)?;
89 if config.verbose {
90 println!("Iteration {}, Lower Bound = {}", iter + 1, new_bound);
91 }
92 if (new_bound - problem.bound) / problem.bound.abs() <= config.tolerance {
93 return Ok(VariationalLogisticRegression {
94 params: problem.theta,
95 covariance: problem.s,
96 includes_bias: config.use_bias,
97 standardizer: problem.standardizer,
98 bound: new_bound
99 })
100 } else {
101 problem.bound = new_bound;
102 }
103 }
104 Err(RegressionError::ConvergenceFailure(config.max_iter))
106 }
107}
108
109impl VariationalRegression<BernoulliDistribution> for VariationalLogisticRegression {
110
111 fn predict(&self, features: &[f64]) -> Result<BernoulliDistribution, RegressionError> {
112 let mut x = design_vector(features, self.includes_bias);
113 if let Some(std) = &self.standardizer {
114 std.transform_vector(&mut x);
115 }
116 let mu = x.dot(&self.params);
117 let s = (&self.covariance * &x).dot(&x);
118 let k = 1.0 / (1.0 + (PI * s) / 8.0).sqrt();
119 let p = logistic(k * mu);
120 BernoulliDistribution::new(p)
121 }
122
123 fn weights(&self) -> &[f64] {
124 get_weights(self.includes_bias, &self.params)
125 }
126
127 fn bias(&self) -> Option<f64> {
128 get_bias(self.includes_bias, &self.params)
129 }
130}
131
132struct Problem {
134 pub x: DenseMatrix, pub y: DenseVector, pub xtr: DenseVector, pub theta: DenseVector, pub s: DenseMatrix, pub alpha: Vec<GammaDistribution>, pub zeta: DenseVector, pub bpp: Option<GammaDistribution>, pub wpp: GammaDistribution, pub n: usize, pub d: usize, pub bound: f64, pub standardizer: Option<Standardizer> }
148
149impl Problem {
150
151 fn new(
152 features: impl Features,
153 labels: impl BinaryLabels,
154 config: &LogisticTrainConfig
155 ) -> Problem {
156 let mut x = features.into_matrix(config.use_bias);
157 let standardizer = if config.standardize {
158 Some(Standardizer::fit(&x))
159 } else {
160 None
161 };
162 if let Some(std) = &standardizer {
163 std.transform_matrix(&mut x);
164 }
165 let n = x.nrows();
166 let d = x.ncols();
167 let y = labels.into_vector();
168 let xtr = x.tr_mul(&y.map(|v| v - 0.5));
169 let wpp = config.weight_precision_prior;
170 let bpp = if config.use_bias {
171 Some(GammaDistribution::vague())
172 } else {
173 None
174 };
175 let mut alpha = vec![wpp; x.ncols()];
176 if let Some(pp) = bpp {
177 alpha[0] = pp;
178 }
179 let zeta = DenseVector::from_element(n, 1.0);
180 let bound = f64::NEG_INFINITY;
181 let theta = DenseVector::zeros(d);
182 let s = DenseMatrix::zeros(d, d);
183 Problem { x, y, xtr, theta, s, alpha, zeta, bpp, wpp, n, d, bound, standardizer }
184 }
185
186 fn param_precision_prior(&self, ind: usize) -> GammaDistribution {
187 match (ind, self.bpp) {
188 (0, Some(bpp)) => bpp,
189 _ => self.wpp
190 }
191 }
192}
193
194fn lambda(val: f64) -> f64 {
195 (1.0 / (2.0 * val)) * (logistic(val) - 0.5)
196}
197
198fn q_theta(prob: &mut Problem) -> Result<(), RegressionError> {
200 let a = DenseVector::from(prob.alpha.iter().map(|alpha| alpha.mean()).collect::<Vec<f64>>());
201 let mut s_inv = DenseMatrix::from_diagonal(&a);
202 let lambdas = prob.zeta.map(|z| lambda(z));
203 s_inv += prob.x.tr_mul(&scale_rows(&prob.x, &lambdas)) * 2.0;
204 prob.s = Cholesky::new(s_inv)
205 .ok_or(RegressionError::CholeskyFailure)?
206 .inverse();
207 prob.theta = &prob.s * &prob.xtr;
208 Ok(())
209}
210
211fn q_alpha(prob: &mut Problem) -> Result<(), RegressionError> {
213 for i in 0..prob.d {
214 let pp = prob.param_precision_prior(i);
215 let inv_scale = pp.rate + 0.5 * (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]);
216 prob.alpha[i] = GammaDistribution::new(pp.shape + 0.5, inv_scale)?;
217 }
218 Ok(())
219}
220
221fn update_zeta(prob: &mut Problem) -> Result<(), RegressionError> {
223 let a = &prob.s + (&prob.theta * prob.theta.transpose());
224 let iter = prob.x.row_iter().map(|xi| {
225 (&xi * &a).dot(&xi).sqrt()
226 });
227 prob.zeta = DenseVector::from_iterator(prob.n, iter);
228 Ok(())
229}
230
231fn lower_bound(prob: &Problem) -> Result<f64, RegressionError> {
233 Ok(expect_ln_p_y(prob)? +
234 expect_ln_p_theta(prob)? +
235 expect_ln_p_alpha(prob)? -
236 expect_ln_q_theta(prob)? -
237 expect_ln_q_alpha(prob)?)
238}
239
240fn expect_ln_p_y(prob: &Problem) -> Result<f64, RegressionError> {
242 let part1 = prob.zeta.map(lambda);
243 let part2 = prob.zeta.map(|z| logistic(z).ln()).sum();
244 let part3 = &prob.x * &prob.theta;
245 let part4 = (&part3.transpose() * prob.y.map(|y| y - 0.5)).sum();
246 let part5 = prob.zeta.sum() / 2.0;
247 let part6 = (part3.map(|v| v * v).transpose() * &part1).sum();
248 let part7 = trace_of_product(&scale_rows(&(&prob.x * &prob.s), &part1), &prob.x.transpose());
249 let part8 = part1.component_mul(&prob.zeta.map(|z| z * z)).sum();
250 Ok(part2 + part4 - part5 - part6 - part7 + part8)
251}
252
253fn expect_ln_p_theta(prob: &Problem) -> Result<f64, RegressionError> {
255 let init = (prob.theta.len() as f64 * -0.5) * LN_2PI;
256 prob.alpha.iter().enumerate().try_fold(init, |sum, (i, a)| {
257 let am = a.mean();
258 let part1 = Gamma::digamma(a.shape) - a.rate.ln();
259 let part2 = (prob.theta[i] * prob.theta[i] + prob.s[(i, i)]) * am;
260 Ok(sum + 0.5 * (part1 - part2))
261 })
262}
263
264fn expect_ln_p_alpha(prob: &Problem) -> Result<f64, RegressionError> {
266 prob.alpha.iter().enumerate().try_fold(0.0, |sum, (i, a)| {
267 let am = a.mean();
268 let pp = prob.param_precision_prior(i);
269 let term1 = pp.shape * pp.rate.ln();
270 let term2 = (pp.shape - 1.0) * (Gamma::digamma(a.shape) - a.rate.ln());
271 let term3 = (pp.rate * am) + Gamma::ln_gamma(prob.wpp.shape).0;
272 Ok(sum + term1 + term2 - term3)
273 })
274}
275
276fn expect_ln_q_theta(prob: &Problem) -> Result<f64, RegressionError> {
278 let m = prob.s.shape().0;
279 let chol = Cholesky::new(prob.s.clone()).unwrap().l();
280 let mut ln_det = 0.0;
281 for i in 0..prob.s.ncols() {
282 ln_det += chol[(i, i)].ln();
283 }
284 ln_det *= 2.0;
285 Ok(-(0.5 * ln_det + (m as f64 / 2.0) * (1.0 + LN_2PI)))
286}
287
288fn expect_ln_q_alpha(prob: &Problem) -> Result<f64, RegressionError> {
290 prob.alpha.iter().try_fold(0.0, |sum, a| {
291 let part1 = Gamma::ln_gamma(a.shape).0;
292 let part2 = (a.shape - 1.0) * Gamma::digamma(a.shape);
293 let part3 = a.shape - a.rate.ln();
294 Ok(sum - (part1 - part2 + part3))
295 })
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use assert_approx_eq::assert_approx_eq;
302
303 const FEATURES: [[f64; 4]; 10] = [
304 [-0.2, -0.9, -0.5, 0.3],
305 [0.6, 0.3, 0.3, -0.4],
306 [0.9, -0.4, -0.5, -0.6],
307 [-0.7, 0.8, 0.3, -0.3],
308 [-0.5, -0.7, -0.1, 0.8],
309 [0.5, 0.5, 0.0, 0.1],
310 [0.1, -0.0, 0.0, -0.2],
311 [0.4, 0.0, 0.2, 0.0],
312 [-0.2, 0.9, -0.1, -0.9],
313 [0.1, 0.4, -0.5, 0.9],
314 ];
315
316 const LABELS: [bool; 10] = [
317 true, false, true, false, true, false, true, false, true, false
318 ];
319
320 #[test]
321 fn test_train_with_bias_with_standardize() {
322 let x = Vec::from(FEATURES.map(Vec::from));
323 let y = Vec::from(LABELS);
324 let config = LogisticTrainConfig::default();
325 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
326 assert_approx_eq!(model.bias().unwrap(), 0.00951244801510034);
327 assert_approx_eq!(model.weights()[0], -0.19303165213334386);
328 assert_approx_eq!(model.weights()[1], -1.2534945326354745);
329 assert_approx_eq!(model.weights()[2], -0.6963518106208433);
330 assert_approx_eq!(model.weights()[3], -0.8508100398896856);
331 }
332
333 #[test]
334 fn test_train_with_bias_no_standardize() {
335 let x = Vec::from(FEATURES.map(Vec::from));
336 let y = Vec::from(LABELS);
337 let config = LogisticTrainConfig {
338 standardize: false,
339 ..Default::default()
340 };
341 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
342 assert_approx_eq!(model.bias().unwrap(), 0.0043520654824470515);
343 assert_approx_eq!(model.weights()[0], -0.10946450049722892);
344 assert_approx_eq!(model.weights()[1], -1.6472155373009127);
345 assert_approx_eq!(model.weights()[2], -1.215877178138718);
346 assert_approx_eq!(model.weights()[3], -0.7679465673373882);
347 }
348
349 #[test]
350 fn test_train_no_bias_with_standardize() {
351 let x = Vec::from(FEATURES.map(Vec::from));
352 let y = Vec::from(LABELS);
353 let config = LogisticTrainConfig {
354 use_bias: false,
355 ..Default::default()
356 };
357 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
358 assert_approx_eq!(model.weights()[0], -0.22479264662358672);
359 assert_approx_eq!(model.weights()[1], -1.194338553263914);
360 assert_approx_eq!(model.weights()[2], -0.6763443319536045);
361 assert_approx_eq!(model.weights()[3], -0.793934474799946);
362 }
363
364 #[test]
365 fn test_train_no_bias_no_standardize() {
366 let x = Vec::from(FEATURES.map(Vec::from));
367 let y = Vec::from(LABELS);
368 let config = LogisticTrainConfig {
369 use_bias: false,
370 standardize: false,
371 ..Default::default()
372 };
373 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
374 assert_approx_eq!(model.weights()[0], -0.11478846445757208);
375 assert_approx_eq!(model.weights()[1], -1.6111314555274376);
376 assert_approx_eq!(model.weights()[2], -1.0489256680896761);
377 assert_approx_eq!(model.weights()[3], -0.6788653466293544);
378 }
379
380 #[test]
381 fn test_predict_with_bias_with_standardize() {
382 let x = Vec::from(FEATURES.map(Vec::from));
383 let y = Vec::from(LABELS);
384 let config = LogisticTrainConfig::default();
385 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
386 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
387 assert_approx_eq!(p, 0.27380317759208006);
388 }
389
390 #[test]
391 fn test_predict_with_bias_no_standardize() {
392 let x = Vec::from(FEATURES.map(Vec::from));
393 let y = Vec::from(LABELS);
394 let config = LogisticTrainConfig {
395 standardize: false,
396 ..Default::default()
397 };
398 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
399 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
400 assert_approx_eq!(p, 0.2956358962602995);
401 }
402
403 #[test]
404 fn test_predict_no_bias_with_standardize() {
405 let x = Vec::from(FEATURES.map(Vec::from));
406 let y = Vec::from(LABELS);
407 let config = LogisticTrainConfig {
408 use_bias: false,
409 ..Default::default()
410 };
411 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
412 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
413 assert_approx_eq!(p, 0.275642768184428);
414 }
415
416 #[test]
417 fn test_predict_no_bias_no_standardize() {
418 let x = Vec::from(FEATURES.map(Vec::from));
419 let y = Vec::from(LABELS);
420 let config = LogisticTrainConfig {
421 use_bias: false,
422 standardize: false,
423 ..Default::default()
424 };
425 let model = VariationalLogisticRegression::train(&x, &y, &config).unwrap();
426 let p = model.predict(&vec![0.3, 0.8, -0.1, -0.3]).unwrap().mean();
427 assert_approx_eq!(p, 0.29090997574190514);
428 }
429}