1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::{Distribution, Gamma, Normal};
9use scirs2_core::validation::*;
10use scirs2_core::Rng;
11use statrs::statistics::Statistics;
12
13#[derive(Debug, Clone)]
20pub struct HierarchicalLinearModel {
21 pub fixed_effects: Array2<f64>,
23 pub random_effects_cov: Array2<f64>,
25 pub residual_variance: f64,
27 pub groups: Array1<usize>,
29 pub n_groups: usize,
31 pub n_level1_predictors: usize,
33 pub n_level2_predictors: usize,
35 pub random_slopes: bool,
37}
38
39impl HierarchicalLinearModel {
40 pub fn new(
42 n_groups: usize,
43 n_level1_predictors: usize,
44 n_level2_predictors: usize,
45 random_slopes: bool,
46 ) -> Result<Self> {
47 check_positive(n_groups, "n_groups")?;
48 check_positive(n_level1_predictors, "n_level1_predictors")?;
49
50 let n_random_effects = if random_slopes {
51 n_level1_predictors + 1
52 } else {
53 1
54 };
55 let fixed_effects = Array2::zeros((n_random_effects, n_level2_predictors + 1));
56 let random_effects_cov = Array2::eye(n_random_effects);
57
58 Ok(Self {
59 fixed_effects,
60 random_effects_cov,
61 residual_variance: 1.0,
62 groups: Array1::zeros(0),
63 n_groups,
64 n_level1_predictors,
65 n_level2_predictors,
66 random_slopes,
67 })
68 }
69
70 pub fn fit_mcmc<R: Rng + ?Sized>(
72 &mut self,
73 y: ArrayView1<f64>,
74 x_level1: ArrayView2<f64>,
75 x_level2: ArrayView2<f64>,
76 groups: ArrayView1<usize>,
77 n_iter: usize,
78 burnin: usize,
79 rng: &mut R,
80 ) -> Result<HierarchicalModelResults> {
81 checkarray_finite(&y, "y")?;
82 checkarray_finite(&x_level1, "x_level1")?;
83 checkarray_finite(&x_level2, "x_level2")?;
84 check_positive(n_iter, "n_iter")?;
85
86 let n_obs = y.len();
87 if x_level1.nrows() != n_obs {
88 return Err(StatsError::DimensionMismatch(format!(
89 "x_level1 rows ({}) must match y length ({})",
90 x_level1.nrows(),
91 n_obs
92 )));
93 }
94
95 if groups.len() != n_obs {
96 return Err(StatsError::DimensionMismatch(format!(
97 "groups length ({}) must match y length ({})",
98 groups.len(),
99 n_obs
100 )));
101 }
102
103 self.groups = groups.to_owned();
104
105 let n_random_effects = if self.random_slopes {
107 self.n_level1_predictors + 1
108 } else {
109 1
110 };
111 let n_fixed = (self.n_level2_predictors + 1) * n_random_effects;
112
113 let mut fixed_effects_samples = Array2::zeros((n_iter - burnin, n_fixed));
114 let mut random_effects_samples =
115 Array2::zeros((n_iter - burnin, self.n_groups * n_random_effects));
116 let mut variance_samples = Array1::zeros(n_iter - burnin);
117 let mut tau_samples = Array2::zeros((n_iter - burnin, n_random_effects * n_random_effects));
118
119 let mut random_effects = Array2::zeros((self.n_groups, n_random_effects));
121
122 for _iter in 0..n_iter {
124 self.update_random_effects(&y, &x_level1, &x_level2, &mut random_effects, rng)?;
126
127 self.update_fixed_effects(&random_effects, &x_level2, rng)?;
129
130 self.update_residual_variance(&y, &x_level1, &random_effects, rng)?;
132
133 self.update_random_effects_covariance(&random_effects, rng)?;
135
136 if _iter >= burnin {
138 let sample_idx = _iter - burnin;
139
140 let mut fixed_flat = Array1::zeros(n_fixed);
142 let mut idx = 0;
143 for i in 0..self.fixed_effects.nrows() {
144 for j in 0..self.fixed_effects.ncols() {
145 fixed_flat[idx] = self.fixed_effects[[i, j]];
146 idx += 1;
147 }
148 }
149 fixed_effects_samples
150 .row_mut(sample_idx)
151 .assign(&fixed_flat);
152
153 let mut random_flat = Array1::zeros(self.n_groups * n_random_effects);
155 let mut idx = 0;
156 for group in 0..self.n_groups {
157 for effect in 0..n_random_effects {
158 random_flat[idx] = random_effects[[group, effect]];
159 idx += 1;
160 }
161 }
162 random_effects_samples
163 .row_mut(sample_idx)
164 .assign(&random_flat);
165
166 variance_samples[sample_idx] = self.residual_variance;
168
169 let mut tau_flat = Array1::zeros(n_random_effects * n_random_effects);
171 let mut idx = 0;
172 for i in 0..n_random_effects {
173 for j in 0..n_random_effects {
174 tau_flat[idx] = self.random_effects_cov[[i, j]];
175 idx += 1;
176 }
177 }
178 tau_samples.row_mut(sample_idx).assign(&tau_flat);
179 }
180 }
181
182 Ok(HierarchicalModelResults {
183 fixed_effects_samples,
184 random_effects_samples,
185 variance_samples,
186 tau_samples,
187 n_groups: self.n_groups,
188 n_random_effects,
189 n_iter: n_iter - burnin,
190 })
191 }
192
193 fn update_random_effects<R: scirs2_core::random::Rng + ?Sized>(
195 &self,
196 y: &ArrayView1<f64>,
197 x_level1: &ArrayView2<f64>,
198 x_level2: &ArrayView2<f64>,
199 random_effects: &mut Array2<f64>,
200 rng: &mut R,
201 ) -> Result<()> {
202 let n_random_effects = random_effects.ncols();
203
204 for group in 0..self.n_groups {
205 let group_indices: Vec<usize> = self
207 .groups
208 .iter()
209 .enumerate()
210 .filter_map(|(i, &g)| if g == group { Some(i) } else { None })
211 .collect();
212
213 if group_indices.is_empty() {
214 continue;
215 }
216
217 let n_group_obs = group_indices.len();
219 let mut y_group = Array1::zeros(n_group_obs);
220 let mut x_group = Array2::zeros((n_group_obs, self.n_level1_predictors));
221
222 for (i, &obs_idx) in group_indices.iter().enumerate() {
223 y_group[i] = y[obs_idx];
224 x_group.row_mut(i).assign(&x_level1.row(obs_idx));
225 }
226
227 let precision_prior = scirs2_linalg::inv(&self.random_effects_cov.view(), None)
229 .map_err(|e| {
230 StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
231 })?;
232
233 let mut z_group = Array2::zeros((n_group_obs, n_random_effects));
235 z_group.column_mut(0).fill(1.0); if self.random_slopes && n_random_effects > 1 {
237 for i in 1..n_random_effects {
238 z_group.column_mut(i).assign(&x_group.column(i - 1));
239 }
240 }
241
242 let zt_z = z_group.t().dot(&z_group);
243 let precision_posterior = precision_prior.clone() + zt_z / self.residual_variance;
244
245 let covariance_posterior = scirs2_linalg::inv(&precision_posterior.view(), None)
246 .map_err(|e| {
247 StatsError::ComputationError(format!(
248 "Failed to invert posterior precision: {}",
249 e
250 ))
251 })?;
252
253 let group_level2 = if group < x_level2.nrows() {
255 x_level2.row(group).to_owned()
256 } else {
257 Array1::zeros(x_level2.ncols())
258 };
259
260 let mut prior_mean = Array1::zeros(n_random_effects);
261 for i in 0..n_random_effects {
262 prior_mean[i] = self.fixed_effects.row(i).dot(&group_level2);
263 }
264
265 let data_contrib = z_group.t().dot(&y_group) / self.residual_variance;
266 let prior_contrib = precision_prior.dot(&prior_mean);
267 let posterior_mean = covariance_posterior.dot(&(data_contrib + prior_contrib));
268
269 let mvn_sample =
271 sample_multivariate_normal(&posterior_mean, &covariance_posterior, rng)?;
272 random_effects.row_mut(group).assign(&mvn_sample);
273 }
274
275 Ok(())
276 }
277
278 fn update_fixed_effects<R: Rng + ?Sized>(
280 &mut self,
281 random_effects: &Array2<f64>,
282 x_level2: &ArrayView2<f64>,
283 rng: &mut R,
284 ) -> Result<()> {
285 let n_random_effects = self.fixed_effects.nrows();
286 let n_level2_predictors = self.fixed_effects.ncols();
287
288 for i in 0..n_random_effects {
289 let y_i = random_effects.column(i);
291
292 let prior_precision = 1e-6;
294 let prior_mean = 0.0;
295
296 let tau_ii = self.random_effects_cov[[i, i]];
298 let likelihood_precision = 1.0 / tau_ii;
299
300 let xtx = x_level2.t().dot(x_level2);
302 let precision_posterior =
303 Array2::eye(n_level2_predictors) * prior_precision + xtx * likelihood_precision;
304 let covariance_posterior = scirs2_linalg::inv(&precision_posterior.view(), None)
305 .map_err(|e| {
306 StatsError::ComputationError(format!("Failed to invert precision: {}", e))
307 })?;
308
309 let xty = x_level2.t().dot(&y_i);
310 let data_contrib = xty * likelihood_precision;
311 let prior_contrib =
312 Array1::from_elem(n_level2_predictors, prior_mean * prior_precision);
313 let mean_posterior = covariance_posterior.dot(&(data_contrib + prior_contrib));
314
315 let sample = sample_multivariate_normal(&mean_posterior, &covariance_posterior, rng)?;
317 self.fixed_effects.row_mut(i).assign(&sample);
318 }
319
320 Ok(())
321 }
322
323 fn update_residual_variance<R: Rng + ?Sized>(
325 &mut self,
326 y: &ArrayView1<f64>,
327 x_level1: &ArrayView2<f64>,
328 random_effects: &Array2<f64>,
329 rng: &mut R,
330 ) -> Result<()> {
331 let n_obs = y.len();
332
333 let mut residuals_sum_sq = 0.0;
335 for (obs_idx, &group) in self.groups.iter().enumerate() {
336 let y_obs = y[obs_idx];
337 let x_obs = x_level1.row(obs_idx);
338
339 let intercept = random_effects[[group, 0]];
341 let mut y_pred = intercept;
342
343 if self.random_slopes && random_effects.ncols() > 1 {
344 for j in 0..self.n_level1_predictors {
345 y_pred += random_effects[[group, j + 1]] * x_obs[j];
346 }
347 }
348
349 let residual = y_obs - y_pred;
350 residuals_sum_sq += residual * residual;
351 }
352
353 let alpha_prior = 1e-3;
355 let beta_prior = 1e-3;
356
357 let alpha_posterior = alpha_prior + n_obs as f64 / 2.0;
359 let beta_posterior = beta_prior + residuals_sum_sq / 2.0;
360
361 let gamma_dist = Gamma::new(alpha_posterior, 1.0 / beta_posterior).map_err(|e| {
363 StatsError::ComputationError(format!("Failed to create Gamma distribution: {}", e))
364 })?;
365 let precision_sample = gamma_dist.sample(rng);
366 self.residual_variance = 1.0 / precision_sample;
367
368 Ok(())
369 }
370
371 fn update_random_effects_covariance<R: scirs2_core::random::Rng + ?Sized>(
373 &mut self,
374 random_effects: &Array2<f64>,
375 rng: &mut R,
376 ) -> Result<()> {
377 let n_random_effects = random_effects.ncols();
378 let n_groups = random_effects.nrows();
379
380 let mut sum_outer_products = Array2::<f64>::zeros((n_random_effects, n_random_effects));
382
383 for group in 0..n_groups {
384 let _effects = random_effects.row(group);
385 let outer = outer_product(&_effects.to_owned());
386 sum_outer_products = sum_outer_products + outer;
387 }
388
389 let nu_prior = n_random_effects as f64 + 2.0; let psi_prior = Array2::<f64>::eye(n_random_effects) * 0.1; let nu_posterior = nu_prior + n_groups as f64;
395 let psi_posterior = psi_prior + sum_outer_products;
396
397 let mut new_cov = Array2::<f64>::zeros((n_random_effects, n_random_effects));
400
401 for i in 0..n_random_effects {
402 let alpha = nu_posterior / 2.0;
404 let beta = psi_posterior[[i, i]] / 2.0;
405
406 let gamma_dist = Gamma::new(alpha, 1.0 / beta).map_err(|e| {
407 StatsError::ComputationError(format!("Failed to create Gamma distribution: {}", e))
408 })?;
409 let precision = gamma_dist.sample(rng);
410 new_cov[[i, i]] = 1.0 / precision;
411 }
412
413 for i in 0..n_random_effects {
415 for j in (i + 1)..n_random_effects {
416 let val1: f64 = psi_posterior[[i, i]];
417 let val2: f64 = psi_posterior[[j, j]];
418 let denom: f64 = (val1 * val2).sqrt();
419 let correlation: f64 = psi_posterior[[i, j]] / denom;
420 let covariance = correlation * (new_cov[[i, i]] * new_cov[[j, j]]).sqrt();
421 new_cov[[i, j]] = covariance * 0.1; new_cov[[j, i]] = new_cov[[i, j]];
423 }
424 }
425
426 self.random_effects_cov = new_cov;
427 Ok(())
428 }
429
430 pub fn predict(
432 &self,
433 x_level1: ArrayView2<f64>,
434 x_level2: ArrayView2<f64>,
435 groups: ArrayView1<usize>,
436 ) -> Result<Array1<f64>> {
437 checkarray_finite(&x_level1, "x_level1")?;
438 checkarray_finite(&x_level2, "x_level2")?;
439
440 let n_obs = x_level1.nrows();
441 let mut predictions = Array1::zeros(n_obs);
442
443 for (obs_idx, &group) in groups.iter().enumerate() {
444 if group >= self.n_groups {
445 return Err(StatsError::InvalidArgument(format!(
446 "Group {} exceeds number of groups {}",
447 group, self.n_groups
448 )));
449 }
450
451 let x_obs = x_level1.row(obs_idx);
452
453 let zeros_array = Array1::zeros(x_level2.ncols());
455 let group_level2 = if group < x_level2.nrows() {
456 x_level2.row(group)
457 } else {
458 zeros_array.view()
460 };
461
462 let intercept = self.fixed_effects.row(0).dot(&group_level2);
464 let mut y_pred = intercept;
465
466 if self.random_slopes && self.fixed_effects.nrows() > 1 {
468 for j in 0..self.n_level1_predictors {
469 let slope = self.fixed_effects.row(j + 1).dot(&group_level2);
470 y_pred += slope * x_obs[j];
471 }
472 }
473
474 predictions[obs_idx] = y_pred;
475 }
476
477 Ok(predictions)
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct HierarchicalModelResults {
484 pub fixed_effects_samples: Array2<f64>,
486 pub random_effects_samples: Array2<f64>,
488 pub variance_samples: Array1<f64>,
490 pub tau_samples: Array2<f64>,
492 pub n_groups: usize,
494 pub n_random_effects: usize,
496 pub n_iter: usize,
498}
499
500impl HierarchicalModelResults {
501 pub fn fixed_effects_summary(&self) -> Result<Array2<f64>> {
503 let n_params = self.fixed_effects_samples.ncols();
504 let mut summary = Array2::zeros((n_params, 4)); for param in 0..n_params {
507 let samples = self.fixed_effects_samples.column(param);
508 let mean = samples.mean();
509 let std = samples.variance().sqrt();
510
511 let mut sorted_samples = samples.to_vec();
512 sorted_samples.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
513
514 let q025_idx = (0.025 * sorted_samples.len() as f64) as usize;
515 let q975_idx = (0.975 * sorted_samples.len() as f64) as usize;
516 let q025 = sorted_samples[q025_idx];
517 let q975 = sorted_samples[q975_idx.min(sorted_samples.len() - 1)];
518
519 summary[[param, 0]] = mean;
520 summary[[param, 1]] = std;
521 summary[[param, 2]] = q025;
522 summary[[param, 3]] = q975;
523 }
524
525 Ok(summary)
526 }
527
528 pub fn random_effects_variance_summary(&self) -> Result<Array2<f64>> {
530 let n_params = self.n_random_effects * self.n_random_effects;
531 let mut summary = Array2::zeros((n_params, 4));
532
533 for param in 0..n_params {
534 let samples = self.tau_samples.column(param);
535 let mean = samples.mean();
536 let std = samples.variance().sqrt();
537
538 let mut sorted_samples = samples.to_vec();
539 sorted_samples.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
540
541 let q025_idx = (0.025 * sorted_samples.len() as f64) as usize;
542 let q975_idx = (0.975 * sorted_samples.len() as f64) as usize;
543 let q025 = sorted_samples[q025_idx];
544 let q975 = sorted_samples[q975_idx.min(sorted_samples.len() - 1)];
545
546 summary[[param, 0]] = mean;
547 summary[[param, 1]] = std;
548 summary[[param, 2]] = q025;
549 summary[[param, 3]] = q975;
550 }
551
552 Ok(summary)
553 }
554}
555
556#[derive(Debug, Clone)]
558pub struct HierarchicalANOVA {
559 pub group_means: Array1<f64>,
561 pub overall_mean: f64,
563 pub between_variance: f64,
565 pub within_variance: f64,
567 pub groups: Array1<usize>,
569 pub n_groups: usize,
571}
572
573impl HierarchicalANOVA {
574 pub fn new(n_groups: usize) -> Result<Self> {
576 check_positive(n_groups, "n_groups")?;
577
578 Ok(Self {
579 group_means: Array1::zeros(n_groups),
580 overall_mean: 0.0,
581 between_variance: 1.0,
582 within_variance: 1.0,
583 groups: Array1::zeros(0),
584 n_groups,
585 })
586 }
587
588 pub fn fit_mcmc<R: Rng + ?Sized>(
590 &mut self,
591 y: ArrayView1<f64>,
592 groups: ArrayView1<usize>,
593 n_iter: usize,
594 burnin: usize,
595 rng: &mut R,
596 ) -> Result<HierarchicalANOVAResults> {
597 checkarray_finite(&y, "y")?;
598 check_positive(n_iter, "n_iter")?;
599
600 if y.len() != groups.len() {
601 return Err(StatsError::DimensionMismatch(format!(
602 "y length ({}) must match groups length ({})",
603 y.len(),
604 groups.len()
605 )));
606 }
607
608 self.groups = groups.to_owned();
609
610 let mut group_means_samples = Array2::zeros((n_iter - burnin, self.n_groups));
612 let mut overall_mean_samples_ = Array1::zeros(n_iter - burnin);
613 let mut between_var_samples = Array1::zeros(n_iter - burnin);
614 let mut within_var_samples = Array1::zeros(n_iter - burnin);
615
616 let mut group_counts = vec![0; self.n_groups];
618 let mut group_sums = vec![0.0; self.n_groups];
619
620 for (&obs_group, &obs_y) in groups.iter().zip(y.iter()) {
621 if obs_group >= self.n_groups {
622 return Err(StatsError::InvalidArgument(format!(
623 "Group {} exceeds n_groups {}",
624 obs_group, self.n_groups
625 )));
626 }
627 group_counts[obs_group] += 1;
628 group_sums[obs_group] += obs_y;
629 }
630
631 for _iter in 0..n_iter {
633 for group in 0..self.n_groups {
635 if group_counts[group] > 0 {
636 let prior_precision = 1.0 / self.between_variance;
638 let likelihood_precision = group_counts[group] as f64 / self.within_variance;
639 let posterior_precision = prior_precision + likelihood_precision;
640 let posterior_variance = 1.0 / posterior_precision;
641
642 let prior_mean_contribution = self.overall_mean * prior_precision;
643 let likelihood_mean_contribution = group_sums[group] * likelihood_precision;
644 let posterior_mean = (prior_mean_contribution + likelihood_mean_contribution)
645 / posterior_precision;
646
647 let normal =
649 Normal::new(posterior_mean, posterior_variance.sqrt()).map_err(|e| {
650 StatsError::ComputationError(format!("Failed to create normal: {}", e))
651 })?;
652 self.group_means[group] = normal.sample(rng);
653 } else {
654 let normal = Normal::new(self.overall_mean, self.between_variance.sqrt())
656 .map_err(|e| {
657 StatsError::ComputationError(format!("Failed to create normal: {}", e))
658 })?;
659 self.group_means[group] = normal.sample(rng);
660 }
661 }
662
663 let group_mean_avg = self.group_means.clone().mean();
665 let prior_variance = 10.0; let likelihood_variance = self.between_variance / self.n_groups as f64;
667 let posterior_variance = 1.0 / (1.0 / prior_variance + 1.0 / likelihood_variance);
668 let posterior_mean =
669 (0.0 / prior_variance + group_mean_avg / likelihood_variance) * posterior_variance;
670
671 let normal = Normal::new(posterior_mean, posterior_variance.sqrt()).map_err(|e| {
672 StatsError::ComputationError(format!("Failed to create normal: {}", e))
673 })?;
674 self.overall_mean = normal.sample(rng);
675
676 let sum_sq_deviations: f64 = self
678 .group_means
679 .iter()
680 .map(|&mean| (mean - self.overall_mean).powi(2))
681 .sum();
682
683 let alpha_prior = 1e-3;
684 let beta_prior = 1e-3;
685 let alpha_posterior = alpha_prior + self.n_groups as f64 / 2.0;
686 let beta_posterior = beta_prior + sum_sq_deviations / 2.0;
687
688 let gamma_dist = Gamma::new(alpha_posterior, 1.0 / beta_posterior).map_err(|e| {
689 StatsError::ComputationError(format!("Failed to create Gamma: {}", e))
690 })?;
691 let precision = gamma_dist.sample(rng);
692 self.between_variance = 1.0 / precision;
693
694 let mut within_sum_sq = 0.0;
696 let mut total_obs = 0;
697
698 for (&obs_group, &obs_y) in groups.iter().zip(y.iter()) {
699 let residual = obs_y - self.group_means[obs_group];
700 within_sum_sq += residual * residual;
701 total_obs += 1;
702 }
703
704 let alpha_posterior = alpha_prior + total_obs as f64 / 2.0;
705 let beta_posterior = beta_prior + within_sum_sq / 2.0;
706
707 let gamma_dist = Gamma::new(alpha_posterior, 1.0 / beta_posterior).map_err(|e| {
708 StatsError::ComputationError(format!("Failed to create Gamma: {}", e))
709 })?;
710 let precision = gamma_dist.sample(rng);
711 self.within_variance = 1.0 / precision;
712
713 if _iter >= burnin {
715 let sample_idx = _iter - burnin;
716 group_means_samples
717 .row_mut(sample_idx)
718 .assign(&self.group_means);
719 overall_mean_samples_[sample_idx] = self.overall_mean;
720 between_var_samples[sample_idx] = self.between_variance;
721 within_var_samples[sample_idx] = self.within_variance;
722 }
723 }
724
725 Ok(HierarchicalANOVAResults {
726 group_means_samples,
727 overall_mean_samples_,
728 between_variance_samples: between_var_samples,
729 within_variance_samples: within_var_samples,
730 n_groups: self.n_groups,
731 n_iter: n_iter - burnin,
732 })
733 }
734}
735
736#[derive(Debug, Clone)]
738pub struct HierarchicalANOVAResults {
739 pub group_means_samples: Array2<f64>,
741 pub overall_mean_samples_: Array1<f64>,
743 pub between_variance_samples: Array1<f64>,
745 pub within_variance_samples: Array1<f64>,
747 pub n_groups: usize,
749 pub n_iter: usize,
751}
752
753impl HierarchicalANOVAResults {
754 pub fn icc_samples(&self) -> Array1<f64> {
756 let mut icc = Array1::zeros(self.n_iter);
757 for i in 0..self.n_iter {
758 let between_var = self.between_variance_samples[i];
759 let within_var = self.within_variance_samples[i];
760 icc[i] = between_var / (between_var + within_var);
761 }
762 icc
763 }
764
765 pub fn prob_group_higher(&self, group_i: usize, group_j: usize) -> Result<f64> {
767 if group_i >= self.n_groups || group_j >= self.n_groups {
768 return Err(StatsError::InvalidArgument(
769 "Group indices out of bounds".to_string(),
770 ));
771 }
772
773 let mut count = 0;
774 for iter in 0..self.n_iter {
775 if self.group_means_samples[[iter, group_i]] > self.group_means_samples[[iter, group_j]]
776 {
777 count += 1;
778 }
779 }
780
781 Ok(count as f64 / self.n_iter as f64)
782 }
783}
784
785#[allow(dead_code)]
789fn sample_multivariate_normal<R: Rng + ?Sized>(
790 mean: &Array1<f64>,
791 covariance: &Array2<f64>,
792 rng: &mut R,
793) -> Result<Array1<f64>> {
794 let dim = mean.len();
795 let normal = Normal::new(0.0, 1.0)
796 .map_err(|e| StatsError::ComputationError(format!("Failed to create normal: {}", e)))?;
797
798 let z = Array1::from_shape_fn(dim, |_| normal.sample(rng));
800
801 let mut sample = Array1::zeros(dim);
803 for i in 0..dim {
804 sample[i] = mean[i] + z[i] * covariance[[i, i]].sqrt();
805 }
806
807 Ok(sample)
808}
809
810#[allow(dead_code)]
812fn outer_product(v: &Array1<f64>) -> Array2<f64> {
813 let n = v.len();
814 let mut result = Array2::zeros((n, n));
815 for i in 0..n {
816 for j in 0..n {
817 result[[i, j]] = v[i] * v[j];
818 }
819 }
820 result
821}