Skip to main content

scirs2_stats/panel/
count_models.rs

1//! Count Data Panel Models
2//!
3//! Implements:
4//! - `PoissonFE`: Poisson fixed effects (conditional ML)
5//! - `NegBinomFE`: Negative binomial fixed effects
6//! - `ZeroInflated`: Zero-inflated Poisson / NB models
7//! - `CountPanelResult`: IRR (incidence rate ratios), std errors, LR test
8
9use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use scirs2_linalg::solve;
13
14// ──────────────────────────────────────────────────────────────────────────────
15// CountPanelResult
16// ──────────────────────────────────────────────────────────────────────────────
17
18/// Result from a count-data panel model.
19#[derive(Debug, Clone)]
20pub struct CountPanelResult<F> {
21    /// Coefficient estimates (log-scale)
22    pub coefficients: Array1<F>,
23    /// Incidence rate ratios (IRR = exp(coeff))
24    pub irr: Array1<F>,
25    /// Standard errors (log-scale)
26    pub std_errors: Array1<F>,
27    /// z-statistics
28    pub z_stats: Array1<F>,
29    /// Log-likelihood of the fitted model
30    pub log_likelihood: F,
31    /// Log-likelihood of the null (intercept-only) model
32    pub null_log_likelihood: F,
33    /// LR test statistic: 2*(LL_full - LL_null)
34    pub lr_stat: F,
35    /// p-value for LR test (chi²(K))
36    pub lr_pvalue: F,
37    /// Number of observations
38    pub n_obs: usize,
39    /// Fitted (expected) counts
40    pub fitted: Array1<F>,
41    /// Pearson residuals: (y - fitted) / sqrt(fitted)
42    pub pearson_resid: Array1<F>,
43    /// Over-dispersion parameter α (only for NegBinom models)
44    pub alpha: Option<F>,
45}
46
47// ──────────────────────────────────────────────────────────────────────────────
48// Helpers
49// ──────────────────────────────────────────────────────────────────────────────
50
51/// Softplus to keep values positive: log(1 + exp(x)).
52#[inline]
53fn softplus<F: Float + FromPrimitive>(x: F) -> F {
54    let one = F::one();
55    let ex = if x > F::from_f64(20.0).unwrap_or(F::one()) {
56        x
57    } else {
58        (F::one() + x.exp()).ln()
59    };
60    ex
61}
62
63/// log-sum-exp trick.
64#[inline]
65fn log_sum_exp<F: Float + std::iter::Sum>(vals: &[F]) -> F {
66    if vals.is_empty() {
67        return F::zero();
68    }
69    let max = vals
70        .iter()
71        .copied()
72        .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
73    if max.is_infinite() {
74        return F::neg_infinity();
75    }
76    max + vals.iter().map(|&v| (v - max).exp()).sum::<F>().ln()
77}
78
79/// Newton-Raphson update for Poisson / NegBinom IRLS.
80/// Returns updated beta and current log-likelihood.
81fn irls_step<F>(
82    x: &Array2<F>,
83    y: &Array1<F>,
84    beta: &Array1<F>,
85    offset: Option<&Array1<F>>,
86    alpha: F, // dispersion; 0 = Poisson, >0 = NegBinom
87) -> StatsResult<(Array1<F>, F)>
88where
89    F: Float
90        + std::iter::Sum
91        + std::fmt::Debug
92        + std::fmt::Display
93        + scirs2_core::numeric::NumAssign
94        + scirs2_core::numeric::One
95        + scirs2_core::ndarray::ScalarOperand
96        + FromPrimitive
97        + Send
98        + Sync
99        + 'static,
100{
101    let n = y.len();
102    let k = beta.len();
103    let (nx, kx) = x.dim();
104    if nx != n || kx != k {
105        return Err(StatsError::DimensionMismatch(
106            "IRLS: x, y, beta dimension mismatch".to_string(),
107        ));
108    }
109
110    let mut eta: Array1<F> = Array1::zeros(n); // linear predictor
111    for i in 0..n {
112        for j in 0..k {
113            eta[i] = eta[i] + x[[i, j]] * beta[j];
114        }
115        if let Some(off) = offset {
116            eta[i] = eta[i] + off[i];
117        }
118    }
119
120    // mu = exp(eta)  (Poisson link)
121    let mut mu = Array1::zeros(n);
122    for i in 0..n {
123        mu[i] = eta[i].exp();
124    }
125
126    // Score vector s = X'(y - μ) / V(μ)
127    // Hessian H = -X' W X  where W = diag(μ / V(μ))
128    // Poisson: V(μ) = μ
129    // NegBinom(α): V(μ) = μ + α μ²
130    let one = F::one();
131    let mut s = Array1::zeros(k);
132    let mut h = Array2::<F>::zeros((k, k));
133    let mut ll = F::zero();
134
135    for i in 0..n {
136        let mu_i = mu[i];
137        let v_i = if alpha > F::zero() {
138            mu_i + alpha * mu_i * mu_i
139        } else {
140            mu_i
141        };
142        let w_i = mu_i * mu_i / v_i; // IRLS weight
143        let resid_i = y[i] - mu_i;
144
145        // log-likelihood contribution
146        if alpha <= F::zero() {
147            // Poisson: log p = y log μ - μ - log(y!)
148            if mu_i > F::zero() {
149                ll = ll + y[i] * mu_i.ln() - mu_i;
150            }
151        } else {
152            // NegBinom: log p = log Γ(y+r) - log Γ(r) - log y! + r log(r/(r+μ)) + y log(μ/(r+μ))
153            let r = one / alpha;
154            let rr = r + mu_i;
155            if rr > F::zero() && mu_i > F::zero() {
156                ll =
157                    ll + lgamma(y[i] + r) - lgamma(r) + r * (r / rr).ln() + y[i] * (mu_i / rr).ln();
158            }
159        }
160
161        for j in 0..k {
162            s[j] = s[j] + x[[i, j]] * resid_i;
163            for l in 0..k {
164                h[[j, l]] = h[[j, l]] - x[[i, j]] * x[[i, l]] * w_i;
165            }
166        }
167    }
168
169    // beta_new = beta - H^{-1} s
170    // Negate h to get positive-definite system: -H δ = s  =>  δ = (-H)^{-1} s
171    let neg_h: Array2<F> = h.mapv(|v| -v);
172    let delta = solve(&neg_h.view(), &s.view(), None)
173        .map_err(|e| StatsError::ComputationError(format!("IRLS solve: {e}")))?;
174    let beta_new: Array1<F> = beta
175        .iter()
176        .zip(delta.iter())
177        .map(|(&b, &d)| b + d)
178        .collect();
179
180    Ok((beta_new, ll))
181}
182
183/// Approximate log-gamma using Stirling's series.
184fn lgamma<F: Float + FromPrimitive>(x: F) -> F {
185    if x <= F::zero() {
186        return F::zero();
187    }
188    // Use Stirling's series: ln Γ(x) ≈ 0.5 ln(2π) + (x-0.5) ln(x) - x
189    let two = F::from_f64(2.0).unwrap_or(F::one());
190    let pi = F::from_f64(std::f64::consts::PI).unwrap_or(F::one());
191    let half = F::from_f64(0.5).unwrap_or(F::zero());
192    if x < F::one() {
193        // Use reflection: Γ(1+x) = x Γ(x)
194        return lgamma(x + F::one()) - x.ln();
195    }
196    half * (two * pi).ln() + (x - half) * x.ln() - x
197}
198
199/// Extract Hessian diagonal → standard errors.
200fn hessian_se<F>(x: &Array2<F>, mu: &Array1<F>, alpha: F) -> StatsResult<Array1<F>>
201where
202    F: Float
203        + std::iter::Sum
204        + std::fmt::Debug
205        + std::fmt::Display
206        + scirs2_core::numeric::NumAssign
207        + scirs2_core::numeric::One
208        + scirs2_core::ndarray::ScalarOperand
209        + FromPrimitive
210        + Send
211        + Sync
212        + 'static,
213{
214    let (n, k) = x.dim();
215    let mut h = Array2::<F>::zeros((k, k));
216    for i in 0..n {
217        let mu_i = mu[i];
218        let v_i = if alpha > F::zero() {
219            mu_i + alpha * mu_i * mu_i
220        } else {
221            mu_i
222        };
223        let w_i = if v_i > F::zero() {
224            mu_i * mu_i / v_i
225        } else {
226            F::zero()
227        };
228        for j in 0..k {
229            for l in 0..k {
230                h[[j, l]] = h[[j, l]] - x[[i, j]] * x[[i, l]] * w_i;
231            }
232        }
233    }
234    let neg_h: Array2<F> = h.mapv(|v| -v);
235    let mut se = Array1::zeros(k);
236    for j in 0..k {
237        let mut ej = Array1::zeros(k);
238        ej[j] = F::one();
239        let vj = solve(&neg_h.view(), &ej.view(), None)
240            .map_err(|e| StatsError::ComputationError(format!("hessian_se solve: {e}")))?;
241        let var_j = vj[j];
242        se[j] = if var_j >= F::zero() {
243            var_j.sqrt()
244        } else {
245            F::zero()
246        };
247    }
248    Ok(se)
249}
250
251/// Chi-squared upper-tail p-value.
252fn chi2_pvalue<F: Float + FromPrimitive>(chi2: F, df: usize) -> F {
253    if chi2 <= F::zero() {
254        return F::one();
255    }
256    let k = F::from_usize(df).unwrap_or(F::one());
257    let two = F::from_f64(2.0).unwrap_or(F::one());
258    let nine = F::from_f64(9.0).unwrap_or(F::one());
259    let factor = two / (nine * k);
260    let x_wh = (chi2 / k).cbrt();
261    let mu = F::one() - factor;
262    let sigma = factor.sqrt();
263    let z = (x_wh - mu) / sigma;
264    p_normal_upper(z)
265}
266
267fn p_normal_upper<F: Float + FromPrimitive>(z: F) -> F {
268    let p1 = F::from_f64(0.2316419).unwrap_or(F::zero());
269    let b1 = F::from_f64(0.319381530).unwrap_or(F::zero());
270    let b2 = F::from_f64(-0.356563782).unwrap_or(F::zero());
271    let b3 = F::from_f64(1.781477937).unwrap_or(F::zero());
272    let b4 = F::from_f64(-1.821255978).unwrap_or(F::zero());
273    let b5 = F::from_f64(1.330274429).unwrap_or(F::zero());
274    let sqrt2pi_inv = F::from_f64(0.39894228).unwrap_or(F::zero());
275    let two = F::from_f64(2.0).unwrap_or(F::one());
276
277    let abs_z = if z < F::zero() { -z } else { z };
278    let t = F::one() / (F::one() + p1 * abs_z);
279    let poly = t * (b1 + t * (b2 + t * (b3 + t * (b4 + t * b5))));
280    let phi = sqrt2pi_inv * (-(abs_z * abs_z) / two).exp();
281    let p_upper = (phi * poly).max(F::zero()).min(F::one());
282    if z >= F::zero() {
283        p_upper
284    } else {
285        F::one() - p_upper
286    }
287}
288
289// ──────────────────────────────────────────────────────────────────────────────
290// PoissonFE
291// ──────────────────────────────────────────────────────────────────────────────
292
293/// Poisson fixed-effects model (conditional ML / within-Poisson).
294///
295/// Hausman, Hall & Griliches (1984) show that the conditional ML estimator for
296/// the Poisson FE model is equivalent to the within-Poisson estimator, which
297/// is obtained by dividing each count by the entity-period total and maximising
298/// the Poisson log-likelihood on the normalised counts.  The entity fixed effects
299/// drop out of the score equations.
300pub struct PoissonFE;
301
302impl PoissonFE {
303    /// Fit a Poisson FE model via IRLS.
304    ///
305    /// # Arguments
306    /// * `x`      – (N × K) design matrix (without entity dummies or intercept)
307    /// * `y`      – count response (N), must be non-negative integers
308    /// * `entity` – entity IDs (0-indexed, length N)
309    /// * `max_iter` – maximum IRLS iterations
310    /// * `tol`      – convergence tolerance on log-likelihood
311    pub fn fit<F>(
312        x: &ArrayView2<F>,
313        y: &ArrayView1<F>,
314        entity: &[usize],
315        max_iter: usize,
316        tol: F,
317    ) -> StatsResult<CountPanelResult<F>>
318    where
319        F: Float
320            + std::iter::Sum
321            + std::fmt::Debug
322            + std::fmt::Display
323            + scirs2_core::numeric::NumAssign
324            + scirs2_core::numeric::One
325            + scirs2_core::ndarray::ScalarOperand
326            + FromPrimitive
327            + Send
328            + Sync
329            + 'static,
330    {
331        let n = y.len();
332        let (nx, k) = x.dim();
333        if nx != n || entity.len() != n {
334            return Err(StatsError::DimensionMismatch(
335                "x, y, entity lengths must match".to_string(),
336            ));
337        }
338        // Validate counts
339        for i in 0..n {
340            if y[i] < F::zero() {
341                return Err(StatsError::InvalidArgument(format!(
342                    "PoissonFE: y[{}] = {} is negative",
343                    i, y[i]
344                )));
345            }
346        }
347
348        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
349
350        // ── Compute entity-period totals for conditional ML offset ────────────
351        // Log-offset = log(y_sum_i) per entity (condition on sum)
352        let mut y_sum = vec![F::zero(); n_entities];
353        for (i, &eid) in entity.iter().enumerate() {
354            y_sum[eid] = y_sum[eid] + y[i];
355        }
356        let offset: Array1<F> = entity
357            .iter()
358            .map(|&eid| {
359                if y_sum[eid] > F::zero() {
360                    y_sum[eid].ln()
361                } else {
362                    F::zero()
363                }
364            })
365            .collect();
366
367        // ── IRLS ──────────────────────────────────────────────────────────────
368        let x_owned = x.to_owned();
369        let y_owned = y.to_owned();
370        let mut beta = Array1::zeros(k);
371        let mut ll_prev = F::neg_infinity();
372
373        for _iter in 0..max_iter {
374            let (new_beta, ll) = irls_step(&x_owned, &y_owned, &beta, Some(&offset), F::zero())?;
375            let delta = new_beta
376                .iter()
377                .zip(beta.iter())
378                .map(|(&a, &b)| (a - b) * (a - b))
379                .sum::<F>()
380                .sqrt();
381            beta = new_beta;
382            if (ll - ll_prev).abs() < tol {
383                break;
384            }
385            ll_prev = ll;
386        }
387
388        // ── Final fitted values ───────────────────────────────────────────────
389        let mut eta: Array1<F> = Array1::zeros(n);
390        for i in 0..n {
391            for j in 0..k {
392                eta[i] = eta[i] + x[[i, j]] * beta[j];
393            }
394            eta[i] = eta[i] + offset[i];
395        }
396        let fitted: Array1<F> = eta.mapv(|e: F| e.exp());
397
398        // ── SE via observed information ───────────────────────────────────────
399        let std_errors = hessian_se(&x_owned, &fitted, F::zero())?;
400        let z_stats: Array1<F> = beta
401            .iter()
402            .zip(std_errors.iter())
403            .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
404            .collect();
405        let irr: Array1<F> = beta.mapv(|b| b.exp());
406
407        // ── Log-likelihood ────────────────────────────────────────────────────
408        let ll_full: F = (0..n)
409            .map(|i| {
410                if fitted[i] > F::zero() {
411                    y[i] * fitted[i].ln() - fitted[i]
412                } else {
413                    F::zero()
414                }
415            })
416            .sum();
417
418        // Null: λ = y̅  for each entity
419        let ll_null: F = {
420            let mut ll_n = F::zero();
421            for (i, &eid) in entity.iter().enumerate() {
422                let y_cnt =
423                    F::from_usize(entity.iter().filter(|&&e| e == eid).count()).unwrap_or(F::one());
424                let lambda = y_sum[eid] / y_cnt;
425                if lambda > F::zero() {
426                    ll_n = ll_n + y[i] * lambda.ln() - lambda;
427                }
428            }
429            ll_n
430        };
431        let two = F::from_f64(2.0).unwrap_or(F::one());
432        let lr_stat = two * (ll_full - ll_null);
433        let lr_pvalue = chi2_pvalue(lr_stat, k);
434
435        let pearson_resid: Array1<F> = (0..n)
436            .map(|i| {
437                let denom = fitted[i].sqrt();
438                if denom > F::zero() {
439                    (y[i] - fitted[i]) / denom
440                } else {
441                    F::zero()
442                }
443            })
444            .collect();
445
446        Ok(CountPanelResult {
447            coefficients: beta,
448            irr,
449            std_errors,
450            z_stats,
451            log_likelihood: ll_full,
452            null_log_likelihood: ll_null,
453            lr_stat,
454            lr_pvalue,
455            n_obs: n,
456            fitted,
457            pearson_resid,
458            alpha: None,
459        })
460    }
461}
462
463// ──────────────────────────────────────────────────────────────────────────────
464// NegBinomFE
465// ──────────────────────────────────────────────────────────────────────────────
466
467/// Negative binomial fixed-effects model.
468///
469/// Models over-dispersion via Var(y) = μ + α μ².
470/// The dispersion parameter α is estimated via the method of moments
471/// from Poisson residuals, then β is re-estimated via IRLS.
472pub struct NegBinomFE;
473
474impl NegBinomFE {
475    /// Fit a negative binomial FE model.
476    ///
477    /// # Arguments
478    /// * `x`        – (N × K) design matrix
479    /// * `y`        – count response (N)
480    /// * `entity`   – entity IDs
481    /// * `max_iter` – IRLS iterations
482    /// * `tol`      – convergence tolerance
483    pub fn fit<F>(
484        x: &ArrayView2<F>,
485        y: &ArrayView1<F>,
486        entity: &[usize],
487        max_iter: usize,
488        tol: F,
489    ) -> StatsResult<CountPanelResult<F>>
490    where
491        F: Float
492            + std::iter::Sum
493            + std::fmt::Debug
494            + std::fmt::Display
495            + scirs2_core::numeric::NumAssign
496            + scirs2_core::numeric::One
497            + scirs2_core::ndarray::ScalarOperand
498            + FromPrimitive
499            + Send
500            + Sync
501            + 'static,
502    {
503        let n = y.len();
504        let (nx, k) = x.dim();
505        if nx != n || entity.len() != n {
506            return Err(StatsError::DimensionMismatch(
507                "x, y, entity lengths must match".to_string(),
508            ));
509        }
510        for i in 0..n {
511            if y[i] < F::zero() {
512                return Err(StatsError::InvalidArgument(format!(
513                    "NegBinomFE: y[{}] = {} is negative",
514                    i, y[i]
515                )));
516            }
517        }
518
519        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
520        let mut y_sum = vec![F::zero(); n_entities];
521        for (i, &eid) in entity.iter().enumerate() {
522            y_sum[eid] = y_sum[eid] + y[i];
523        }
524        let offset: Array1<F> = entity
525            .iter()
526            .map(|&eid| {
527                if y_sum[eid] > F::zero() {
528                    y_sum[eid].ln()
529                } else {
530                    F::zero()
531                }
532            })
533            .collect();
534
535        let x_owned = x.to_owned();
536        let y_owned = y.to_owned();
537
538        // ── Step 1: Poisson fit to initialise ────────────────────────────────
539        let mut beta = Array1::zeros(k);
540        let mut ll_prev = F::neg_infinity();
541        for _iter in 0..max_iter {
542            let (new_beta, ll) = irls_step(&x_owned, &y_owned, &beta, Some(&offset), F::zero())?;
543            let delta = new_beta
544                .iter()
545                .zip(beta.iter())
546                .map(|(&a, &b)| (a - b).abs())
547                .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
548            beta = new_beta;
549            if (ll - ll_prev).abs() < tol {
550                break;
551            }
552            ll_prev = ll;
553        }
554
555        // ── Estimate α from Pearson chi² of Poisson fit ───────────────────────
556        let mut eta: Array1<F> = Array1::zeros(n);
557        for i in 0..n {
558            for j in 0..k {
559                eta[i] = eta[i] + x[[i, j]] * beta[j];
560            }
561            eta[i] = eta[i] + offset[i];
562        }
563        let mu_pois: Array1<F> = eta.mapv(|e: F| e.exp());
564        let pearson_chi2: F = (0..n)
565            .map(|i| {
566                let diff = y[i] - mu_pois[i];
567                if mu_pois[i] > F::zero() {
568                    diff * diff / mu_pois[i]
569                } else {
570                    F::zero()
571                }
572            })
573            .sum();
574        let df = if n > k { n - k } else { 1 };
575        let df_f = F::from_usize(df).unwrap_or(F::one());
576        let n_f = F::from_usize(n).unwrap_or(F::one());
577        // α_hat = (Pearson χ² / (n - k) - 1) / mean(μ²/μ) = (χ²/(n-k) - 1) / mean(μ)
578        let mean_mu = mu_pois.iter().copied().sum::<F>() / n_f;
579        let disp = pearson_chi2 / df_f;
580        let alpha_init = if disp > F::one() && mean_mu > F::zero() {
581            (disp - F::one()) / mean_mu
582        } else {
583            F::from_f64(1e-4).unwrap_or(F::zero())
584        };
585
586        // ── Step 2: NB IRLS ───────────────────────────────────────────────────
587        let mut alpha = alpha_init;
588        ll_prev = F::neg_infinity();
589        for _iter in 0..max_iter {
590            let (new_beta, ll) = irls_step(&x_owned, &y_owned, &beta, Some(&offset), alpha)?;
591            let delta_b = new_beta
592                .iter()
593                .zip(beta.iter())
594                .map(|(&a, &b)| (a - b).abs())
595                .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
596            beta = new_beta;
597
598            // Update alpha: method of moments
599            let mut eta2: Array1<F> = Array1::zeros(n);
600            for i in 0..n {
601                for j in 0..k {
602                    eta2[i] = eta2[i] + x[[i, j]] * beta[j];
603                }
604                eta2[i] = eta2[i] + offset[i];
605            }
606            let mu2: Array1<F> = eta2.mapv(|e: F| e.exp());
607            let pc: F = (0..n)
608                .map(|i| {
609                    let diff = y[i] - mu2[i];
610                    if mu2[i] > F::zero() {
611                        diff * diff / mu2[i] - F::one()
612                    } else {
613                        F::zero()
614                    }
615                })
616                .sum();
617            let denom_a: F = mu2.iter().map(|&m| m * m).sum::<F>();
618            let new_alpha = if denom_a > F::zero() {
619                let a = pc / denom_a;
620                if a > F::zero() {
621                    a
622                } else {
623                    F::from_f64(1e-10).unwrap_or(F::zero())
624                }
625            } else {
626                alpha
627            };
628            let delta_a = (new_alpha - alpha).abs();
629            alpha = new_alpha;
630
631            if (ll - ll_prev).abs() < tol && delta_a < tol {
632                break;
633            }
634            ll_prev = ll;
635        }
636
637        // ── Final fit ─────────────────────────────────────────────────────────
638        let mut eta_f: Array1<F> = Array1::zeros(n);
639        for i in 0..n {
640            for j in 0..k {
641                eta_f[i] = eta_f[i] + x[[i, j]] * beta[j];
642            }
643            eta_f[i] = eta_f[i] + offset[i];
644        }
645        let fitted: Array1<F> = eta_f.mapv(|e: F| e.exp());
646        let std_errors = hessian_se(&x_owned, &fitted, alpha)?;
647        let z_stats: Array1<F> = beta
648            .iter()
649            .zip(std_errors.iter())
650            .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
651            .collect();
652        let irr: Array1<F> = beta.mapv(|b| b.exp());
653
654        let one = F::one();
655        let ll_full: F = (0..n)
656            .map(|i| {
657                let r = one / alpha;
658                let rr = r + fitted[i];
659                if rr > F::zero() && fitted[i] > F::zero() {
660                    lgamma(y[i] + r) - lgamma(r) + r * (r / rr).ln() + y[i] * (fitted[i] / rr).ln()
661                } else {
662                    F::zero()
663                }
664            })
665            .sum();
666        let ll_null: F = {
667            let mut ll_n = F::zero();
668            for (i, &eid) in entity.iter().enumerate() {
669                let y_cnt =
670                    F::from_usize(entity.iter().filter(|&&e| e == eid).count()).unwrap_or(F::one());
671                let lam = y_sum[eid] / y_cnt;
672                if lam > F::zero() {
673                    ll_n = ll_n + y[i] * lam.ln() - lam;
674                }
675            }
676            ll_n
677        };
678        let two = F::from_f64(2.0).unwrap_or(F::one());
679        let lr_stat = two * (ll_full - ll_null);
680        let lr_pvalue = chi2_pvalue(lr_stat, k);
681        let pearson_resid: Array1<F> = (0..n)
682            .map(|i| {
683                let v = if alpha > F::zero() {
684                    fitted[i] + alpha * fitted[i] * fitted[i]
685                } else {
686                    fitted[i]
687                };
688                if v > F::zero() {
689                    (y[i] - fitted[i]) / v.sqrt()
690                } else {
691                    F::zero()
692                }
693            })
694            .collect();
695
696        Ok(CountPanelResult {
697            coefficients: beta,
698            irr,
699            std_errors,
700            z_stats,
701            log_likelihood: ll_full,
702            null_log_likelihood: ll_null,
703            lr_stat,
704            lr_pvalue,
705            n_obs: n,
706            fitted,
707            pearson_resid,
708            alpha: Some(alpha),
709        })
710    }
711}
712
713// ──────────────────────────────────────────────────────────────────────────────
714// ZeroInflated
715// ──────────────────────────────────────────────────────────────────────────────
716
717/// Which count model to use in the zero-inflated model.
718#[derive(Debug, Clone, Copy, PartialEq, Eq)]
719pub enum CountDistribution {
720    Poisson,
721    NegativeBinomial,
722}
723
724/// Zero-inflated count model (ZIP or ZINB).
725///
726/// The model:  P(y=0) = π + (1-π) P_count(y=0)
727///             P(y=k) = (1-π) P_count(y=k)  for k > 0
728/// where π = logit⁻¹(z'γ) and the count part uses Poisson or NegBinom.
729pub struct ZeroInflated;
730
731impl ZeroInflated {
732    /// Fit a zero-inflated Poisson or NegBinom model.
733    ///
734    /// # Arguments
735    /// * `x`        – count part design matrix (N × K_x)
736    /// * `z`        – inflation part design matrix (N × K_z); often just an intercept
737    /// * `y`        – response count (N)
738    /// * `dist`     – `CountDistribution::Poisson` or `CountDistribution::NegativeBinomial`
739    /// * `max_iter` – EM iterations
740    /// * `tol`      – convergence tolerance
741    pub fn fit<F>(
742        x: &ArrayView2<F>,
743        z: &ArrayView2<F>,
744        y: &ArrayView1<F>,
745        dist: CountDistribution,
746        max_iter: usize,
747        tol: F,
748    ) -> StatsResult<ZeroInflatedResult<F>>
749    where
750        F: Float
751            + std::iter::Sum
752            + std::fmt::Debug
753            + std::fmt::Display
754            + scirs2_core::numeric::NumAssign
755            + scirs2_core::numeric::One
756            + scirs2_core::ndarray::ScalarOperand
757            + FromPrimitive
758            + Send
759            + Sync
760            + 'static,
761    {
762        let n = y.len();
763        let (nx, kx) = x.dim();
764        let (nz, kz) = z.dim();
765        if nx != n || nz != n {
766            return Err(StatsError::DimensionMismatch(
767                "x, z, y lengths must match".to_string(),
768            ));
769        }
770        for i in 0..n {
771            if y[i] < F::zero() {
772                return Err(StatsError::InvalidArgument(format!(
773                    "ZeroInflated: y[{}] = {} is negative",
774                    i, y[i]
775                )));
776            }
777        }
778
779        let x_owned = x.to_owned();
780        let z_owned = z.to_owned();
781        let y_owned = y.to_owned();
782
783        // Initial estimates
784        let mut beta_count = Array1::zeros(kx); // count part
785        let mut gamma_inflate = Array1::zeros(kz); // inflation part
786        let mut alpha = F::from_f64(1e-4).unwrap_or(F::zero()); // NB dispersion
787
788        let mut ll_prev = F::neg_infinity();
789
790        for _iter in 0..max_iter {
791            // ── E-step: compute P(zero-inflate | y_i, params) ─────────────────
792            // For y_i > 0: w_i = 0 (cannot be inflated zero)
793            // For y_i = 0: w_i = π_i / (π_i + (1-π_i) * p_count(0|mu_i))
794            let mut eta_c: Array1<F> = Array1::zeros(n);
795            for i in 0..n {
796                for j in 0..kx {
797                    eta_c[i] = eta_c[i] + x[[i, j]] * beta_count[j];
798                }
799            }
800            let mu: Array1<F> = eta_c.mapv(|e: F| e.exp());
801
802            let mut eta_z: Array1<F> = Array1::zeros(n);
803            for i in 0..n {
804                for j in 0..kz {
805                    eta_z[i] = eta_z[i] + z[[i, j]] * gamma_inflate[j];
806                }
807            }
808            let pi: Array1<F> = eta_z.mapv(|e: F| {
809                let ex = e.exp();
810                ex / (F::one() + ex)
811            });
812
813            // P(y=0 | count model)
814            let p0_count: Array1<F> = (0..n)
815                .map(|i| {
816                    if dist == CountDistribution::Poisson {
817                        (-mu[i]).exp()
818                    } else {
819                        let r = F::one() / alpha;
820                        let rr = r + mu[i];
821                        if rr > F::zero() {
822                            (r / rr).powf(r)
823                        } else {
824                            F::zero()
825                        }
826                    }
827                })
828                .collect();
829
830            // Posterior weights
831            let w: Array1<F> = (0..n)
832                .map(|i| {
833                    if y[i] > F::zero() {
834                        F::zero()
835                    } else {
836                        let pi_i = pi[i];
837                        let denom = pi_i + (F::one() - pi_i) * p0_count[i];
838                        if denom > F::zero() {
839                            pi_i / denom
840                        } else {
841                            F::zero()
842                        }
843                    }
844                })
845                .collect();
846
847            // ── M-step: update gamma (logistic on w) via Newton ───────────────
848            let (new_gamma, _) = logistic_irls(&z_owned, &w, &gamma_inflate, 5)?;
849            gamma_inflate = new_gamma;
850
851            // ── M-step: update beta (Poisson/NB on (1-w) weighted) ────────────
852            // Effective y for count part: only use non-inflated obs
853            let yw: Array1<F> = (0..n).map(|i| (F::one() - w[i]) * y[i]).collect();
854            let (new_beta, ll_count) = irls_step(&x_owned, &yw, &beta_count, None, alpha)?;
855            beta_count = new_beta;
856
857            // Update alpha for NB
858            if dist == CountDistribution::NegativeBinomial {
859                let mut eta_new: Array1<F> = Array1::zeros(n);
860                for i in 0..n {
861                    for j in 0..kx {
862                        eta_new[i] = eta_new[i] + x[[i, j]] * beta_count[j];
863                    }
864                }
865                let mu_new: Array1<F> = eta_new.mapv(|e: F| e.exp());
866                let pc: F = (0..n)
867                    .map(|i| {
868                        let wt = F::one() - w[i];
869                        let diff = yw[i] - mu_new[i];
870                        if mu_new[i] > F::zero() {
871                            wt * (diff * diff / mu_new[i] - F::one())
872                        } else {
873                            F::zero()
874                        }
875                    })
876                    .sum();
877                let denom_a: F = (0..n)
878                    .map(|i| (F::one() - w[i]) * mu_new[i] * mu_new[i])
879                    .sum();
880                if denom_a > F::zero() {
881                    let a_new = pc / denom_a;
882                    if a_new > F::zero() {
883                        alpha = a_new;
884                    }
885                }
886            }
887
888            // ── Log-likelihood ──────────────────────────────────────────────────
889            let ll: F = (0..n)
890                .map(|i| {
891                    let pi_i = pi[i];
892                    let mu_i = mu[i];
893                    if y[i] > F::zero() {
894                        // log((1-π_i) * p_count(y_i))
895                        let log_p = if dist == CountDistribution::Poisson {
896                            y[i] * mu_i.ln() - mu_i
897                        } else {
898                            let r = F::one() / alpha;
899                            let rr = r + mu_i;
900                            lgamma(y[i] + r) - lgamma(r)
901                                + r * (r / rr).ln()
902                                + y[i] * (mu_i / rr).ln()
903                        };
904                        (F::one() - pi_i).ln() + log_p
905                    } else {
906                        // log(π_i + (1-π_i) * p_count(0))
907                        let val = pi_i + (F::one() - pi_i) * p0_count[i];
908                        if val > F::zero() {
909                            val.ln()
910                        } else {
911                            F::from_f64(-1e10).unwrap_or(F::zero())
912                        }
913                    }
914                })
915                .sum();
916
917            if (ll - ll_prev).abs() < tol {
918                break;
919            }
920            ll_prev = ll;
921        }
922
923        // ── Final fit ─────────────────────────────────────────────────────────
924        let mut eta_f: Array1<F> = Array1::zeros(n);
925        for i in 0..n {
926            for j in 0..kx {
927                eta_f[i] = eta_f[i] + x[[i, j]] * beta_count[j];
928            }
929        }
930        let mu_f: Array1<F> = eta_f.mapv(|e: F| e.exp());
931
932        let mut eta_zf: Array1<F> = Array1::zeros(n);
933        for i in 0..n {
934            for j in 0..kz {
935                eta_zf[i] = eta_zf[i] + z[[i, j]] * gamma_inflate[j];
936            }
937        }
938        let pi_f: Array1<F> = eta_zf.mapv(|e| {
939            let ex = e.exp();
940            ex / (F::one() + ex)
941        });
942        let fitted: Array1<F> = (0..n).map(|i| (F::one() - pi_f[i]) * mu_f[i]).collect();
943
944        let se_count = hessian_se(&x.to_owned(), &mu_f, alpha)?;
945        let z_stats_count: Array1<F> = beta_count
946            .iter()
947            .zip(se_count.iter())
948            .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
949            .collect();
950        let irr: Array1<F> = beta_count.mapv(|b| b.exp());
951
952        let ll_full = ll_prev;
953        let ll_null = {
954            let y_mean = y.iter().copied().sum::<F>() / F::from_usize(n).unwrap_or(F::one());
955            if y_mean > F::zero() {
956                let ln_lam = y_mean.ln();
957                (0..n).map(|i| y[i] * ln_lam - y_mean).sum::<F>()
958            } else {
959                F::zero()
960            }
961        };
962        let two = F::from_f64(2.0).unwrap_or(F::one());
963        let lr_stat = two * (ll_full - ll_null);
964        let lr_pvalue = chi2_pvalue(lr_stat, kx + kz);
965
966        let pearson_resid: Array1<F> = (0..n)
967            .map(|i| {
968                let denom = fitted[i].sqrt();
969                if denom > F::zero() {
970                    (y[i] - fitted[i]) / denom
971                } else {
972                    F::zero()
973                }
974            })
975            .collect();
976
977        Ok(ZeroInflatedResult {
978            count_coefficients: beta_count,
979            inflate_coefficients: gamma_inflate,
980            irr,
981            count_std_errors: se_count,
982            count_z_stats: z_stats_count,
983            log_likelihood: ll_full,
984            null_log_likelihood: ll_null,
985            lr_stat,
986            lr_pvalue,
987            n_obs: n,
988            fitted,
989            pearson_resid,
990            alpha: if dist == CountDistribution::NegativeBinomial {
991                Some(alpha)
992            } else {
993                None
994            },
995        })
996    }
997}
998
999/// Result from a zero-inflated count model.
1000#[derive(Debug, Clone)]
1001pub struct ZeroInflatedResult<F> {
1002    /// Count-part coefficients (log-scale)
1003    pub count_coefficients: Array1<F>,
1004    /// Inflation-part coefficients (logit-scale)
1005    pub inflate_coefficients: Array1<F>,
1006    /// IRR for the count part
1007    pub irr: Array1<F>,
1008    /// SE for count-part coefficients
1009    pub count_std_errors: Array1<F>,
1010    /// z-statistics for count coefficients
1011    pub count_z_stats: Array1<F>,
1012    /// Log-likelihood
1013    pub log_likelihood: F,
1014    /// Null log-likelihood
1015    pub null_log_likelihood: F,
1016    /// LR statistic
1017    pub lr_stat: F,
1018    /// LR p-value
1019    pub lr_pvalue: F,
1020    /// Number of observations
1021    pub n_obs: usize,
1022    /// Fitted (expected) counts
1023    pub fitted: Array1<F>,
1024    /// Pearson residuals
1025    pub pearson_resid: Array1<F>,
1026    /// Over-dispersion α (NB only)
1027    pub alpha: Option<F>,
1028}
1029
1030// ──────────────────────────────────────────────────────────────────────────────
1031// Logistic IRLS (for inflation part)
1032// ──────────────────────────────────────────────────────────────────────────────
1033
1034fn logistic_irls<F>(
1035    z: &Array2<F>,
1036    w: &Array1<F>, // posterior P(inflate=1 | y_i=0)
1037    gamma: &Array1<F>,
1038    max_iter: usize,
1039) -> StatsResult<(Array1<F>, F)>
1040where
1041    F: Float
1042        + std::iter::Sum
1043        + std::fmt::Debug
1044        + std::fmt::Display
1045        + scirs2_core::numeric::NumAssign
1046        + scirs2_core::numeric::One
1047        + scirs2_core::ndarray::ScalarOperand
1048        + FromPrimitive
1049        + Send
1050        + Sync
1051        + 'static,
1052{
1053    let n = w.len();
1054    let (nz, kz) = z.dim();
1055    if nz != n || gamma.len() != kz {
1056        return Err(StatsError::DimensionMismatch(
1057            "logistic_irls dimension mismatch".to_string(),
1058        ));
1059    }
1060    let mut g = gamma.to_owned();
1061    let mut ll = F::zero();
1062    for _iter in 0..max_iter {
1063        // pi = logistic(z γ)
1064        let mut eta: Array1<F> = Array1::zeros(n);
1065        for i in 0..n {
1066            for j in 0..kz {
1067                eta[i] = eta[i] + z[[i, j]] * g[j];
1068            }
1069        }
1070        let pi: Array1<F> = eta.mapv(|e: F| {
1071            let ex = e.exp();
1072            ex / (F::one() + ex)
1073        });
1074        // Score: s = Z' (w - π)
1075        let mut s: Array1<F> = Array1::zeros(kz);
1076        let mut h = Array2::<F>::zeros((kz, kz));
1077        ll = F::zero();
1078        for i in 0..n {
1079            let pi_i = pi[i];
1080            let resid = w[i] - pi_i;
1081            let w_i = pi_i * (F::one() - pi_i);
1082            for j in 0..kz {
1083                s[j] = s[j] + z[[i, j]] * resid;
1084                for l in 0..kz {
1085                    h[[j, l]] = h[[j, l]] - z[[i, j]] * z[[i, l]] * w_i;
1086                }
1087            }
1088            let p_i = if pi_i > F::from_f64(1e-12).unwrap_or(F::zero()) {
1089                pi_i
1090            } else {
1091                F::from_f64(1e-12).unwrap_or(F::zero())
1092            };
1093            let one_p = F::one() - p_i;
1094            ll = ll
1095                + w[i] * p_i.ln()
1096                + (F::one() - w[i]) * one_p.max(F::from_f64(1e-12).unwrap_or(F::zero())).ln();
1097        }
1098        let neg_h: Array2<F> = h.mapv(|v| -v);
1099        let delta = solve(&neg_h.view(), &s.view(), None)
1100            .map_err(|e| StatsError::ComputationError(format!("logistic_irls solve: {e}")))?;
1101        g = g.iter().zip(delta.iter()).map(|(&b, &d)| b + d).collect();
1102    }
1103    Ok((g, ll))
1104}
1105
1106// ──────────────────────────────────────────────────────────────────────────────
1107// Tests
1108// ──────────────────────────────────────────────────────────────────────────────
1109
1110#[cfg(test)]
1111mod tests {
1112    use super::*;
1113    use scirs2_core::ndarray::{Array1, Array2};
1114
1115    fn make_count_panel() -> (Array2<f64>, Array1<f64>, Vec<usize>) {
1116        let n_ent = 10;
1117        let t_per = 5;
1118        let n = n_ent * t_per;
1119        let entity: Vec<usize> = (0..n_ent)
1120            .flat_map(|e| std::iter::repeat(e).take(t_per))
1121            .collect();
1122        let eff = [0.5, 0.8, 1.0, 1.2, 1.5, 0.6, 0.9, 1.1, 1.3, 1.6_f64];
1123        let mut x_vals = Vec::with_capacity(n);
1124        let mut y_vals = Vec::with_capacity(n);
1125        for (i, &eid) in entity.iter().enumerate() {
1126            let x_v = (i % t_per) as f64 * 0.5 + 0.5;
1127            x_vals.push(x_v);
1128            let lambda = (1.0 + 0.5 * x_v) * eff[eid];
1129            // Approximate Poisson sample via rounding
1130            y_vals.push(lambda.round());
1131        }
1132        let x = Array2::from_shape_vec((n, 1), x_vals).unwrap();
1133        let y = Array1::from(y_vals);
1134        (x, y, entity)
1135    }
1136
1137    #[test]
1138    fn test_poisson_fe_fit() {
1139        let (x, y, entity) = make_count_panel();
1140        let result =
1141            PoissonFE::fit(&x.view(), &y.view(), &entity, 100, 1e-8).expect("PoissonFE fit failed");
1142        assert!(result.log_likelihood.is_finite());
1143        assert_eq!(result.irr.len(), 1);
1144        assert!(result.irr[0] > 0.0, "IRR should be positive");
1145    }
1146
1147    #[test]
1148    fn test_negbinom_fe_fit() {
1149        let (x, y, entity) = make_count_panel();
1150        let result = NegBinomFE::fit(&x.view(), &y.view(), &entity, 50, 1e-6)
1151            .expect("NegBinomFE fit failed");
1152        assert!(result.log_likelihood.is_finite());
1153        assert!(result.alpha.is_some());
1154        let alpha = result.alpha.unwrap();
1155        assert!(alpha >= 0.0, "alpha should be non-negative");
1156    }
1157
1158    #[test]
1159    fn test_zero_inflated_poisson() {
1160        let (x_count, y, entity) = make_count_panel();
1161        // Inflate: intercept only
1162        let z = Array2::<f64>::ones((y.len(), 1));
1163        let result = ZeroInflated::fit(
1164            &x_count.view(),
1165            &z.view(),
1166            &y.view(),
1167            CountDistribution::Poisson,
1168            50,
1169            1e-6,
1170        )
1171        .expect("ZIP fit failed");
1172        assert!(result.log_likelihood.is_finite());
1173        assert_eq!(result.irr.len(), 1);
1174    }
1175
1176    #[test]
1177    fn test_irr_positive() {
1178        let (x, y, entity) = make_count_panel();
1179        let result =
1180            PoissonFE::fit(&x.view(), &y.view(), &entity, 100, 1e-8).expect("PoissonFE fit failed");
1181        for &irr in result.irr.iter() {
1182            assert!(irr > 0.0, "All IRRs must be positive");
1183        }
1184    }
1185}