1#![allow(non_snake_case)]
2#![allow(clippy::many_single_char_names)]
3use crate::dist_util::cholperm;
9use crate::dist_util::ln_normal_pr;
10use crate::faddeeva::erfc;
11use crate::tilting::TiltingProblem;
12use crate::tilting::TiltingSolution;
13use crate::util;
14use ndarray::azip;
15use ndarray::Zip;
16use ndarray::{s, Axis};
17use ndarray::{Array1, Array2};
18use ndarray_rand::rand_distr::StandardNormal;
19use ndarray_rand::RandomExt;
20use num::traits::FloatConst;
21use rand::distributions::Uniform;
22use rand::Rng;
23use statrs::function::erf::erfc_inv;
24
25fn ntail<R: Rng + ?Sized>(
26 l: &Array1<f64>,
27 u: &Array1<f64>,
28 max_iters: usize,
29 rng: &mut R,
30) -> Array1<f64> {
31 let c = l.map(|x| x.powi(2) / 2.);
38 let f = Zip::from(&c)
39 .and(u)
40 .map_collect(|&c, &u| (c - u * u / 2.).exp_m1());
41 let mut accept_condition = |x: &Array1<f64>, accepted: &mut Array1<bool>, rng: &mut R| {
43 let test_sample: Array1<f64> = Array1::random_using(l.len(), Uniform::new(0., 1.), rng);
44 azip!((x in x, &s in &test_sample, &c in &c, acc in accepted) {
45 if s * s * x < c {
46 *acc = true;
47 }
48 });
49 };
50 let mut proposal_sampler = |rng: &mut R| {
51 let sample = Array1::random_using(l.len(), Uniform::new(0., 1.), rng);
52 &c - (1. + sample * &f).mapv(f64::ln)
53 };
54 let mut output_array =
55 util::rejection_sample(&mut accept_condition, &mut proposal_sampler, max_iters, rng);
56 output_array.mapv_inplace(|x| (2. * x).sqrt());
57 output_array
58}
59
60fn trnd<R: Rng + ?Sized>(
61 l: &Array1<f64>,
62 u: &Array1<f64>,
63 max_iters: usize,
64 rng: &mut R,
65) -> Array1<f64> {
66 let mut accept_condition = |x: &Array1<f64>, accepted: &mut Array1<bool>, _rng: &mut R| {
68 azip!((x in x, l in l, u in u, acc in accepted) {
69 if x > l && x < u {
70 *acc = true;
71 }
72 });
73 };
74 let mut proposal_sampler = |rng: &mut R| Array1::random_using(l.len(), StandardNormal, rng);
75 util::rejection_sample(&mut accept_condition, &mut proposal_sampler, max_iters, rng)
76}
77
78fn tn<R: Rng + ?Sized>(
79 l: &Array1<f64>,
80 u: &Array1<f64>,
81 max_iters: usize,
82 rng: &mut R,
83) -> Array1<f64> {
84 let tol = 2.;
93 let mut coeff = Array1::ones(l.len());
96 let gap = (u - l).map(|x| x.abs());
97 let mut tl = l.clone();
98 let mut tu = u.clone();
99 azip!((gap in &gap, coeff in &mut coeff, tl in &mut tl, tu in &mut tu) if *gap < tol {*coeff = 0.;*tl=f64::NEG_INFINITY;*tu=f64::INFINITY;});
100 let accept_reject = trnd(&tl, &tu, max_iters, rng);
101 let pl = (&tl * f64::FRAC_1_SQRT_2()).map(|x| erfc(*x) / 2.);
103 let pu = (&tu * f64::FRAC_1_SQRT_2()).map(|x| erfc(*x) / 2.);
104 let sample = Array1::random_using(l.len(), Uniform::new(0., 1.), rng);
105
106 let inverse_transform =
107 f64::SQRT_2() * (2. * (&pl - (&pl - &pu) * sample)).map(|x| erfc_inv(*x));
108 let mut result = &coeff * &accept_reject + (1. - &coeff) * &inverse_transform;
109 if result.iter().any(|x| x.is_nan()) {
110 result = coeff
111 .iter()
112 .zip(inverse_transform.iter())
113 .zip(accept_reject.iter())
114 .map(|x| if *x.0 .0 == 0. { *x.0 .1 } else { *x.1 })
115 .collect();
116 }
117 result
118}
119
120pub fn trandn<R: Rng + ?Sized>(
129 l: &Array1<f64>,
130 u: &Array1<f64>,
131 max_iters: usize,
132 rng: &mut R,
133) -> Array1<f64> {
134 let thresh = 0.66; let mut tl = l.clone();
136 let mut tu = u.clone();
137 let mut coeff = Array1::zeros(l.len());
138 azip!((tl in &mut tl, tu in &mut tu, coeff in &mut coeff) {
139 if *tl > thresh {*coeff = 1.}
140 else if *tu < -thresh {*tl = -*tu; *tu = -*tl; *coeff = -1.}
141 else {*tl = -100.; *tu = 100.; *coeff=0.;} });
143 let acc_rej_sample = ntail(&tl, &tu, max_iters, rng);
144 let trunc_norm_sample = tn(l, u, max_iters, rng);
145 &coeff * acc_rej_sample + (1. - &coeff.mapv(f64::abs)) * trunc_norm_sample
146}
147
148fn psy(
149 x: &Array1<f64>,
150 L: &Array2<f64>,
151 l: &Array1<f64>,
152 u: &Array1<f64>,
153 mu: &Array1<f64>,
154) -> f64 {
155 let mut temp = Array1::zeros(x.len() + 1);
157 temp.slice_mut(s![..x.len()]).assign(x);
158 let x = temp;
159 let mut temp = Array1::zeros(mu.len() + 1);
160 temp.slice_mut(s![..mu.len()]).assign(mu);
161 let mu = temp;
162 let c = L.dot(&x);
164 let tl = l - &mu - &c;
165 let tu = u - &mu - &c;
166 (ln_normal_pr(&tl, &tu) + 0.5 * mu.mapv(|x| x * x) - x * mu).sum()
167}
168
169fn mv_normal_pr<R: Rng + ?Sized>(
176 n: usize,
177 L: &Array2<f64>,
178 l: &Array1<f64>,
179 u: &Array1<f64>,
180 mu: &Array1<f64>,
181 max_iters: usize,
182 rng: &mut R,
183) -> (f64, f64) {
184 let d = l.shape()[0];
185 let mut p = Array1::zeros(n);
186 let mut temp = Array1::zeros(mu.shape()[0] + 1);
187 temp.slice_mut(s![..d - 1]).assign(mu);
188 let mu = temp;
189 let mut Z = Array2::zeros((d, n));
190 let mut col;
191 let mut tl;
192 let mut tu;
193 for k in 0..d - 1 {
194 col = L.slice(s![k, ..k]).dot(&Z.slice(s![..k, ..]));
195 tl = l[[k]] - mu[[k]] - &col;
196 tu = u[[k]] - mu[[k]] - col;
197 Z.index_axis_mut(Axis(0), k)
199 .assign(&(mu[[k]] + trandn(&tl, &tu, max_iters, rng)));
200 p = p + ln_normal_pr(&tl, &tu) + 0.5 * mu[[k]].powi(2)
202 - mu[[k]] * &Z.index_axis(Axis(0), k);
203 }
204 col = L.index_axis(Axis(0), d - 1).dot(&Z);
205 tl = l[d - 1] - &col;
206 tu = u[d - 1] - col;
207 p = p + ln_normal_pr(&tl, &tu);
208 p.mapv_inplace(f64::exp);
209 let prob = p.mean().unwrap();
210 debug_assert!(
211 !prob.is_sign_negative(),
212 "Returned invalid probability, {:?}",
213 prob
214 );
215 let rel_err = p.std(0.) / (n as f64).sqrt() / prob;
216 (prob, rel_err)
217}
218
219fn mv_truncnorm_proposal<R: Rng + ?Sized>(
220 L: &Array2<f64>,
221 l: &Array1<f64>,
222 u: &Array1<f64>,
223 mu: &Array1<f64>,
224 n: usize,
225 max_iters: usize,
226 rng: &mut R,
227) -> (Array1<f64>, Array2<f64>) {
228 let d = l.shape()[0];
235 let mut logp = Array1::zeros(n);
236 let mut temp = Array1::zeros(mu.shape()[0] + 1);
237 temp.slice_mut(s![..d - 1]).assign(mu);
238 let mu = temp;
239 let mut Z = Array2::zeros((d, n));
240 let mut col;
241 let mut tl;
242 let mut tu;
243 for k in 0..d {
244 col = L.slice(s![k, ..k]).dot(&Z.slice(s![..k, ..]));
245 tl = l[[k]] - mu[[k]] - &col;
246 tu = u[[k]] - mu[[k]] - col;
247 Z.index_axis_mut(Axis(0), k)
249 .assign(&(mu[[k]] + trandn(&tl, &tu, max_iters, rng)));
250 logp = logp + ln_normal_pr(&tl, &tu) + 0.5 * mu[[k]] * mu[[k]]
252 - mu[[k]] * &Z.index_axis(Axis(0), k);
253 }
254 (logp, Z)
255}
256
257pub fn solved_mv_truncnormal_rand<R: Rng + ?Sized>(
258 tilting_solution: &TiltingSolution,
259 mut l: Array1<f64>,
260 mut u: Array1<f64>,
261 mut sigma: Array2<f64>,
262 n: usize,
263 max_iters: usize,
264 rng: &mut R,
265) -> Array2<f64> {
266 let d = l.len();
267 let (Lfull, perm) = cholperm(&mut sigma, &mut l, &mut u);
268 let D = Lfull.diag().to_owned();
269
270 u /= &D;
271 l /= &D;
272 let L = (&Lfull / &(Array2::<f64>::zeros([D.len(), D.len()]) + &D)) - Array2::<f64>::eye(d);
273
274 let x = &tilting_solution.x;
275 let mu = &tilting_solution.mu;
276 let psi_star = psy(x, &L, &l, &u, mu); let (logp, mut Z) = mv_truncnorm_proposal(&L, &l, &u, mu, n, max_iters, rng);
278
279 let accept_condition = |logp: &Array1<f64>, accepted: &mut Array1<bool>, rng: &mut R| {
280 let test_sample: Array1<f64> = Array1::random_using(logp.len(), Uniform::new(0., 1.), rng);
281 azip!((&s in &test_sample, &logp in logp, acc in accepted) {
282 if -1. * s.ln() > (psi_star - logp) {
283 *acc = true;
284 }
285 });
286 };
287 let mut accepted: Array1<bool> = Array1::from_elem(Z.ncols(), false);
288 accept_condition(&logp, &mut accepted, rng);
289 let mut i = 0;
290 while !accepted.fold(true, |a, b| a && *b) {
291 let (logp, sample) = mv_truncnorm_proposal(&L, &l, &u, mu, n, max_iters, rng);
292 Zip::from(Z.axis_iter_mut(Axis(1)))
293 .and(sample.axis_iter(Axis(1)))
294 .and(&accepted)
295 .for_each(|mut z, s, &acc| {
296 if !acc {
297 z.assign(&s);
298 }
299 });
300 accept_condition(&logp, &mut accepted, rng);
301 i += 1;
302 if i > max_iters {
303 break;
305 }
306 }
307 let mut unperm = perm.into_iter().zip(0..d).collect::<Vec<(usize, usize)>>();
309 unperm.sort_by(|a, b| a.0.cmp(&b.0));
310 let order: Vec<usize> = unperm.iter().map(|x| x.1).collect();
311
312 let mut rv = Lfull.dot(&Z);
314 let unperm_rv = rv.clone();
315 for (i, &ord) in order.iter().enumerate() {
316 rv.row_mut(i).assign(&unperm_rv.row(ord));
317 }
318 rv.reversed_axes()
319}
320
321pub fn mv_truncnormal_rand<R: Rng + ?Sized>(
323 mut l: Array1<f64>,
324 mut u: Array1<f64>,
325 mut sigma: Array2<f64>,
326 n: usize,
327 max_iters: usize,
328 rng: &mut R,
329) -> Array2<f64> {
330 let d = l.len();
331 let (Lfull, perm) = cholperm(&mut sigma, &mut l, &mut u);
332 let D = Lfull.diag().to_owned();
333
334 u /= &D;
335 l /= &D;
336 let L = (&Lfull / &(Array2::<f64>::zeros([D.len(), D.len()]) + &D)) - Array2::<f64>::eye(d);
337
338 let problem = TiltingProblem::new(l.clone(), u.clone(), sigma);
340 let result = problem.solve_optimial_tilting();
341 let x = result.x.slice(s![..d - 1]).to_owned();
343 let mu = result.x.slice(s![d - 1..(2 * (d - 1))]).to_owned();
344 let psi_star = psy(&x, &L, &l, &u, &mu); let (logp, mut Z) = mv_truncnorm_proposal(&L, &l, &u, &mu, n, max_iters, rng);
346
347 let accept_condition = |logp: &Array1<f64>, accepted: &mut Array1<bool>, rng: &mut R| {
348 let test_sample: Array1<f64> = Array1::random_using(logp.len(), Uniform::new(0., 1.), rng);
349 azip!((&s in &test_sample, &logp in logp, acc in accepted) {
350 if -1. * s.ln() > (psi_star - logp) {
351 *acc = true;
352 }
353 });
354 };
355 let mut accepted: Array1<bool> = Array1::from_elem(Z.ncols(), false);
356 accept_condition(&logp, &mut accepted, rng);
357 let mut i = 0;
358 while !accepted.fold(true, |a, b| a && *b) {
359 let (logp, sample) = mv_truncnorm_proposal(&L, &l, &u, &mu, n, max_iters, rng);
360 Zip::from(Z.axis_iter_mut(Axis(1)))
361 .and(sample.axis_iter(Axis(1)))
362 .and(&accepted)
363 .for_each(|mut z, s, &acc| {
364 if !acc {
365 z.assign(&s);
366 }
367 });
368 accept_condition(&logp, &mut accepted, rng);
369 i += 1;
370 if i > max_iters {
371 break;
372 }
373 }
374 let mut unperm = perm.into_iter().zip(0..d).collect::<Vec<(usize, usize)>>();
376 unperm.sort_by(|a, b| a.0.cmp(&b.0));
377 let order: Vec<usize> = unperm.iter().map(|x| x.1).collect();
378
379 let mut rv = Lfull.dot(&Z);
381 let unperm_rv = rv.clone();
382 for (i, &ord) in order.iter().enumerate() {
383 rv.row_mut(i).assign(&unperm_rv.row(ord));
384 }
385 rv.reversed_axes()
386}
387
388pub fn solved_mv_truncnormal_cdf<R: Rng + ?Sized>(
389 tilting_solution: &TiltingSolution,
390 n: usize,
391 max_iters: usize,
392 rng: &mut R,
393) -> (f64, f64, f64) {
394 let (est, rel_err) = mv_normal_pr(
396 n,
397 &tilting_solution.lower_tri,
398 &tilting_solution.lower,
399 &tilting_solution.upper,
400 &tilting_solution.mu,
401 max_iters,
402 rng,
403 );
404 let log_upbnd = psy(
406 &tilting_solution.x,
407 &tilting_solution.lower_tri,
408 &tilting_solution.lower,
409 &tilting_solution.upper,
410 &tilting_solution.mu,
411 );
412 let upbnd = log_upbnd.exp();
420 (est, rel_err, upbnd)
421}
422
423pub fn mv_truncnormal_cdf<R: Rng + ?Sized>(
425 l: Array1<f64>,
426 u: Array1<f64>,
427 sigma: Array2<f64>,
428 n: usize,
429 max_iters: usize,
430 rng: &mut R,
431) -> (f64, f64, f64) {
432 let tilting_solution = TiltingProblem::new(l, u, sigma).solve_optimial_tilting();
433 let (est, rel_err) = mv_normal_pr(
435 n,
436 &tilting_solution.lower_tri,
437 &tilting_solution.lower,
438 &tilting_solution.upper,
439 &tilting_solution.mu,
440 max_iters,
441 rng,
442 );
443 let log_upbnd = psy(
445 &tilting_solution.x,
446 &tilting_solution.lower_tri,
447 &tilting_solution.lower,
448 &tilting_solution.upper,
449 &tilting_solution.mu,
450 );
451 let upbnd = log_upbnd.exp();
459 (est, rel_err, upbnd)
460}
461
462#[cfg(test)]
463mod tests {
464 extern crate ndarray;
465 extern crate test;
466 use super::*;
467 use ndarray::{arr1, arr2};
468 use ndarray_rand::rand_distr::Normal;
469 use ndarray_rand::rand_distr::Uniform;
470 use test::Bencher;
471
472 #[test]
473 fn manual_rand_scale_test() {
474 let l = arr1(&[f64::NEG_INFINITY, f64::NEG_INFINITY]);
475 let u = arr1(&[-7.75, 9.11]);
477 let sigma = arr2(&[[10., -10.], [-10., 11.]]);
478 let mut rng = rand::thread_rng();
479 let n = 10;
480 let max_iters = 10;
481 let samples = mv_truncnormal_rand(l, u, sigma, n, max_iters, &mut rng);
482 }
483
484 #[bench]
485 fn bench_ln_normal_pr(bench: &mut Bencher) {
487 let normal = Normal::new(0., 1.).unwrap();
488 let uniform = Uniform::new(0., 2.);
489 let a = Array2::random((1000, 1), normal);
490 let b = Array2::random((1000, 1), uniform);
491 let c = a.clone() + b;
492 bench.iter(|| test::black_box(ln_normal_pr(&a, &c)));
493 }
494
495 #[bench]
496 fn bench_cholperm(b: &mut Bencher) {
498 let n = 10;
499 let mut sigma = Array2::eye(n);
500 sigma.row_mut(2).map_inplace(|x| *x += 0.01);
501 sigma.column_mut(2).map_inplace(|x| *x += 0.01);
502 let uniform = Uniform::new(0., 1.);
503 let mut l = Array1::random(n, uniform);
504 let mut u = 2. * &l;
505 b.iter(|| test::black_box(cholperm(&mut sigma, &mut l, &mut u)));
506 }
507
508 #[test]
530 fn test_mv_normal_cdf() {
531 let d = 25;
532 let l = Array1::ones(d) / 2.;
533 let u = Array1::ones(d);
534 let sigma: Array2<f64> =
535 Array2::from_elem((25, 25), -0.07692307692307693) + Array2::<f64>::eye(25) * 2.;
536 let mut rng = rand::thread_rng();
537 let (est, rel_err, upper_bound) = mv_truncnormal_cdf(l, u, sigma, 10000, 10, &mut rng);
538 println!("{:?}", (est, rel_err, upper_bound));
539 }
545
546 #[test]
547 fn test_mv_truncnormal_rand() {
548 let d = 3;
549 let l = Array1::ones(d) / 2.;
550 let u = Array1::ones(d);
551 let sigma: Array2<f64> =
552 Array2::from_elem((d, d), -0.07692307692307693) + Array2::<f64>::eye(d) * 2.;
553 println!("l {}", l);
554 println!("u {}", u);
555 let mut rng = rand::thread_rng();
556 let (samples, logp) = mv_truncnormal_rand(l, u, sigma, 5, 10, &mut rng);
557 println!("{:?}", (samples, logp));
558 }
559
560 #[bench]
561 fn bench_mv_normal_cdf(b: &mut Bencher) {
562 let d = 25;
563 let l = Array1::ones(d) / 2.;
564 let u = Array1::ones(d);
565 let sigma: Array2<f64> =
566 Array2::from_elem((25, 25), -0.07692307692307693) + Array2::<f64>::eye(25) * 2.;
567 let mut rng = rand::thread_rng();
568 b.iter(|| {
569 test::black_box(mv_truncnormal_cdf(
570 l.clone(),
571 u.clone(),
572 sigma.clone(),
573 20000,
574 10,
575 &mut rng,
576 ))
577 });
578 }
579}