truncnorm/
truncnorm.rs

1#![allow(non_snake_case)]
2#![allow(clippy::many_single_char_names)]
3//! Rust re-write of [Truncated Normal and Student's t-distribution toolbox](https://www.mathworks.com/matlabcentral/fileexchange/53796-truncated-normal-and-student-s-t-distribution-toolbox)
4//!
5//! Reference: Z. I. Botev (2017), _The Normal Law Under Linear Restrictions:
6//! Simulation and Estimation via Minimax Tilting_, Journal of the Royal
7//! Statistical Society, Series B, Volume 79, Part 1, pp. 1-24
8use crate::dist_util::cholperm;
9use crate::dist_util::ln_normal_pr;
10use crate::faddeeva::erfc;
11use crate::tilting::TiltingProblem;
12use crate::tilting::TiltingSolution;
13use crate::util;
14use ndarray::azip;
15use ndarray::Zip;
16use ndarray::{s, Axis};
17use ndarray::{Array1, Array2};
18use ndarray_rand::rand_distr::StandardNormal;
19use ndarray_rand::RandomExt;
20use num::traits::FloatConst;
21use rand::distributions::Uniform;
22use rand::Rng;
23use statrs::function::erf::erfc_inv;
24
25fn ntail<R: Rng + ?Sized>(
26    l: &Array1<f64>,
27    u: &Array1<f64>,
28    max_iters: usize,
29    rng: &mut R,
30) -> Array1<f64> {
31    /*
32    % samples a vector from the standard normal
33    % distribution truncated over the region [l,u], where l>0 and
34    % l and u are column vectors; uses acceptance-rejection from Rayleigh
35    % distr. Similar to Marsaglia (1964)
36    */
37    let c = l.map(|x| x.powi(2) / 2.);
38    let f = Zip::from(&c)
39        .and(u)
40        .map_collect(|&c, &u| (c - u * u / 2.).exp_m1());
41    // use rejection sample pattern
42    let mut accept_condition = |x: &Array1<f64>, accepted: &mut Array1<bool>, rng: &mut R| {
43        let test_sample: Array1<f64> = Array1::random_using(l.len(), Uniform::new(0., 1.), rng);
44        azip!((x in x, &s in &test_sample, &c in &c, acc in accepted) {
45            if s * s * x < c {
46            *acc = true;
47            }
48        });
49    };
50    let mut proposal_sampler = |rng: &mut R| {
51        let sample = Array1::random_using(l.len(), Uniform::new(0., 1.), rng);
52        &c - (1. + sample * &f).mapv(f64::ln)
53    };
54    let mut output_array =
55        util::rejection_sample(&mut accept_condition, &mut proposal_sampler, max_iters, rng);
56    output_array.mapv_inplace(|x| (2. * x).sqrt());
57    output_array
58}
59
60fn trnd<R: Rng + ?Sized>(
61    l: &Array1<f64>,
62    u: &Array1<f64>,
63    max_iters: usize,
64    rng: &mut R,
65) -> Array1<f64> {
66    // use accept-reject pattern to sample from truncated N(0,1)
67    let mut accept_condition = |x: &Array1<f64>, accepted: &mut Array1<bool>, _rng: &mut R| {
68        azip!((x in x, l in l, u in u, acc in accepted) {
69            if x > l && x < u {
70            *acc = true;
71            }
72        });
73    };
74    let mut proposal_sampler = |rng: &mut R| Array1::random_using(l.len(), StandardNormal, rng);
75    util::rejection_sample(&mut accept_condition, &mut proposal_sampler, max_iters, rng)
76}
77
78fn tn<R: Rng + ?Sized>(
79    l: &Array1<f64>,
80    u: &Array1<f64>,
81    max_iters: usize,
82    rng: &mut R,
83) -> Array1<f64> {
84    /*
85    % samples a column vector of length=length(l)=length(u)
86    % from the standard multivariate normal distribution,
87    % truncated over the region [l,u], where -a<l<u<a for some
88    % 'a' and l and u are column vectors;
89    % uses acceptance rejection and inverse-transform method;
90    */
91    // controls switch between methods
92    let tol = 2.;
93    // threshold can be tuned for maximum speed for each platform
94    // case: abs(u-l)>tol, uses accept-reject from randn
95    let mut coeff = Array1::ones(l.len());
96    let gap = (u - l).map(|x| x.abs());
97    let mut tl = l.clone();
98    let mut tu = u.clone();
99    azip!((gap in &gap, coeff in &mut coeff, tl in &mut tl, tu in &mut tu) if *gap < tol {*coeff = 0.;*tl=f64::NEG_INFINITY;*tu=f64::INFINITY;});
100    let accept_reject = trnd(&tl, &tu, max_iters, rng);
101    // case: abs(u-l)<tol, uses inverse-transform
102    let pl = (&tl * f64::FRAC_1_SQRT_2()).map(|x| erfc(*x) / 2.);
103    let pu = (&tu * f64::FRAC_1_SQRT_2()).map(|x| erfc(*x) / 2.);
104    let sample = Array1::random_using(l.len(), Uniform::new(0., 1.), rng);
105
106    let inverse_transform =
107        f64::SQRT_2() * (2. * (&pl - (&pl - &pu) * sample)).map(|x| erfc_inv(*x));
108    let mut result = &coeff * &accept_reject + (1. - &coeff) * &inverse_transform;
109    if result.iter().any(|x| x.is_nan()) {
110        result = coeff
111            .iter()
112            .zip(inverse_transform.iter())
113            .zip(accept_reject.iter())
114            .map(|x| if *x.0 .0 == 0. { *x.0 .1 } else { *x.1 })
115            .collect();
116    }
117    result
118}
119
120/// fast truncated normal generator
121///
122///  Infinite values for 'u' and 'l' are accepted;
123///
124///  If you wish to simulate a random variable
125/// 'Z' from the non-standard Gaussian $N(\mu,\sigma^2)$
126///  conditional on $l<Z<u$, first simulate
127///  $X=trandn((l-m)/s,(u-m)/s)$ and set $Z=\mu+\sigma X$;
128pub fn trandn<R: Rng + ?Sized>(
129    l: &Array1<f64>,
130    u: &Array1<f64>,
131    max_iters: usize,
132    rng: &mut R,
133) -> Array1<f64> {
134    let thresh = 0.66; // tunable threshold to choose method
135    let mut tl = l.clone();
136    let mut tu = u.clone();
137    let mut coeff = Array1::zeros(l.len());
138    azip!((tl in &mut tl, tu in &mut tu, coeff in &mut coeff) {
139        if *tl > thresh {*coeff = 1.}
140        else if *tu < -thresh {*tl = -*tu; *tu = -*tl; *coeff = -1.}
141        else {*tl = -100.; *tu = 100.; *coeff=0.;} // sample from another method, set params to always accept
142    });
143    let acc_rej_sample = ntail(&tl, &tu, max_iters, rng);
144    let trunc_norm_sample = tn(l, u, max_iters, rng);
145    &coeff * acc_rej_sample + (1. - &coeff.mapv(f64::abs)) * trunc_norm_sample
146}
147
148fn psy(
149    x: &Array1<f64>,
150    L: &Array2<f64>,
151    l: &Array1<f64>,
152    u: &Array1<f64>,
153    mu: &Array1<f64>,
154) -> f64 {
155    // implements psi(x,mu); assumes scaled 'L' without diagonal;
156    let mut temp = Array1::zeros(x.len() + 1);
157    temp.slice_mut(s![..x.len()]).assign(x);
158    let x = temp;
159    let mut temp = Array1::zeros(mu.len() + 1);
160    temp.slice_mut(s![..mu.len()]).assign(mu);
161    let mu = temp;
162    // compute now ~l and ~u
163    let c = L.dot(&x);
164    let tl = l - &mu - &c;
165    let tu = u - &mu - &c;
166    (ln_normal_pr(&tl, &tu) + 0.5 * mu.mapv(|x| x * x) - x * mu).sum()
167}
168
169/*
170% computes P(l<X<u), where X is normal with
171% 'Cov(X)=L*L' and zero mean vector;
172% exponential tilting uses parameter 'mu';
173% Monte Carlo uses 'n' samples;
174*/
175fn mv_normal_pr<R: Rng + ?Sized>(
176    n: usize,
177    L: &Array2<f64>,
178    l: &Array1<f64>,
179    u: &Array1<f64>,
180    mu: &Array1<f64>,
181    max_iters: usize,
182    rng: &mut R,
183) -> (f64, f64) {
184    let d = l.shape()[0];
185    let mut p = Array1::zeros(n);
186    let mut temp = Array1::zeros(mu.shape()[0] + 1);
187    temp.slice_mut(s![..d - 1]).assign(mu);
188    let mu = temp;
189    let mut Z = Array2::zeros((d, n));
190    let mut col;
191    let mut tl;
192    let mut tu;
193    for k in 0..d - 1 {
194        col = L.slice(s![k, ..k]).dot(&Z.slice(s![..k, ..]));
195        tl = l[[k]] - mu[[k]] - &col;
196        tu = u[[k]] - mu[[k]] - col;
197        // simulate N(mu, 1) conditional on [tl,tu]
198        Z.index_axis_mut(Axis(0), k)
199            .assign(&(mu[[k]] + trandn(&tl, &tu, max_iters, rng)));
200        // update likelihood ratio
201        p = p + ln_normal_pr(&tl, &tu) + 0.5 * mu[[k]].powi(2)
202            - mu[[k]] * &Z.index_axis(Axis(0), k);
203    }
204    col = L.index_axis(Axis(0), d - 1).dot(&Z);
205    tl = l[d - 1] - &col;
206    tu = u[d - 1] - col;
207    p = p + ln_normal_pr(&tl, &tu);
208    p.mapv_inplace(f64::exp);
209    let prob = p.mean().unwrap();
210    debug_assert!(
211        !prob.is_sign_negative(),
212        "Returned invalid probability, {:?}",
213        prob
214    );
215    let rel_err = p.std(0.) / (n as f64).sqrt() / prob;
216    (prob, rel_err)
217}
218
219fn mv_truncnorm_proposal<R: Rng + ?Sized>(
220    L: &Array2<f64>,
221    l: &Array1<f64>,
222    u: &Array1<f64>,
223    mu: &Array1<f64>,
224    n: usize,
225    max_iters: usize,
226    rng: &mut R,
227) -> (Array1<f64>, Array2<f64>) {
228    /*
229    % generates the proposals from the exponentially tilted
230    % sequential importance sampling pdf;
231    % output:    'p', log-likelihood of sample
232    %             Z, random sample
233    */
234    let d = l.shape()[0];
235    let mut logp = Array1::zeros(n);
236    let mut temp = Array1::zeros(mu.shape()[0] + 1);
237    temp.slice_mut(s![..d - 1]).assign(mu);
238    let mu = temp;
239    let mut Z = Array2::zeros((d, n));
240    let mut col;
241    let mut tl;
242    let mut tu;
243    for k in 0..d {
244        col = L.slice(s![k, ..k]).dot(&Z.slice(s![..k, ..]));
245        tl = l[[k]] - mu[[k]] - &col;
246        tu = u[[k]] - mu[[k]] - col;
247        // simulate N(mu, 1) conditional on [tl,tu]
248        Z.index_axis_mut(Axis(0), k)
249            .assign(&(mu[[k]] + trandn(&tl, &tu, max_iters, rng)));
250        // update likelihood ratio
251        logp = logp + ln_normal_pr(&tl, &tu) + 0.5 * mu[[k]] * mu[[k]]
252            - mu[[k]] * &Z.index_axis(Axis(0), k);
253    }
254    (logp, Z)
255}
256
257pub fn solved_mv_truncnormal_rand<R: Rng + ?Sized>(
258    tilting_solution: &TiltingSolution,
259    mut l: Array1<f64>,
260    mut u: Array1<f64>,
261    mut sigma: Array2<f64>,
262    n: usize,
263    max_iters: usize,
264    rng: &mut R,
265) -> Array2<f64> {
266    let d = l.len();
267    let (Lfull, perm) = cholperm(&mut sigma, &mut l, &mut u);
268    let D = Lfull.diag().to_owned();
269
270    u /= &D;
271    l /= &D;
272    let L = (&Lfull / &(Array2::<f64>::zeros([D.len(), D.len()]) + &D)) - Array2::<f64>::eye(d);
273
274    let x = &tilting_solution.x;
275    let mu = &tilting_solution.mu;
276    let psi_star = psy(x, &L, &l, &u, mu); // compute psi star
277    let (logp, mut Z) = mv_truncnorm_proposal(&L, &l, &u, mu, n, max_iters, rng);
278
279    let accept_condition = |logp: &Array1<f64>, accepted: &mut Array1<bool>, rng: &mut R| {
280        let test_sample: Array1<f64> = Array1::random_using(logp.len(), Uniform::new(0., 1.), rng);
281        azip!((&s in &test_sample, &logp in logp, acc in accepted) {
282            if -1. * s.ln() > (psi_star - logp) {
283            *acc = true;
284            }
285        });
286    };
287    let mut accepted: Array1<bool> = Array1::from_elem(Z.ncols(), false);
288    accept_condition(&logp, &mut accepted, rng);
289    let mut i = 0;
290    while !accepted.fold(true, |a, b| a && *b) {
291        let (logp, sample) = mv_truncnorm_proposal(&L, &l, &u, mu, n, max_iters, rng);
292        Zip::from(Z.axis_iter_mut(Axis(1)))
293            .and(sample.axis_iter(Axis(1)))
294            .and(&accepted)
295            .for_each(|mut z, s, &acc| {
296                if !acc {
297                    z.assign(&s);
298                }
299            });
300        accept_condition(&logp, &mut accepted, rng);
301        i += 1;
302        if i > max_iters {
303            // Ran out of accept-reject rounds
304            break;
305        }
306    }
307    // postprocess samples
308    let mut unperm = perm.into_iter().zip(0..d).collect::<Vec<(usize, usize)>>();
309    unperm.sort_by(|a, b| a.0.cmp(&b.0));
310    let order: Vec<usize> = unperm.iter().map(|x| x.1).collect();
311
312    // reverse scaling of L
313    let mut rv = Lfull.dot(&Z);
314    let unperm_rv = rv.clone();
315    for (i, &ord) in order.iter().enumerate() {
316        rv.row_mut(i).assign(&unperm_rv.row(ord));
317    }
318    rv.reversed_axes()
319}
320
321/// truncated multivariate normal generator
322pub fn mv_truncnormal_rand<R: Rng + ?Sized>(
323    mut l: Array1<f64>,
324    mut u: Array1<f64>,
325    mut sigma: Array2<f64>,
326    n: usize,
327    max_iters: usize,
328    rng: &mut R,
329) -> Array2<f64> {
330    let d = l.len();
331    let (Lfull, perm) = cholperm(&mut sigma, &mut l, &mut u);
332    let D = Lfull.diag().to_owned();
333
334    u /= &D;
335    l /= &D;
336    let L = (&Lfull / &(Array2::<f64>::zeros([D.len(), D.len()]) + &D)) - Array2::<f64>::eye(d);
337
338    // find optimal tilting parameter via non-linear equation solver
339    let problem = TiltingProblem::new(l.clone(), u.clone(), sigma);
340    let result = problem.solve_optimial_tilting();
341    // assign saddlepoint x* and mu*
342    let x = result.x.slice(s![..d - 1]).to_owned();
343    let mu = result.x.slice(s![d - 1..(2 * (d - 1))]).to_owned();
344    let psi_star = psy(&x, &L, &l, &u, &mu); // compute psi star
345    let (logp, mut Z) = mv_truncnorm_proposal(&L, &l, &u, &mu, n, max_iters, rng);
346
347    let accept_condition = |logp: &Array1<f64>, accepted: &mut Array1<bool>, rng: &mut R| {
348        let test_sample: Array1<f64> = Array1::random_using(logp.len(), Uniform::new(0., 1.), rng);
349        azip!((&s in &test_sample, &logp in logp, acc in accepted) {
350            if -1. * s.ln() > (psi_star - logp) {
351            *acc = true;
352            }
353        });
354    };
355    let mut accepted: Array1<bool> = Array1::from_elem(Z.ncols(), false);
356    accept_condition(&logp, &mut accepted, rng);
357    let mut i = 0;
358    while !accepted.fold(true, |a, b| a && *b) {
359        let (logp, sample) = mv_truncnorm_proposal(&L, &l, &u, &mu, n, max_iters, rng);
360        Zip::from(Z.axis_iter_mut(Axis(1)))
361            .and(sample.axis_iter(Axis(1)))
362            .and(&accepted)
363            .for_each(|mut z, s, &acc| {
364                if !acc {
365                    z.assign(&s);
366                }
367            });
368        accept_condition(&logp, &mut accepted, rng);
369        i += 1;
370        if i > max_iters {
371            break;
372        }
373    }
374    // postprocess samples
375    let mut unperm = perm.into_iter().zip(0..d).collect::<Vec<(usize, usize)>>();
376    unperm.sort_by(|a, b| a.0.cmp(&b.0));
377    let order: Vec<usize> = unperm.iter().map(|x| x.1).collect();
378
379    // reverse scaling of L
380    let mut rv = Lfull.dot(&Z);
381    let unperm_rv = rv.clone();
382    for (i, &ord) in order.iter().enumerate() {
383        rv.row_mut(i).assign(&unperm_rv.row(ord));
384    }
385    rv.reversed_axes()
386}
387
388pub fn solved_mv_truncnormal_cdf<R: Rng + ?Sized>(
389    tilting_solution: &TiltingSolution,
390    n: usize,
391    max_iters: usize,
392    rng: &mut R,
393) -> (f64, f64, f64) {
394    // compute psi star
395    let (est, rel_err) = mv_normal_pr(
396        n,
397        &tilting_solution.lower_tri,
398        &tilting_solution.lower,
399        &tilting_solution.upper,
400        &tilting_solution.mu,
401        max_iters,
402        rng,
403    );
404    // calculate an upper bound
405    let log_upbnd = psy(
406        &tilting_solution.x,
407        &tilting_solution.lower_tri,
408        &tilting_solution.lower,
409        &tilting_solution.upper,
410        &tilting_solution.mu,
411    );
412    /*
413    if log_upbnd < -743. {
414        panic!(
415        "Natural log of upbnd probability is less than -743, yielding 0 after exponentiation!"
416        )
417    }
418    */
419    let upbnd = log_upbnd.exp();
420    (est, rel_err, upbnd)
421}
422
423/// multivariate normal cumulative distribution
424pub fn mv_truncnormal_cdf<R: Rng + ?Sized>(
425    l: Array1<f64>,
426    u: Array1<f64>,
427    sigma: Array2<f64>,
428    n: usize,
429    max_iters: usize,
430    rng: &mut R,
431) -> (f64, f64, f64) {
432    let tilting_solution = TiltingProblem::new(l, u, sigma).solve_optimial_tilting();
433    // compute psi star
434    let (est, rel_err) = mv_normal_pr(
435        n,
436        &tilting_solution.lower_tri,
437        &tilting_solution.lower,
438        &tilting_solution.upper,
439        &tilting_solution.mu,
440        max_iters,
441        rng,
442    );
443    // calculate an upper bound
444    let log_upbnd = psy(
445        &tilting_solution.x,
446        &tilting_solution.lower_tri,
447        &tilting_solution.lower,
448        &tilting_solution.upper,
449        &tilting_solution.mu,
450    );
451    /*
452    if log_upbnd < -743. {
453        panic!(
454        "Natural log of upbnd probability is less than -743, yielding 0 after exponentiation!"
455        )
456    }
457    */
458    let upbnd = log_upbnd.exp();
459    (est, rel_err, upbnd)
460}
461
462#[cfg(test)]
463mod tests {
464    extern crate ndarray;
465    extern crate test;
466    use super::*;
467    use ndarray::{arr1, arr2};
468    use ndarray_rand::rand_distr::Normal;
469    use ndarray_rand::rand_distr::Uniform;
470    use test::Bencher;
471
472    #[test]
473    fn manual_rand_scale_test() {
474        let l = arr1(&[f64::NEG_INFINITY, f64::NEG_INFINITY]);
475        //let u: Array1<f64> = arr1(&[-7.33, 0.1]);
476        let u = arr1(&[-7.75, 9.11]);
477        let sigma = arr2(&[[10., -10.], [-10., 11.]]);
478        let mut rng = rand::thread_rng();
479        let n = 10;
480        let max_iters = 10;
481        let samples = mv_truncnormal_rand(l, u, sigma, n, max_iters, &mut rng);
482    }
483
484    #[bench]
485    // with par e0::2.63e3, e1::1.22e4, e2::1.94e4, e3::4.85e4
486    fn bench_ln_normal_pr(bench: &mut Bencher) {
487        let normal = Normal::new(0., 1.).unwrap();
488        let uniform = Uniform::new(0., 2.);
489        let a = Array2::random((1000, 1), normal);
490        let b = Array2::random((1000, 1), uniform);
491        let c = a.clone() + b;
492        bench.iter(|| test::black_box(ln_normal_pr(&a, &c)));
493    }
494
495    #[bench]
496    // e1::1.57e5, e2::2.83e6, e3::4.20e8
497    fn bench_cholperm(b: &mut Bencher) {
498        let n = 10;
499        let mut sigma = Array2::eye(n);
500        sigma.row_mut(2).map_inplace(|x| *x += 0.01);
501        sigma.column_mut(2).map_inplace(|x| *x += 0.01);
502        let uniform = Uniform::new(0., 1.);
503        let mut l = Array1::random(n, uniform);
504        let mut u = 2. * &l;
505        b.iter(|| test::black_box(cholperm(&mut sigma, &mut l, &mut u)));
506    }
507
508    /*
509    #[test]
510    fn test_grad_psi() {
511        let d = 25;
512        let mut l = Array1::ones(d) / 2.;
513        let mut u = Array1::ones(d);
514        let mut sigma: Array2<f64> =
515            Array2::from_elem((25, 25), -0.07692307692307693) + Array2::<f64>::eye(25) * 2.;
516        //let y = Array1::ones(d);
517        let y = Array::range(0., 2. * (d - 1) as f64, 1.);
518
519        let (mut L, _perm) = cholperm(&mut sigma, &mut l, &mut u);
520        let D = L.diag().to_owned();
521        u /= &D;
522        l /= &D;
523        L = (L / (Array2::<f64>::zeros([D.len(), D.len()]) + &D).t()) - Array2::<f64>::eye(d);
524        let (residuals, jacobian) = grad_psi(&y, &L, &l, &u);
525        println!("{:?}", (residuals, jacobian))
526    }
527    */
528
529    #[test]
530    fn test_mv_normal_cdf() {
531        let d = 25;
532        let l = Array1::ones(d) / 2.;
533        let u = Array1::ones(d);
534        let sigma: Array2<f64> =
535            Array2::from_elem((25, 25), -0.07692307692307693) + Array2::<f64>::eye(25) * 2.;
536        let mut rng = rand::thread_rng();
537        let (est, rel_err, upper_bound) = mv_truncnormal_cdf(l, u, sigma, 10000, 10, &mut rng);
538        println!("{:?}", (est, rel_err, upper_bound));
539        /* Should be close to:
540        prob: 2.6853e-53
541        relErr: 2.1390e-04
542        upbnd: 2.8309e-53
543        */
544    }
545
546    #[test]
547    fn test_mv_truncnormal_rand() {
548        let d = 3;
549        let l = Array1::ones(d) / 2.;
550        let u = Array1::ones(d);
551        let sigma: Array2<f64> =
552            Array2::from_elem((d, d), -0.07692307692307693) + Array2::<f64>::eye(d) * 2.;
553        println!("l {}", l);
554        println!("u {}", u);
555        let mut rng = rand::thread_rng();
556        let (samples, logp) = mv_truncnormal_rand(l, u, sigma, 5, 10, &mut rng);
557        println!("{:?}", (samples, logp));
558    }
559
560    #[bench]
561    fn bench_mv_normal_cdf(b: &mut Bencher) {
562        let d = 25;
563        let l = Array1::ones(d) / 2.;
564        let u = Array1::ones(d);
565        let sigma: Array2<f64> =
566            Array2::from_elem((25, 25), -0.07692307692307693) + Array2::<f64>::eye(25) * 2.;
567        let mut rng = rand::thread_rng();
568        b.iter(|| {
569            test::black_box(mv_truncnormal_cdf(
570                l.clone(),
571                u.clone(),
572                sigma.clone(),
573                20000,
574                10,
575                &mut rng,
576            ))
577        });
578    }
579}