1use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use scirs2_linalg::solve;
13
14#[derive(Debug, Clone)]
20pub struct CountPanelResult<F> {
21 pub coefficients: Array1<F>,
23 pub irr: Array1<F>,
25 pub std_errors: Array1<F>,
27 pub z_stats: Array1<F>,
29 pub log_likelihood: F,
31 pub null_log_likelihood: F,
33 pub lr_stat: F,
35 pub lr_pvalue: F,
37 pub n_obs: usize,
39 pub fitted: Array1<F>,
41 pub pearson_resid: Array1<F>,
43 pub alpha: Option<F>,
45}
46
47#[inline]
53fn softplus<F: Float + FromPrimitive>(x: F) -> F {
54 let one = F::one();
55 let ex = if x > F::from_f64(20.0).unwrap_or(F::one()) {
56 x
57 } else {
58 (F::one() + x.exp()).ln()
59 };
60 ex
61}
62
63#[inline]
65fn log_sum_exp<F: Float + std::iter::Sum>(vals: &[F]) -> F {
66 if vals.is_empty() {
67 return F::zero();
68 }
69 let max = vals
70 .iter()
71 .copied()
72 .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
73 if max.is_infinite() {
74 return F::neg_infinity();
75 }
76 max + vals.iter().map(|&v| (v - max).exp()).sum::<F>().ln()
77}
78
79fn irls_step<F>(
82 x: &Array2<F>,
83 y: &Array1<F>,
84 beta: &Array1<F>,
85 offset: Option<&Array1<F>>,
86 alpha: F, ) -> StatsResult<(Array1<F>, F)>
88where
89 F: Float
90 + std::iter::Sum
91 + std::fmt::Debug
92 + std::fmt::Display
93 + scirs2_core::numeric::NumAssign
94 + scirs2_core::numeric::One
95 + scirs2_core::ndarray::ScalarOperand
96 + FromPrimitive
97 + Send
98 + Sync
99 + 'static,
100{
101 let n = y.len();
102 let k = beta.len();
103 let (nx, kx) = x.dim();
104 if nx != n || kx != k {
105 return Err(StatsError::DimensionMismatch(
106 "IRLS: x, y, beta dimension mismatch".to_string(),
107 ));
108 }
109
110 let mut eta: Array1<F> = Array1::zeros(n); for i in 0..n {
112 for j in 0..k {
113 eta[i] = eta[i] + x[[i, j]] * beta[j];
114 }
115 if let Some(off) = offset {
116 eta[i] = eta[i] + off[i];
117 }
118 }
119
120 let mut mu = Array1::zeros(n);
122 for i in 0..n {
123 mu[i] = eta[i].exp();
124 }
125
126 let one = F::one();
131 let mut s = Array1::zeros(k);
132 let mut h = Array2::<F>::zeros((k, k));
133 let mut ll = F::zero();
134
135 for i in 0..n {
136 let mu_i = mu[i];
137 let v_i = if alpha > F::zero() {
138 mu_i + alpha * mu_i * mu_i
139 } else {
140 mu_i
141 };
142 let w_i = mu_i * mu_i / v_i; let resid_i = y[i] - mu_i;
144
145 if alpha <= F::zero() {
147 if mu_i > F::zero() {
149 ll = ll + y[i] * mu_i.ln() - mu_i;
150 }
151 } else {
152 let r = one / alpha;
154 let rr = r + mu_i;
155 if rr > F::zero() && mu_i > F::zero() {
156 ll =
157 ll + lgamma(y[i] + r) - lgamma(r) + r * (r / rr).ln() + y[i] * (mu_i / rr).ln();
158 }
159 }
160
161 for j in 0..k {
162 s[j] = s[j] + x[[i, j]] * resid_i;
163 for l in 0..k {
164 h[[j, l]] = h[[j, l]] - x[[i, j]] * x[[i, l]] * w_i;
165 }
166 }
167 }
168
169 let neg_h: Array2<F> = h.mapv(|v| -v);
172 let delta = solve(&neg_h.view(), &s.view(), None)
173 .map_err(|e| StatsError::ComputationError(format!("IRLS solve: {e}")))?;
174 let beta_new: Array1<F> = beta
175 .iter()
176 .zip(delta.iter())
177 .map(|(&b, &d)| b + d)
178 .collect();
179
180 Ok((beta_new, ll))
181}
182
183fn lgamma<F: Float + FromPrimitive>(x: F) -> F {
185 if x <= F::zero() {
186 return F::zero();
187 }
188 let two = F::from_f64(2.0).unwrap_or(F::one());
190 let pi = F::from_f64(std::f64::consts::PI).unwrap_or(F::one());
191 let half = F::from_f64(0.5).unwrap_or(F::zero());
192 if x < F::one() {
193 return lgamma(x + F::one()) - x.ln();
195 }
196 half * (two * pi).ln() + (x - half) * x.ln() - x
197}
198
199fn hessian_se<F>(x: &Array2<F>, mu: &Array1<F>, alpha: F) -> StatsResult<Array1<F>>
201where
202 F: Float
203 + std::iter::Sum
204 + std::fmt::Debug
205 + std::fmt::Display
206 + scirs2_core::numeric::NumAssign
207 + scirs2_core::numeric::One
208 + scirs2_core::ndarray::ScalarOperand
209 + FromPrimitive
210 + Send
211 + Sync
212 + 'static,
213{
214 let (n, k) = x.dim();
215 let mut h = Array2::<F>::zeros((k, k));
216 for i in 0..n {
217 let mu_i = mu[i];
218 let v_i = if alpha > F::zero() {
219 mu_i + alpha * mu_i * mu_i
220 } else {
221 mu_i
222 };
223 let w_i = if v_i > F::zero() {
224 mu_i * mu_i / v_i
225 } else {
226 F::zero()
227 };
228 for j in 0..k {
229 for l in 0..k {
230 h[[j, l]] = h[[j, l]] - x[[i, j]] * x[[i, l]] * w_i;
231 }
232 }
233 }
234 let neg_h: Array2<F> = h.mapv(|v| -v);
235 let mut se = Array1::zeros(k);
236 for j in 0..k {
237 let mut ej = Array1::zeros(k);
238 ej[j] = F::one();
239 let vj = solve(&neg_h.view(), &ej.view(), None)
240 .map_err(|e| StatsError::ComputationError(format!("hessian_se solve: {e}")))?;
241 let var_j = vj[j];
242 se[j] = if var_j >= F::zero() {
243 var_j.sqrt()
244 } else {
245 F::zero()
246 };
247 }
248 Ok(se)
249}
250
251fn chi2_pvalue<F: Float + FromPrimitive>(chi2: F, df: usize) -> F {
253 if chi2 <= F::zero() {
254 return F::one();
255 }
256 let k = F::from_usize(df).unwrap_or(F::one());
257 let two = F::from_f64(2.0).unwrap_or(F::one());
258 let nine = F::from_f64(9.0).unwrap_or(F::one());
259 let factor = two / (nine * k);
260 let x_wh = (chi2 / k).cbrt();
261 let mu = F::one() - factor;
262 let sigma = factor.sqrt();
263 let z = (x_wh - mu) / sigma;
264 p_normal_upper(z)
265}
266
267fn p_normal_upper<F: Float + FromPrimitive>(z: F) -> F {
268 let p1 = F::from_f64(0.2316419).unwrap_or(F::zero());
269 let b1 = F::from_f64(0.319381530).unwrap_or(F::zero());
270 let b2 = F::from_f64(-0.356563782).unwrap_or(F::zero());
271 let b3 = F::from_f64(1.781477937).unwrap_or(F::zero());
272 let b4 = F::from_f64(-1.821255978).unwrap_or(F::zero());
273 let b5 = F::from_f64(1.330274429).unwrap_or(F::zero());
274 let sqrt2pi_inv = F::from_f64(0.39894228).unwrap_or(F::zero());
275 let two = F::from_f64(2.0).unwrap_or(F::one());
276
277 let abs_z = if z < F::zero() { -z } else { z };
278 let t = F::one() / (F::one() + p1 * abs_z);
279 let poly = t * (b1 + t * (b2 + t * (b3 + t * (b4 + t * b5))));
280 let phi = sqrt2pi_inv * (-(abs_z * abs_z) / two).exp();
281 let p_upper = (phi * poly).max(F::zero()).min(F::one());
282 if z >= F::zero() {
283 p_upper
284 } else {
285 F::one() - p_upper
286 }
287}
288
289pub struct PoissonFE;
301
302impl PoissonFE {
303 pub fn fit<F>(
312 x: &ArrayView2<F>,
313 y: &ArrayView1<F>,
314 entity: &[usize],
315 max_iter: usize,
316 tol: F,
317 ) -> StatsResult<CountPanelResult<F>>
318 where
319 F: Float
320 + std::iter::Sum
321 + std::fmt::Debug
322 + std::fmt::Display
323 + scirs2_core::numeric::NumAssign
324 + scirs2_core::numeric::One
325 + scirs2_core::ndarray::ScalarOperand
326 + FromPrimitive
327 + Send
328 + Sync
329 + 'static,
330 {
331 let n = y.len();
332 let (nx, k) = x.dim();
333 if nx != n || entity.len() != n {
334 return Err(StatsError::DimensionMismatch(
335 "x, y, entity lengths must match".to_string(),
336 ));
337 }
338 for i in 0..n {
340 if y[i] < F::zero() {
341 return Err(StatsError::InvalidArgument(format!(
342 "PoissonFE: y[{}] = {} is negative",
343 i, y[i]
344 )));
345 }
346 }
347
348 let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
349
350 let mut y_sum = vec![F::zero(); n_entities];
353 for (i, &eid) in entity.iter().enumerate() {
354 y_sum[eid] = y_sum[eid] + y[i];
355 }
356 let offset: Array1<F> = entity
357 .iter()
358 .map(|&eid| {
359 if y_sum[eid] > F::zero() {
360 y_sum[eid].ln()
361 } else {
362 F::zero()
363 }
364 })
365 .collect();
366
367 let x_owned = x.to_owned();
369 let y_owned = y.to_owned();
370 let mut beta = Array1::zeros(k);
371 let mut ll_prev = F::neg_infinity();
372
373 for _iter in 0..max_iter {
374 let (new_beta, ll) = irls_step(&x_owned, &y_owned, &beta, Some(&offset), F::zero())?;
375 let delta = new_beta
376 .iter()
377 .zip(beta.iter())
378 .map(|(&a, &b)| (a - b) * (a - b))
379 .sum::<F>()
380 .sqrt();
381 beta = new_beta;
382 if (ll - ll_prev).abs() < tol {
383 break;
384 }
385 ll_prev = ll;
386 }
387
388 let mut eta: Array1<F> = Array1::zeros(n);
390 for i in 0..n {
391 for j in 0..k {
392 eta[i] = eta[i] + x[[i, j]] * beta[j];
393 }
394 eta[i] = eta[i] + offset[i];
395 }
396 let fitted: Array1<F> = eta.mapv(|e: F| e.exp());
397
398 let std_errors = hessian_se(&x_owned, &fitted, F::zero())?;
400 let z_stats: Array1<F> = beta
401 .iter()
402 .zip(std_errors.iter())
403 .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
404 .collect();
405 let irr: Array1<F> = beta.mapv(|b| b.exp());
406
407 let ll_full: F = (0..n)
409 .map(|i| {
410 if fitted[i] > F::zero() {
411 y[i] * fitted[i].ln() - fitted[i]
412 } else {
413 F::zero()
414 }
415 })
416 .sum();
417
418 let ll_null: F = {
420 let mut ll_n = F::zero();
421 for (i, &eid) in entity.iter().enumerate() {
422 let y_cnt =
423 F::from_usize(entity.iter().filter(|&&e| e == eid).count()).unwrap_or(F::one());
424 let lambda = y_sum[eid] / y_cnt;
425 if lambda > F::zero() {
426 ll_n = ll_n + y[i] * lambda.ln() - lambda;
427 }
428 }
429 ll_n
430 };
431 let two = F::from_f64(2.0).unwrap_or(F::one());
432 let lr_stat = two * (ll_full - ll_null);
433 let lr_pvalue = chi2_pvalue(lr_stat, k);
434
435 let pearson_resid: Array1<F> = (0..n)
436 .map(|i| {
437 let denom = fitted[i].sqrt();
438 if denom > F::zero() {
439 (y[i] - fitted[i]) / denom
440 } else {
441 F::zero()
442 }
443 })
444 .collect();
445
446 Ok(CountPanelResult {
447 coefficients: beta,
448 irr,
449 std_errors,
450 z_stats,
451 log_likelihood: ll_full,
452 null_log_likelihood: ll_null,
453 lr_stat,
454 lr_pvalue,
455 n_obs: n,
456 fitted,
457 pearson_resid,
458 alpha: None,
459 })
460 }
461}
462
463pub struct NegBinomFE;
473
474impl NegBinomFE {
475 pub fn fit<F>(
484 x: &ArrayView2<F>,
485 y: &ArrayView1<F>,
486 entity: &[usize],
487 max_iter: usize,
488 tol: F,
489 ) -> StatsResult<CountPanelResult<F>>
490 where
491 F: Float
492 + std::iter::Sum
493 + std::fmt::Debug
494 + std::fmt::Display
495 + scirs2_core::numeric::NumAssign
496 + scirs2_core::numeric::One
497 + scirs2_core::ndarray::ScalarOperand
498 + FromPrimitive
499 + Send
500 + Sync
501 + 'static,
502 {
503 let n = y.len();
504 let (nx, k) = x.dim();
505 if nx != n || entity.len() != n {
506 return Err(StatsError::DimensionMismatch(
507 "x, y, entity lengths must match".to_string(),
508 ));
509 }
510 for i in 0..n {
511 if y[i] < F::zero() {
512 return Err(StatsError::InvalidArgument(format!(
513 "NegBinomFE: y[{}] = {} is negative",
514 i, y[i]
515 )));
516 }
517 }
518
519 let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
520 let mut y_sum = vec![F::zero(); n_entities];
521 for (i, &eid) in entity.iter().enumerate() {
522 y_sum[eid] = y_sum[eid] + y[i];
523 }
524 let offset: Array1<F> = entity
525 .iter()
526 .map(|&eid| {
527 if y_sum[eid] > F::zero() {
528 y_sum[eid].ln()
529 } else {
530 F::zero()
531 }
532 })
533 .collect();
534
535 let x_owned = x.to_owned();
536 let y_owned = y.to_owned();
537
538 let mut beta = Array1::zeros(k);
540 let mut ll_prev = F::neg_infinity();
541 for _iter in 0..max_iter {
542 let (new_beta, ll) = irls_step(&x_owned, &y_owned, &beta, Some(&offset), F::zero())?;
543 let delta = new_beta
544 .iter()
545 .zip(beta.iter())
546 .map(|(&a, &b)| (a - b).abs())
547 .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
548 beta = new_beta;
549 if (ll - ll_prev).abs() < tol {
550 break;
551 }
552 ll_prev = ll;
553 }
554
555 let mut eta: Array1<F> = Array1::zeros(n);
557 for i in 0..n {
558 for j in 0..k {
559 eta[i] = eta[i] + x[[i, j]] * beta[j];
560 }
561 eta[i] = eta[i] + offset[i];
562 }
563 let mu_pois: Array1<F> = eta.mapv(|e: F| e.exp());
564 let pearson_chi2: F = (0..n)
565 .map(|i| {
566 let diff = y[i] - mu_pois[i];
567 if mu_pois[i] > F::zero() {
568 diff * diff / mu_pois[i]
569 } else {
570 F::zero()
571 }
572 })
573 .sum();
574 let df = if n > k { n - k } else { 1 };
575 let df_f = F::from_usize(df).unwrap_or(F::one());
576 let n_f = F::from_usize(n).unwrap_or(F::one());
577 let mean_mu = mu_pois.iter().copied().sum::<F>() / n_f;
579 let disp = pearson_chi2 / df_f;
580 let alpha_init = if disp > F::one() && mean_mu > F::zero() {
581 (disp - F::one()) / mean_mu
582 } else {
583 F::from_f64(1e-4).unwrap_or(F::zero())
584 };
585
586 let mut alpha = alpha_init;
588 ll_prev = F::neg_infinity();
589 for _iter in 0..max_iter {
590 let (new_beta, ll) = irls_step(&x_owned, &y_owned, &beta, Some(&offset), alpha)?;
591 let delta_b = new_beta
592 .iter()
593 .zip(beta.iter())
594 .map(|(&a, &b)| (a - b).abs())
595 .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
596 beta = new_beta;
597
598 let mut eta2: Array1<F> = Array1::zeros(n);
600 for i in 0..n {
601 for j in 0..k {
602 eta2[i] = eta2[i] + x[[i, j]] * beta[j];
603 }
604 eta2[i] = eta2[i] + offset[i];
605 }
606 let mu2: Array1<F> = eta2.mapv(|e: F| e.exp());
607 let pc: F = (0..n)
608 .map(|i| {
609 let diff = y[i] - mu2[i];
610 if mu2[i] > F::zero() {
611 diff * diff / mu2[i] - F::one()
612 } else {
613 F::zero()
614 }
615 })
616 .sum();
617 let denom_a: F = mu2.iter().map(|&m| m * m).sum::<F>();
618 let new_alpha = if denom_a > F::zero() {
619 let a = pc / denom_a;
620 if a > F::zero() {
621 a
622 } else {
623 F::from_f64(1e-10).unwrap_or(F::zero())
624 }
625 } else {
626 alpha
627 };
628 let delta_a = (new_alpha - alpha).abs();
629 alpha = new_alpha;
630
631 if (ll - ll_prev).abs() < tol && delta_a < tol {
632 break;
633 }
634 ll_prev = ll;
635 }
636
637 let mut eta_f: Array1<F> = Array1::zeros(n);
639 for i in 0..n {
640 for j in 0..k {
641 eta_f[i] = eta_f[i] + x[[i, j]] * beta[j];
642 }
643 eta_f[i] = eta_f[i] + offset[i];
644 }
645 let fitted: Array1<F> = eta_f.mapv(|e: F| e.exp());
646 let std_errors = hessian_se(&x_owned, &fitted, alpha)?;
647 let z_stats: Array1<F> = beta
648 .iter()
649 .zip(std_errors.iter())
650 .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
651 .collect();
652 let irr: Array1<F> = beta.mapv(|b| b.exp());
653
654 let one = F::one();
655 let ll_full: F = (0..n)
656 .map(|i| {
657 let r = one / alpha;
658 let rr = r + fitted[i];
659 if rr > F::zero() && fitted[i] > F::zero() {
660 lgamma(y[i] + r) - lgamma(r) + r * (r / rr).ln() + y[i] * (fitted[i] / rr).ln()
661 } else {
662 F::zero()
663 }
664 })
665 .sum();
666 let ll_null: F = {
667 let mut ll_n = F::zero();
668 for (i, &eid) in entity.iter().enumerate() {
669 let y_cnt =
670 F::from_usize(entity.iter().filter(|&&e| e == eid).count()).unwrap_or(F::one());
671 let lam = y_sum[eid] / y_cnt;
672 if lam > F::zero() {
673 ll_n = ll_n + y[i] * lam.ln() - lam;
674 }
675 }
676 ll_n
677 };
678 let two = F::from_f64(2.0).unwrap_or(F::one());
679 let lr_stat = two * (ll_full - ll_null);
680 let lr_pvalue = chi2_pvalue(lr_stat, k);
681 let pearson_resid: Array1<F> = (0..n)
682 .map(|i| {
683 let v = if alpha > F::zero() {
684 fitted[i] + alpha * fitted[i] * fitted[i]
685 } else {
686 fitted[i]
687 };
688 if v > F::zero() {
689 (y[i] - fitted[i]) / v.sqrt()
690 } else {
691 F::zero()
692 }
693 })
694 .collect();
695
696 Ok(CountPanelResult {
697 coefficients: beta,
698 irr,
699 std_errors,
700 z_stats,
701 log_likelihood: ll_full,
702 null_log_likelihood: ll_null,
703 lr_stat,
704 lr_pvalue,
705 n_obs: n,
706 fitted,
707 pearson_resid,
708 alpha: Some(alpha),
709 })
710 }
711}
712
713#[derive(Debug, Clone, Copy, PartialEq, Eq)]
719pub enum CountDistribution {
720 Poisson,
721 NegativeBinomial,
722}
723
724pub struct ZeroInflated;
730
731impl ZeroInflated {
732 pub fn fit<F>(
742 x: &ArrayView2<F>,
743 z: &ArrayView2<F>,
744 y: &ArrayView1<F>,
745 dist: CountDistribution,
746 max_iter: usize,
747 tol: F,
748 ) -> StatsResult<ZeroInflatedResult<F>>
749 where
750 F: Float
751 + std::iter::Sum
752 + std::fmt::Debug
753 + std::fmt::Display
754 + scirs2_core::numeric::NumAssign
755 + scirs2_core::numeric::One
756 + scirs2_core::ndarray::ScalarOperand
757 + FromPrimitive
758 + Send
759 + Sync
760 + 'static,
761 {
762 let n = y.len();
763 let (nx, kx) = x.dim();
764 let (nz, kz) = z.dim();
765 if nx != n || nz != n {
766 return Err(StatsError::DimensionMismatch(
767 "x, z, y lengths must match".to_string(),
768 ));
769 }
770 for i in 0..n {
771 if y[i] < F::zero() {
772 return Err(StatsError::InvalidArgument(format!(
773 "ZeroInflated: y[{}] = {} is negative",
774 i, y[i]
775 )));
776 }
777 }
778
779 let x_owned = x.to_owned();
780 let z_owned = z.to_owned();
781 let y_owned = y.to_owned();
782
783 let mut beta_count = Array1::zeros(kx); let mut gamma_inflate = Array1::zeros(kz); let mut alpha = F::from_f64(1e-4).unwrap_or(F::zero()); let mut ll_prev = F::neg_infinity();
789
790 for _iter in 0..max_iter {
791 let mut eta_c: Array1<F> = Array1::zeros(n);
795 for i in 0..n {
796 for j in 0..kx {
797 eta_c[i] = eta_c[i] + x[[i, j]] * beta_count[j];
798 }
799 }
800 let mu: Array1<F> = eta_c.mapv(|e: F| e.exp());
801
802 let mut eta_z: Array1<F> = Array1::zeros(n);
803 for i in 0..n {
804 for j in 0..kz {
805 eta_z[i] = eta_z[i] + z[[i, j]] * gamma_inflate[j];
806 }
807 }
808 let pi: Array1<F> = eta_z.mapv(|e: F| {
809 let ex = e.exp();
810 ex / (F::one() + ex)
811 });
812
813 let p0_count: Array1<F> = (0..n)
815 .map(|i| {
816 if dist == CountDistribution::Poisson {
817 (-mu[i]).exp()
818 } else {
819 let r = F::one() / alpha;
820 let rr = r + mu[i];
821 if rr > F::zero() {
822 (r / rr).powf(r)
823 } else {
824 F::zero()
825 }
826 }
827 })
828 .collect();
829
830 let w: Array1<F> = (0..n)
832 .map(|i| {
833 if y[i] > F::zero() {
834 F::zero()
835 } else {
836 let pi_i = pi[i];
837 let denom = pi_i + (F::one() - pi_i) * p0_count[i];
838 if denom > F::zero() {
839 pi_i / denom
840 } else {
841 F::zero()
842 }
843 }
844 })
845 .collect();
846
847 let (new_gamma, _) = logistic_irls(&z_owned, &w, &gamma_inflate, 5)?;
849 gamma_inflate = new_gamma;
850
851 let yw: Array1<F> = (0..n).map(|i| (F::one() - w[i]) * y[i]).collect();
854 let (new_beta, ll_count) = irls_step(&x_owned, &yw, &beta_count, None, alpha)?;
855 beta_count = new_beta;
856
857 if dist == CountDistribution::NegativeBinomial {
859 let mut eta_new: Array1<F> = Array1::zeros(n);
860 for i in 0..n {
861 for j in 0..kx {
862 eta_new[i] = eta_new[i] + x[[i, j]] * beta_count[j];
863 }
864 }
865 let mu_new: Array1<F> = eta_new.mapv(|e: F| e.exp());
866 let pc: F = (0..n)
867 .map(|i| {
868 let wt = F::one() - w[i];
869 let diff = yw[i] - mu_new[i];
870 if mu_new[i] > F::zero() {
871 wt * (diff * diff / mu_new[i] - F::one())
872 } else {
873 F::zero()
874 }
875 })
876 .sum();
877 let denom_a: F = (0..n)
878 .map(|i| (F::one() - w[i]) * mu_new[i] * mu_new[i])
879 .sum();
880 if denom_a > F::zero() {
881 let a_new = pc / denom_a;
882 if a_new > F::zero() {
883 alpha = a_new;
884 }
885 }
886 }
887
888 let ll: F = (0..n)
890 .map(|i| {
891 let pi_i = pi[i];
892 let mu_i = mu[i];
893 if y[i] > F::zero() {
894 let log_p = if dist == CountDistribution::Poisson {
896 y[i] * mu_i.ln() - mu_i
897 } else {
898 let r = F::one() / alpha;
899 let rr = r + mu_i;
900 lgamma(y[i] + r) - lgamma(r)
901 + r * (r / rr).ln()
902 + y[i] * (mu_i / rr).ln()
903 };
904 (F::one() - pi_i).ln() + log_p
905 } else {
906 let val = pi_i + (F::one() - pi_i) * p0_count[i];
908 if val > F::zero() {
909 val.ln()
910 } else {
911 F::from_f64(-1e10).unwrap_or(F::zero())
912 }
913 }
914 })
915 .sum();
916
917 if (ll - ll_prev).abs() < tol {
918 break;
919 }
920 ll_prev = ll;
921 }
922
923 let mut eta_f: Array1<F> = Array1::zeros(n);
925 for i in 0..n {
926 for j in 0..kx {
927 eta_f[i] = eta_f[i] + x[[i, j]] * beta_count[j];
928 }
929 }
930 let mu_f: Array1<F> = eta_f.mapv(|e: F| e.exp());
931
932 let mut eta_zf: Array1<F> = Array1::zeros(n);
933 for i in 0..n {
934 for j in 0..kz {
935 eta_zf[i] = eta_zf[i] + z[[i, j]] * gamma_inflate[j];
936 }
937 }
938 let pi_f: Array1<F> = eta_zf.mapv(|e| {
939 let ex = e.exp();
940 ex / (F::one() + ex)
941 });
942 let fitted: Array1<F> = (0..n).map(|i| (F::one() - pi_f[i]) * mu_f[i]).collect();
943
944 let se_count = hessian_se(&x.to_owned(), &mu_f, alpha)?;
945 let z_stats_count: Array1<F> = beta_count
946 .iter()
947 .zip(se_count.iter())
948 .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
949 .collect();
950 let irr: Array1<F> = beta_count.mapv(|b| b.exp());
951
952 let ll_full = ll_prev;
953 let ll_null = {
954 let y_mean = y.iter().copied().sum::<F>() / F::from_usize(n).unwrap_or(F::one());
955 if y_mean > F::zero() {
956 let ln_lam = y_mean.ln();
957 (0..n).map(|i| y[i] * ln_lam - y_mean).sum::<F>()
958 } else {
959 F::zero()
960 }
961 };
962 let two = F::from_f64(2.0).unwrap_or(F::one());
963 let lr_stat = two * (ll_full - ll_null);
964 let lr_pvalue = chi2_pvalue(lr_stat, kx + kz);
965
966 let pearson_resid: Array1<F> = (0..n)
967 .map(|i| {
968 let denom = fitted[i].sqrt();
969 if denom > F::zero() {
970 (y[i] - fitted[i]) / denom
971 } else {
972 F::zero()
973 }
974 })
975 .collect();
976
977 Ok(ZeroInflatedResult {
978 count_coefficients: beta_count,
979 inflate_coefficients: gamma_inflate,
980 irr,
981 count_std_errors: se_count,
982 count_z_stats: z_stats_count,
983 log_likelihood: ll_full,
984 null_log_likelihood: ll_null,
985 lr_stat,
986 lr_pvalue,
987 n_obs: n,
988 fitted,
989 pearson_resid,
990 alpha: if dist == CountDistribution::NegativeBinomial {
991 Some(alpha)
992 } else {
993 None
994 },
995 })
996 }
997}
998
999#[derive(Debug, Clone)]
1001pub struct ZeroInflatedResult<F> {
1002 pub count_coefficients: Array1<F>,
1004 pub inflate_coefficients: Array1<F>,
1006 pub irr: Array1<F>,
1008 pub count_std_errors: Array1<F>,
1010 pub count_z_stats: Array1<F>,
1012 pub log_likelihood: F,
1014 pub null_log_likelihood: F,
1016 pub lr_stat: F,
1018 pub lr_pvalue: F,
1020 pub n_obs: usize,
1022 pub fitted: Array1<F>,
1024 pub pearson_resid: Array1<F>,
1026 pub alpha: Option<F>,
1028}
1029
1030fn logistic_irls<F>(
1035 z: &Array2<F>,
1036 w: &Array1<F>, gamma: &Array1<F>,
1038 max_iter: usize,
1039) -> StatsResult<(Array1<F>, F)>
1040where
1041 F: Float
1042 + std::iter::Sum
1043 + std::fmt::Debug
1044 + std::fmt::Display
1045 + scirs2_core::numeric::NumAssign
1046 + scirs2_core::numeric::One
1047 + scirs2_core::ndarray::ScalarOperand
1048 + FromPrimitive
1049 + Send
1050 + Sync
1051 + 'static,
1052{
1053 let n = w.len();
1054 let (nz, kz) = z.dim();
1055 if nz != n || gamma.len() != kz {
1056 return Err(StatsError::DimensionMismatch(
1057 "logistic_irls dimension mismatch".to_string(),
1058 ));
1059 }
1060 let mut g = gamma.to_owned();
1061 let mut ll = F::zero();
1062 for _iter in 0..max_iter {
1063 let mut eta: Array1<F> = Array1::zeros(n);
1065 for i in 0..n {
1066 for j in 0..kz {
1067 eta[i] = eta[i] + z[[i, j]] * g[j];
1068 }
1069 }
1070 let pi: Array1<F> = eta.mapv(|e: F| {
1071 let ex = e.exp();
1072 ex / (F::one() + ex)
1073 });
1074 let mut s: Array1<F> = Array1::zeros(kz);
1076 let mut h = Array2::<F>::zeros((kz, kz));
1077 ll = F::zero();
1078 for i in 0..n {
1079 let pi_i = pi[i];
1080 let resid = w[i] - pi_i;
1081 let w_i = pi_i * (F::one() - pi_i);
1082 for j in 0..kz {
1083 s[j] = s[j] + z[[i, j]] * resid;
1084 for l in 0..kz {
1085 h[[j, l]] = h[[j, l]] - z[[i, j]] * z[[i, l]] * w_i;
1086 }
1087 }
1088 let p_i = if pi_i > F::from_f64(1e-12).unwrap_or(F::zero()) {
1089 pi_i
1090 } else {
1091 F::from_f64(1e-12).unwrap_or(F::zero())
1092 };
1093 let one_p = F::one() - p_i;
1094 ll = ll
1095 + w[i] * p_i.ln()
1096 + (F::one() - w[i]) * one_p.max(F::from_f64(1e-12).unwrap_or(F::zero())).ln();
1097 }
1098 let neg_h: Array2<F> = h.mapv(|v| -v);
1099 let delta = solve(&neg_h.view(), &s.view(), None)
1100 .map_err(|e| StatsError::ComputationError(format!("logistic_irls solve: {e}")))?;
1101 g = g.iter().zip(delta.iter()).map(|(&b, &d)| b + d).collect();
1102 }
1103 Ok((g, ll))
1104}
1105
1106#[cfg(test)]
1111mod tests {
1112 use super::*;
1113 use scirs2_core::ndarray::{Array1, Array2};
1114
1115 fn make_count_panel() -> (Array2<f64>, Array1<f64>, Vec<usize>) {
1116 let n_ent = 10;
1117 let t_per = 5;
1118 let n = n_ent * t_per;
1119 let entity: Vec<usize> = (0..n_ent)
1120 .flat_map(|e| std::iter::repeat(e).take(t_per))
1121 .collect();
1122 let eff = [0.5, 0.8, 1.0, 1.2, 1.5, 0.6, 0.9, 1.1, 1.3, 1.6_f64];
1123 let mut x_vals = Vec::with_capacity(n);
1124 let mut y_vals = Vec::with_capacity(n);
1125 for (i, &eid) in entity.iter().enumerate() {
1126 let x_v = (i % t_per) as f64 * 0.5 + 0.5;
1127 x_vals.push(x_v);
1128 let lambda = (1.0 + 0.5 * x_v) * eff[eid];
1129 y_vals.push(lambda.round());
1131 }
1132 let x = Array2::from_shape_vec((n, 1), x_vals).unwrap();
1133 let y = Array1::from(y_vals);
1134 (x, y, entity)
1135 }
1136
1137 #[test]
1138 fn test_poisson_fe_fit() {
1139 let (x, y, entity) = make_count_panel();
1140 let result =
1141 PoissonFE::fit(&x.view(), &y.view(), &entity, 100, 1e-8).expect("PoissonFE fit failed");
1142 assert!(result.log_likelihood.is_finite());
1143 assert_eq!(result.irr.len(), 1);
1144 assert!(result.irr[0] > 0.0, "IRR should be positive");
1145 }
1146
1147 #[test]
1148 fn test_negbinom_fe_fit() {
1149 let (x, y, entity) = make_count_panel();
1150 let result = NegBinomFE::fit(&x.view(), &y.view(), &entity, 50, 1e-6)
1151 .expect("NegBinomFE fit failed");
1152 assert!(result.log_likelihood.is_finite());
1153 assert!(result.alpha.is_some());
1154 let alpha = result.alpha.unwrap();
1155 assert!(alpha >= 0.0, "alpha should be non-negative");
1156 }
1157
1158 #[test]
1159 fn test_zero_inflated_poisson() {
1160 let (x_count, y, entity) = make_count_panel();
1161 let z = Array2::<f64>::ones((y.len(), 1));
1163 let result = ZeroInflated::fit(
1164 &x_count.view(),
1165 &z.view(),
1166 &y.view(),
1167 CountDistribution::Poisson,
1168 50,
1169 1e-6,
1170 )
1171 .expect("ZIP fit failed");
1172 assert!(result.log_likelihood.is_finite());
1173 assert_eq!(result.irr.len(), 1);
1174 }
1175
1176 #[test]
1177 fn test_irr_positive() {
1178 let (x, y, entity) = make_count_panel();
1179 let result =
1180 PoissonFE::fit(&x.view(), &y.view(), &entity, 100, 1e-8).expect("PoissonFE fit failed");
1181 for &irr in result.irr.iter() {
1182 assert!(irr > 0.0, "All IRRs must be positive");
1183 }
1184 }
1185}