Skip to main content

scirs2_stats/panel/
random_effects.rs

1//! Random Effects and Linear Mixed Models
2//!
3//! Implements:
4//! - `RandomEffectsModel`: GLS random effects with Swamy-Arora variance components
5//! - `HausmanTest`: fixed vs. random effects specification test
6//! - `LinearMixedModel`: LMM with random intercepts and slopes
7//! - `REML`: restricted maximum likelihood for variance components
8//! - `REResult`: fixed effects coefficients, BLUPs, variance components
9
10use crate::error::{StatsError, StatsResult};
11use crate::panel::fixed_effects::{FEResult, FixedEffectsModel};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use scirs2_core::numeric::{Float, FromPrimitive};
14use scirs2_linalg::{lstsq, solve};
15
16// ──────────────────────────────────────────────────────────────────────────────
17// Helpers
18// ──────────────────────────────────────────────────────────────────────────────
19
20/// Simple mat-mul A(m×k) × B(k×n) → (m×n).
21fn matmul<F: Float + std::iter::Sum>(a: &Array2<F>, b: &Array2<F>) -> StatsResult<Array2<F>> {
22    let (m, k) = a.dim();
23    let (kb, n) = b.dim();
24    if k != kb {
25        return Err(StatsError::DimensionMismatch(format!(
26            "matmul dim: {} vs {}",
27            k, kb
28        )));
29    }
30    let mut c = Array2::zeros((m, n));
31    for i in 0..m {
32        for j in 0..n {
33            let mut s = F::zero();
34            for l in 0..k {
35                s = s + a[[i, l]] * b[[l, j]];
36            }
37            c[[i, j]] = s;
38        }
39    }
40    Ok(c)
41}
42
43/// OLS helper: returns (coeff, resid).
44fn ols<F>(x: &Array2<F>, y: &Array1<F>) -> StatsResult<(Array1<F>, Array1<F>)>
45where
46    F: Float
47        + std::iter::Sum
48        + std::fmt::Debug
49        + std::fmt::Display
50        + scirs2_core::numeric::NumAssign
51        + scirs2_core::numeric::One
52        + scirs2_core::ndarray::ScalarOperand
53        + FromPrimitive
54        + Send
55        + Sync
56        + 'static,
57{
58    let n = y.len();
59    let result = lstsq(&x.view(), &y.view(), None)
60        .map_err(|e| StatsError::ComputationError(format!("lstsq: {e}")))?;
61    let c = result.x;
62    let mut fitted = Array1::zeros(n);
63    for i in 0..n {
64        for j in 0..c.len() {
65            fitted[i] = fitted[i] + x[[i, j]] * c[j];
66        }
67    }
68    let resid: Array1<F> = y.iter().zip(fitted.iter()).map(|(&y, &f)| y - f).collect();
69    Ok((c, resid))
70}
71
72// ──────────────────────────────────────────────────────────────────────────────
73// REResult
74// ──────────────────────────────────────────────────────────────────────────────
75
76/// Results from a random-effects estimation.
77#[derive(Debug, Clone)]
78pub struct REResult<F> {
79    /// GLS-estimated fixed-effects coefficients (K)
80    pub coefficients: Array1<F>,
81    /// Standard errors (K)
82    pub std_errors: Array1<F>,
83    /// t-statistics
84    pub t_stats: Array1<F>,
85    /// Within-entity variance σ²_ε
86    pub sigma2_epsilon: F,
87    /// Between-entity variance (random effect) σ²_u
88    pub sigma2_u: F,
89    /// Hausman theta (RE quasi-demeaning factor per entity)
90    pub theta: F,
91    /// R² overall
92    pub r2_overall: F,
93    /// Fitted values
94    pub fitted: Array1<F>,
95    /// Residuals
96    pub residuals: Array1<F>,
97    /// BLUPs of random intercepts (n_entities)
98    pub blups: Array1<F>,
99    /// Number of observations
100    pub n_obs: usize,
101    /// Number of entities
102    pub n_entities: usize,
103}
104
105// ──────────────────────────────────────────────────────────────────────────────
106// RandomEffectsModel  (Swamy-Arora variance components)
107// ──────────────────────────────────────────────────────────────────────────────
108
109/// GLS random-effects estimator (Swamy-Arora variance components).
110///
111/// Assumes:  y_{it} = x_{it}' β + u_i + ε_{it}
112/// where u_i ~ N(0, σ²_u)  and  ε_{it} ~ N(0, σ²_ε).
113///
114/// Steps:
115/// 1. Estimate σ²_ε from within (FE) residuals.
116/// 2. Estimate σ²_u from between residuals.
117/// 3. Compute θ = 1 - σ_ε / sqrt(T σ²_u + σ²_ε).
118/// 4. Quasi-demean all variables.
119/// 5. OLS on quasi-demeaned data.
120pub struct RandomEffectsModel;
121
122impl RandomEffectsModel {
123    /// Fit random effects model.
124    ///
125    /// # Arguments
126    /// * `x`       – (N × K) design matrix **without intercept**
127    /// * `y`       – response (N)
128    /// * `entity`  – entity IDs (0-indexed, length N)
129    /// * `time`    – time IDs (0-indexed, length N)
130    pub fn fit<F>(
131        x: &ArrayView2<F>,
132        y: &ArrayView1<F>,
133        entity: &[usize],
134        time: &[usize],
135    ) -> StatsResult<REResult<F>>
136    where
137        F: Float
138            + std::iter::Sum
139            + std::fmt::Debug
140            + std::fmt::Display
141            + scirs2_core::numeric::NumAssign
142            + scirs2_core::numeric::One
143            + scirs2_core::ndarray::ScalarOperand
144            + FromPrimitive
145            + Send
146            + Sync
147            + 'static,
148    {
149        let n = y.len();
150        let (nx, k) = x.dim();
151        if nx != n || entity.len() != n || time.len() != n {
152            return Err(StatsError::DimensionMismatch(
153                "x, y, entity, time lengths must match".to_string(),
154            ));
155        }
156        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
157        if n_entities < 2 {
158            return Err(StatsError::InsufficientData(
159                "Need at least 2 entities for RE estimation".to_string(),
160            ));
161        }
162
163        // ── Step 1: within variance σ²_ε from FE residuals ───────────────────
164        let fe = FixedEffectsModel::fit(x, y, entity, time, false)?;
165        let resid_within = &fe.residuals;
166        let df_within = if n > n_entities + k {
167            n - n_entities - k
168        } else {
169            1
170        };
171        let ss_within: F = resid_within.iter().map(|&r| r * r).sum();
172        let sigma2_eps = ss_within
173            / F::from_usize(df_within)
174                .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
175
176        // ── Step 2: entity counts ─────────────────────────────────────────────
177        let mut e_counts = vec![0usize; n_entities];
178        for &eid in entity.iter() {
179            e_counts[eid] += 1;
180        }
181        // average T per entity
182        let t_bar =
183            F::from_usize(n).unwrap_or(F::one()) / F::from_usize(n_entities).unwrap_or(F::one());
184
185        // ── Step 3: between variance σ²_u ────────────────────────────────────
186        // Between estimator: OLS on entity means
187        // ȳ_i = X̄_i β + u_i + ε̄_i
188        let mut y_mean_e = vec![F::zero(); n_entities];
189        let mut x_mean_e = vec![vec![F::zero(); k]; n_entities];
190        for (i, &eid) in entity.iter().enumerate() {
191            y_mean_e[eid] = y_mean_e[eid] + y[i];
192            for j in 0..k {
193                x_mean_e[eid][j] = x_mean_e[eid][j] + x[[i, j]];
194            }
195        }
196        for eid in 0..n_entities {
197            let cnt = F::from_usize(e_counts[eid]).unwrap_or(F::one());
198            y_mean_e[eid] = y_mean_e[eid] / cnt;
199            for j in 0..k {
200                x_mean_e[eid][j] = x_mean_e[eid][j] / cnt;
201            }
202        }
203        // Build between design matrix (with intercept)
204        let xb_flat: Vec<F> = x_mean_e
205            .iter()
206            .flat_map(|row| std::iter::once(F::one()).chain(row.iter().copied()))
207            .collect();
208        let xb = Array2::from_shape_vec((n_entities, k + 1), xb_flat)
209            .map_err(|e| StatsError::ComputationError(format!("reshape: {e}")))?;
210        let yb = Array1::from(y_mean_e.clone());
211        let (_coeffs_b, resid_b) = ols(&xb, &yb)?;
212        let ss_between: F = resid_b.iter().map(|&r| r * r).sum();
213        let df_between = if n_entities > k + 1 {
214            n_entities - k - 1
215        } else {
216            1
217        };
218        let sigma2_b = ss_between
219            / F::from_usize(df_between)
220                .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
221        let sigma2_u_raw = sigma2_b - sigma2_eps / t_bar;
222        let sigma2_u = if sigma2_u_raw > F::zero() {
223            sigma2_u_raw
224        } else {
225            F::zero()
226        };
227
228        // ── Step 4: quasi-demeaning factor θ ─────────────────────────────────
229        // θ_i = 1 - σ_ε / sqrt( T_i * σ²_u + σ²_ε )
230        // We compute per-entity theta but store the overall value.
231        let theta_vec: Vec<F> = e_counts
232            .iter()
233            .map(|&ti| {
234                let ti_f = F::from_usize(ti).unwrap_or(F::one());
235                let denom_sq = ti_f * sigma2_u + sigma2_eps;
236                if denom_sq > F::zero() {
237                    F::one() - sigma2_eps.sqrt() / denom_sq.sqrt()
238                } else {
239                    F::zero()
240                }
241            })
242            .collect();
243        // Global theta (use balanced approximation)
244        let theta_glob = {
245            let denom_sq = t_bar * sigma2_u + sigma2_eps;
246            if denom_sq > F::zero() {
247                F::one() - sigma2_eps.sqrt() / denom_sq.sqrt()
248            } else {
249                F::zero()
250            }
251        };
252
253        // ── Step 5: quasi-demean ──────────────────────────────────────────────
254        // ỹ_it = y_it - θ_i * ȳ_i
255        let mut yq: Vec<F> = Vec::with_capacity(n);
256        let mut xq_rows: Vec<Vec<F>> = Vec::with_capacity(n);
257        for (i, &eid) in entity.iter().enumerate() {
258            let th = theta_vec[eid];
259            yq.push(y[i] - th * y_mean_e[eid]);
260            let mut row = vec![F::one() - th]; // quasi-demeaned intercept
261            for j in 0..k {
262                row.push(x[[i, j]] - th * x_mean_e[eid][j]);
263            }
264            xq_rows.push(row);
265        }
266        let yq_arr = Array1::from(yq);
267        let xq_flat: Vec<F> = xq_rows.iter().flat_map(|r| r.iter().copied()).collect();
268        let xq = Array2::from_shape_vec((n, k + 1), xq_flat)
269            .map_err(|e| StatsError::ComputationError(format!("reshape: {e}")))?;
270
271        let (coeffs_full, resid) = ols(&xq, &yq_arr)?;
272
273        // ── Build fitted values ───────────────────────────────────────────────
274        // coeffs_full[0] = intercept, [1..] = slope
275        let intercept = coeffs_full[0];
276        let slopes: Array1<F> = coeffs_full.slice(scirs2_core::ndarray::s![1..]).to_owned();
277
278        let mut fitted = Array1::zeros(n);
279        for i in 0..n {
280            let mut fi = intercept;
281            for j in 0..k {
282                fi = fi + x[[i, j]] * slopes[j];
283            }
284            fitted[i] = fi;
285        }
286        let orig_resid: Array1<F> = (0..n).map(|i| y[i] - fitted[i]).collect();
287
288        // ── R² overall ────────────────────────────────────────────────────────
289        let y_bar = y.iter().copied().sum::<F>() / F::from_usize(n).unwrap_or(F::one());
290        let ss_tot: F = y.iter().map(|&v| (v - y_bar) * (v - y_bar)).sum();
291        let ss_res: F = orig_resid.iter().map(|&r| r * r).sum();
292        let r2 = if ss_tot > F::zero() {
293            F::one() - ss_res / ss_tot
294        } else {
295            F::zero()
296        };
297
298        // ── Standard errors (OLS on quasi-demeaned) ───────────────────────────
299        let nf = F::from_usize(n).unwrap_or(F::one());
300        let df_res_f = F::from_usize(if n > k + 1 { n - k - 1 } else { 1 }).unwrap_or(F::one());
301        let sigma2_resid = resid.iter().map(|&r| r * r).sum::<F>() / df_res_f;
302
303        // (X'X)^{-1} σ² for SE
304        let xtx = matmul(&xq.t().to_owned(), &xq)?;
305        let std_errors = xtx_inv_diag_se(&xtx, sigma2_resid)?;
306        // drop the intercept SE row (return only slope SEs)
307        let se_slopes: Array1<F> = std_errors.slice(scirs2_core::ndarray::s![1..]).to_owned();
308        let t_stats: Array1<F> = slopes
309            .iter()
310            .zip(se_slopes.iter())
311            .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
312            .collect();
313
314        // ── BLUPs: û_i = σ²_u / (σ²_u + σ²_ε/T_i) * ē_i ───────────────────
315        let mut blup_sum = vec![F::zero(); n_entities];
316        for (i, &eid) in entity.iter().enumerate() {
317            blup_sum[eid] = blup_sum[eid] + orig_resid[i];
318        }
319        let blups: Array1<F> = (0..n_entities)
320            .map(|eid| {
321                if e_counts[eid] == 0 {
322                    return F::zero();
323                }
324                let ti = F::from_usize(e_counts[eid]).unwrap_or(F::one());
325                let e_mean = blup_sum[eid] / ti;
326                let denom = sigma2_u + sigma2_eps / ti;
327                if denom > F::zero() {
328                    sigma2_u / denom * e_mean
329                } else {
330                    F::zero()
331                }
332            })
333            .collect();
334
335        // Only return slope coefficients (not intercept) to match fixed-effects API
336        Ok(REResult {
337            coefficients: slopes,
338            std_errors: se_slopes,
339            t_stats,
340            sigma2_epsilon: sigma2_eps,
341            sigma2_u,
342            theta: theta_glob,
343            r2_overall: r2,
344            fitted,
345            residuals: orig_resid,
346            blups,
347            n_obs: n,
348            n_entities,
349        })
350    }
351}
352
353// ──────────────────────────────────────────────────────────────────────────────
354// HausmanTest
355// ──────────────────────────────────────────────────────────────────────────────
356
357/// Result of the Hausman specification test.
358#[derive(Debug, Clone)]
359pub struct HausmanTestResult<F> {
360    /// Hausman H statistic (χ² distributed under H₀)
361    pub h_stat: F,
362    /// Degrees of freedom (= number of regressors K)
363    pub df: usize,
364    /// Approximate p-value
365    pub p_value: F,
366    /// Difference in coefficient vectors (FE - RE)
367    pub coeff_diff: Array1<F>,
368}
369
370/// Hausman (1978) test: H₀: RE is consistent (no correlation between u_i and x_{it}).
371///
372/// H = (β̂_FE - β̂_RE)' [Var(β̂_FE) - Var(β̂_RE)]⁻¹ (β̂_FE - β̂_RE)  ~ χ²(K)
373pub struct HausmanTest;
374
375impl HausmanTest {
376    /// Compute the Hausman test statistic given FE and RE results.
377    ///
378    /// Both results must have the same number of slope coefficients K.
379    pub fn test<F>(fe: &FEResult<F>, re: &REResult<F>) -> StatsResult<HausmanTestResult<F>>
380    where
381        F: Float
382            + std::iter::Sum
383            + std::fmt::Debug
384            + std::fmt::Display
385            + scirs2_core::numeric::NumAssign
386            + scirs2_core::numeric::One
387            + scirs2_core::ndarray::ScalarOperand
388            + FromPrimitive
389            + Send
390            + Sync
391            + 'static,
392    {
393        let kfe = fe.coefficients.len();
394        let kre = re.coefficients.len();
395        if kfe != kre {
396            return Err(StatsError::DimensionMismatch(format!(
397                "FE has {} coefficients but RE has {}",
398                kfe, kre
399            )));
400        }
401        let k = kfe;
402
403        // Coefficient difference q = β_FE - β_RE
404        let q: Array1<F> = fe
405            .coefficients
406            .iter()
407            .zip(re.coefficients.iter())
408            .map(|(&bfe, &bre)| bfe - bre)
409            .collect();
410
411        // Variance of q: Var(q) = Var(β_FE) - Var(β_RE)
412        // Diagonal approximation: var_q_j = se_FE_j² - se_RE_j²
413        let mut var_q = Array2::<F>::zeros((k, k));
414        for j in 0..k {
415            let v = fe.std_errors[j] * fe.std_errors[j] - re.std_errors[j] * re.std_errors[j];
416            // Ensure positive definite by clamping to small positive value
417            var_q[[j, j]] = if v > F::zero() {
418                v
419            } else {
420                F::from_f64(1e-10).unwrap_or(F::zero())
421            };
422        }
423
424        // H = q' Var(q)^{-1} q
425        // Since var_q is diagonal, this simplifies:
426        let h_stat: F = (0..k)
427            .map(|j| {
428                let vj = var_q[[j, j]];
429                if vj > F::zero() {
430                    q[j] * q[j] / vj
431                } else {
432                    F::zero()
433                }
434            })
435            .sum();
436
437        let p_value = chi2_upper_tail_pvalue(h_stat, k);
438
439        Ok(HausmanTestResult {
440            h_stat,
441            df: k,
442            p_value,
443            coeff_diff: q,
444        })
445    }
446}
447
448// ──────────────────────────────────────────────────────────────────────────────
449// LinearMixedModel
450// ──────────────────────────────────────────────────────────────────────────────
451
452/// Configuration for the Linear Mixed Model.
453#[derive(Debug, Clone)]
454pub struct LmmConfig {
455    /// Include random slopes in addition to random intercepts.
456    pub random_slopes: bool,
457    /// Maximum EM iterations for REML.
458    pub max_iter: usize,
459    /// Convergence tolerance.
460    pub tol: f64,
461}
462
463impl Default for LmmConfig {
464    fn default() -> Self {
465        LmmConfig {
466            random_slopes: false,
467            max_iter: 200,
468            tol: 1e-8,
469        }
470    }
471}
472
473/// Result from a Linear Mixed Model.
474#[derive(Debug, Clone)]
475pub struct LmmResult<F> {
476    /// Fixed-effects coefficients (K)
477    pub fixed_effects: Array1<F>,
478    /// Standard errors for fixed effects
479    pub fixed_se: Array1<F>,
480    /// BLUPs for random intercepts (n_entities)
481    pub random_intercepts: Array1<F>,
482    /// BLUPs for random slopes per entity (n_entities × K_r), if random_slopes=true
483    pub random_slopes: Option<Array2<F>>,
484    /// Residual variance σ²_ε
485    pub sigma2_resid: F,
486    /// Random-intercept variance σ²_u
487    pub sigma2_u: F,
488    /// Log-likelihood under REML
489    pub reml_loglik: F,
490    /// Number of observations
491    pub n_obs: usize,
492    /// Number of entities
493    pub n_entities: usize,
494}
495
496/// Linear Mixed Model with random intercepts (and optionally random slopes).
497///
498/// Estimation via a two-step EM / REML procedure.
499pub struct LinearMixedModel {
500    pub config: LmmConfig,
501}
502
503impl LinearMixedModel {
504    /// Create a new LMM with default configuration.
505    pub fn new() -> Self {
506        LinearMixedModel {
507            config: LmmConfig::default(),
508        }
509    }
510
511    /// Create an LMM with custom configuration.
512    pub fn with_config(config: LmmConfig) -> Self {
513        LinearMixedModel { config }
514    }
515
516    /// Fit an LMM via iterated GLS (equivalent to REML EM for balanced data).
517    ///
518    /// # Arguments
519    /// * `x`      – (N × K) design matrix (fixed effects, **without** intercept)
520    /// * `y`      – response (N)
521    /// * `entity` – entity IDs (0-indexed, length N)
522    pub fn fit<F>(
523        &self,
524        x: &ArrayView2<F>,
525        y: &ArrayView1<F>,
526        entity: &[usize],
527    ) -> StatsResult<LmmResult<F>>
528    where
529        F: Float
530            + std::iter::Sum
531            + std::fmt::Debug
532            + std::fmt::Display
533            + scirs2_core::numeric::NumAssign
534            + scirs2_core::numeric::One
535            + scirs2_core::ndarray::ScalarOperand
536            + FromPrimitive
537            + Send
538            + Sync
539            + 'static,
540    {
541        let n = y.len();
542        let (nx, k) = x.dim();
543        if nx != n || entity.len() != n {
544            return Err(StatsError::DimensionMismatch(
545                "x, y, entity lengths must match".to_string(),
546            ));
547        }
548        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
549        // entity counts
550        let mut e_counts = vec![0usize; n_entities];
551        for &eid in entity.iter() {
552            e_counts[eid] += 1;
553        }
554
555        // ── EM iterations ──────────────────────────────────────────────────────
556        // Initialize variance components
557        let mut sigma2_eps = F::one();
558        let mut sigma2_u = F::from_f64(0.5).unwrap_or(F::one());
559        let tol = F::from_f64(self.config.tol).unwrap_or(F::from_f64(1e-8).unwrap_or(F::zero()));
560
561        let mut coeffs = Array1::zeros(k + 1); // [intercept, slopes]
562        let mut blups = Array1::zeros(n_entities);
563
564        for _iter in 0..self.config.max_iter {
565            // ── E-step: compute BLUPs ──────────────────────────────────────────
566            // û_i = σ²_u / (σ²_u + σ²_ε / T_i) * ē_i
567            // first compute residuals with current β
568            let mut resid_cur = y.to_owned();
569            for i in 0..n {
570                let mut fi = coeffs[0]; // intercept
571                for j in 0..k {
572                    fi = fi + x[[i, j]] * coeffs[j + 1];
573                }
574                resid_cur[i] = resid_cur[i] - fi;
575            }
576            // entity mean residuals
577            let mut e_res_sum = vec![F::zero(); n_entities];
578            for (i, &eid) in entity.iter().enumerate() {
579                e_res_sum[eid] = e_res_sum[eid] + resid_cur[i];
580            }
581            let mut new_blups = Array1::zeros(n_entities);
582            for eid in 0..n_entities {
583                if e_counts[eid] == 0 {
584                    continue;
585                }
586                let ti = F::from_usize(e_counts[eid]).unwrap_or(F::one());
587                let e_mean = e_res_sum[eid] / ti;
588                let denom = sigma2_u + sigma2_eps / ti;
589                new_blups[eid] = if denom > F::zero() {
590                    sigma2_u / denom * e_mean
591                } else {
592                    F::zero()
593                };
594            }
595
596            // ── M-step: update β via WLS (GLS with known Σ) ────────────────────
597            // Quasi-demean by θ_i * û_i contribution (empirical Bayes shrinkage)
598            // Effective model: ỹ = Xβ + ε̃
599            let mut yq_vec: Vec<F> = Vec::with_capacity(n);
600            let mut xq_rows: Vec<Vec<F>> = Vec::with_capacity(n);
601            for (i, &eid) in entity.iter().enumerate() {
602                let ti = F::from_usize(e_counts[eid]).unwrap_or(F::one());
603                let denom = sigma2_u + sigma2_eps / ti;
604                let theta_i = if denom > F::zero() {
605                    sigma2_u / denom
606                } else {
607                    F::zero()
608                };
609                // subtract BLUP contribution
610                yq_vec.push(y[i] - new_blups[eid]);
611                let mut row = vec![F::one()]; // intercept
612                for j in 0..k {
613                    row.push(x[[i, j]]);
614                }
615                xq_rows.push(row);
616            }
617            let yq = Array1::from(yq_vec);
618            let xq_flat: Vec<F> = xq_rows.iter().flat_map(|r| r.iter().copied()).collect();
619            let xq = Array2::from_shape_vec((n, k + 1), xq_flat)
620                .map_err(|e| StatsError::ComputationError(format!("reshape: {e}")))?;
621            let (new_coeffs, resid_m) = ols(&xq, &yq)?;
622
623            // ── Update variance components ─────────────────────────────────────
624            let ss_eps: F = resid_m.iter().map(|&r| r * r).sum();
625            let df_eps = if n > k + 1 { n - k - 1 } else { 1 };
626            let new_sigma2_eps = ss_eps / F::from_usize(df_eps).unwrap_or(F::one());
627
628            let ss_u: F = new_blups.iter().map(|&u| u * u).sum();
629            let df_u = if n_entities > 0 { n_entities } else { 1 };
630            let new_sigma2_u = ss_u / F::from_usize(df_u).unwrap_or(F::one());
631
632            // ── Convergence check ──────────────────────────────────────────────
633            let delta_coeffs: F = new_coeffs
634                .iter()
635                .zip(coeffs.iter())
636                .map(|(&a, &b)| (a - b) * (a - b))
637                .sum::<F>()
638                .sqrt();
639            let delta_sig = (new_sigma2_eps - sigma2_eps).abs() + (new_sigma2_u - sigma2_u).abs();
640
641            coeffs = new_coeffs;
642            blups = new_blups;
643            sigma2_eps = if new_sigma2_eps > F::zero() {
644                new_sigma2_eps
645            } else {
646                F::zero()
647            };
648            sigma2_u = if new_sigma2_u > F::zero() {
649                new_sigma2_u
650            } else {
651                F::zero()
652            };
653
654            if delta_coeffs < tol && delta_sig < tol {
655                break;
656            }
657        }
658
659        // ── Final residuals / standard errors ─────────────────────────────────
660        let mut fitted = Array1::zeros(n);
661        for i in 0..n {
662            let mut fi = coeffs[0];
663            for j in 0..k {
664                fi = fi + x[[i, j]] * coeffs[j + 1];
665            }
666            fi = fi + blups[entity[i]];
667            fitted[i] = fi;
668        }
669        let residuals: Array1<F> = (0..n).map(|i| y[i] - fitted[i]).collect();
670        let nf = F::from_usize(n).unwrap_or(F::one());
671        let df_f = F::from_usize(if n > k + 1 { n - k - 1 } else { 1 }).unwrap_or(F::one());
672        let sigma2_final = residuals.iter().map(|&r| r * r).sum::<F>() / df_f;
673
674        // Approximate SE: sqrt(diag((X'X)^{-1} σ²))
675        // Build xq with intercept for SE computation
676        let xq_for_se_flat: Vec<F> = (0..n)
677            .flat_map(|i| std::iter::once(F::one()).chain((0..k).map(move |j| x[[i, j]])))
678            .collect();
679        let xq_for_se = Array2::from_shape_vec((n, k + 1), xq_for_se_flat)
680            .map_err(|e| StatsError::ComputationError(format!("reshape: {e}")))?;
681        let xtx = matmul(&xq_for_se.t().to_owned(), &xq_for_se)?;
682        let se_full = xtx_inv_diag_se(&xtx, sigma2_final)?;
683        let fixed_se: Array1<F> = se_full.slice(scirs2_core::ndarray::s![1..]).to_owned();
684        let fixed_coef: Array1<F> = coeffs.slice(scirs2_core::ndarray::s![1..]).to_owned();
685
686        // ── REML log-likelihood (approximate) ──────────────────────────────────
687        // log L_REML ≈ -n/2 log(σ²_ε) - n/2
688        let reml_loglik = if sigma2_eps > F::zero() {
689            let two = F::from_f64(2.0).unwrap_or(F::one());
690            -nf / two * sigma2_eps.ln() - nf / two
691        } else {
692            F::zero()
693        };
694
695        // ── Random slopes (optional) ────────────────────────────────────────────
696        // Compute per-entity random slopes when config.random_slopes is true.
697        //
698        // After convergence of the EM loop we have fixed-effects β = `coeffs`.
699        // For each entity i, within-entity residuals are:
700        //   r_i = y_i - X_i · β_fixed
701        // Build Z_i = [1 | X_i] (T_i × (k+1)) and solve OLS(Z_i, r_i).
702        // b_i[0] is a per-entity intercept correction (already captured in `blups`),
703        // b_i[1..] are the per-entity slope deviations from the fixed slopes.
704        let random_slopes_opt: Option<Array2<F>> = if self.config.random_slopes && k > 0 {
705            // Group row indices by entity
706            let mut entity_rows: Vec<Vec<usize>> = vec![Vec::new(); n_entities];
707            for (row_idx, &eid) in entity.iter().enumerate() {
708                entity_rows[eid].push(row_idx);
709            }
710
711            let mut slope_matrix = Array2::zeros((n_entities, k));
712
713            for eid in 0..n_entities {
714                let rows = &entity_rows[eid];
715                let ti = rows.len();
716                if ti == 0 {
717                    continue;
718                }
719
720                // Build per-entity residuals: r_i = y_i - X_i · β_fixed
721                // coeffs[0] = intercept, coeffs[1..] = slopes
722                let r_i: Array1<F> = rows
723                    .iter()
724                    .map(|&row_idx| {
725                        let mut fi = coeffs[0];
726                        for j in 0..k {
727                            fi = fi + x[[row_idx, j]] * coeffs[j + 1];
728                        }
729                        y[row_idx] - fi
730                    })
731                    .collect();
732
733                // Build Z_i = [1 | X_rows_for_entity_i], shape (T_i × (k+1))
734                let zi_flat: Vec<F> = rows
735                    .iter()
736                    .flat_map(|&row_idx| {
737                        std::iter::once(F::one()).chain((0..k).map(move |j| x[[row_idx, j]]))
738                    })
739                    .collect();
740                let zi = Array2::from_shape_vec((ti, k + 1), zi_flat).map_err(|e| {
741                    StatsError::ComputationError(format!("random slopes reshape: {e}"))
742                })?;
743
744                // OLS on within-entity system; b_i[0] = intercept correction (discarded),
745                // b_i[1..] = per-entity slope deviations.
746                // If ti < k+1 the system is under-determined; lstsq handles it gracefully
747                // via the minimum-norm solution.
748                let (b_i, _) = ols(&zi, &r_i)?;
749
750                // Store slope deviations (skip b_i[0] which is the intercept correction)
751                for j in 0..k {
752                    slope_matrix[[eid, j]] = b_i[j + 1];
753                }
754            }
755
756            Some(slope_matrix)
757        } else {
758            None
759        };
760
761        Ok(LmmResult {
762            fixed_effects: fixed_coef,
763            fixed_se,
764            random_intercepts: blups,
765            random_slopes: random_slopes_opt,
766            sigma2_resid: sigma2_eps,
767            sigma2_u,
768            reml_loglik,
769            n_obs: n,
770            n_entities,
771        })
772    }
773}
774
775impl Default for LinearMixedModel {
776    fn default() -> Self {
777        Self::new()
778    }
779}
780
781// ──────────────────────────────────────────────────────────────────────────────
782// REML
783// ──────────────────────────────────────────────────────────────────────────────
784
785/// Restricted Maximum Likelihood (REML) estimator for variance components.
786///
787/// Provides a thin wrapper that calls `LinearMixedModel::fit` and exposes
788/// the REML-specific interface.
789pub struct REML;
790
791impl REML {
792    /// Estimate variance components via REML.
793    ///
794    /// Returns (σ²_u, σ²_ε, REML log-likelihood).
795    pub fn estimate<F>(
796        x: &ArrayView2<F>,
797        y: &ArrayView1<F>,
798        entity: &[usize],
799    ) -> StatsResult<(F, F, F)>
800    where
801        F: Float
802            + std::iter::Sum
803            + std::fmt::Debug
804            + std::fmt::Display
805            + scirs2_core::numeric::NumAssign
806            + scirs2_core::numeric::One
807            + scirs2_core::ndarray::ScalarOperand
808            + FromPrimitive
809            + Send
810            + Sync
811            + 'static,
812    {
813        let lmm = LinearMixedModel::new();
814        let result = lmm.fit(x, y, entity)?;
815        Ok((result.sigma2_u, result.sigma2_resid, result.reml_loglik))
816    }
817}
818
819// ──────────────────────────────────────────────────────────────────────────────
820// Helpers
821// ──────────────────────────────────────────────────────────────────────────────
822
823/// Compute sqrt(diag((X'X)^{-1} σ²)) for standard errors.
824fn xtx_inv_diag_se<F>(xtx: &Array2<F>, sigma2: F) -> StatsResult<Array1<F>>
825where
826    F: Float
827        + std::iter::Sum
828        + std::fmt::Debug
829        + std::fmt::Display
830        + scirs2_core::numeric::NumAssign
831        + scirs2_core::numeric::One
832        + scirs2_core::ndarray::ScalarOperand
833        + FromPrimitive
834        + Send
835        + Sync
836        + 'static,
837{
838    let k = xtx.nrows();
839    // Solve (X'X) v_j = e_j for each basis vector
840    let mut se = Array1::zeros(k);
841    for j in 0..k {
842        let mut ej = Array1::zeros(k);
843        ej[j] = F::one();
844        let vj = solve(&xtx.view(), &ej.view(), None)
845            .map_err(|e| StatsError::ComputationError(format!("solve: {e}")))?;
846        let var_j = vj[j] * sigma2;
847        se[j] = if var_j >= F::zero() {
848            var_j.sqrt()
849        } else {
850            F::zero()
851        };
852    }
853    Ok(se)
854}
855
856/// Upper-tail chi² p-value using Wilson-Hilferty approximation.
857fn chi2_upper_tail_pvalue<F: Float + FromPrimitive>(chi2: F, df: usize) -> F {
858    if chi2 <= F::zero() {
859        return F::one();
860    }
861    let k = F::from_usize(df).unwrap_or(F::one());
862    let two = F::from_f64(2.0).unwrap_or(F::one());
863    let nine = F::from_f64(9.0).unwrap_or(F::one());
864    // Wilson-Hilferty: z ≈ (χ²/k)^{1/3} - (1 - 2/(9k)) / sqrt(2/(9k))
865    let factor = two / (nine * k);
866    let x = (chi2 / k).cbrt();
867    let mu = F::one() - factor;
868    let sigma = factor.sqrt();
869    let z = (x - mu) / sigma;
870    // P(Z > z)
871    p_value_normal_upper(z)
872}
873
874/// Upper-tail N(0,1) probability.
875fn p_value_normal_upper<F: Float + FromPrimitive>(z: F) -> F {
876    let p1 = F::from_f64(0.2316419).unwrap_or(F::zero());
877    let b1 = F::from_f64(0.319381530).unwrap_or(F::zero());
878    let b2 = F::from_f64(-0.356563782).unwrap_or(F::zero());
879    let b3 = F::from_f64(1.781477937).unwrap_or(F::zero());
880    let b4 = F::from_f64(-1.821255978).unwrap_or(F::zero());
881    let b5 = F::from_f64(1.330274429).unwrap_or(F::zero());
882    let sqrt2pi_inv = F::from_f64(0.39894228).unwrap_or(F::zero());
883    let two = F::from_f64(2.0).unwrap_or(F::one());
884
885    let abs_z = if z < F::zero() { -z } else { z };
886    let t = F::one() / (F::one() + p1 * abs_z);
887    let poly = t * (b1 + t * (b2 + t * (b3 + t * (b4 + t * b5))));
888    let phi = sqrt2pi_inv * (-(abs_z * abs_z) / two).exp();
889    let p_upper = (phi * poly).max(F::zero()).min(F::one());
890    if z >= F::zero() {
891        p_upper
892    } else {
893        F::one() - p_upper
894    }
895}
896
897// ──────────────────────────────────────────────────────────────────────────────
898// Tests
899// ──────────────────────────────────────────────────────────────────────────────
900
901#[cfg(test)]
902mod tests {
903    use super::*;
904    use scirs2_core::ndarray::{Array1, Array2};
905
906    fn make_re_panel() -> (Array2<f64>, Array1<f64>, Vec<usize>, Vec<usize>) {
907        // y_it = 2.0 * x_it + u_i + eps_it
908        // u_i ~ N(0, 1), eps_it ~ N(0, 0.1)
909        let n_ent = 10;
910        let t_per = 5;
911        let n = n_ent * t_per;
912        let mut x_vals = Vec::with_capacity(n);
913        let mut y_vals = Vec::with_capacity(n);
914        let entity: Vec<usize> = (0..n_ent)
915            .flat_map(|e| std::iter::repeat(e).take(t_per))
916            .collect();
917        let time: Vec<usize> = (0..t_per).cycle().take(n).collect();
918        // Entity effects
919        let effects = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5_f64];
920        for (i, &eid) in entity.iter().enumerate() {
921            let x_v = (i as f64) * 0.3 + 1.0;
922            let y_v = 2.0 * x_v + effects[eid] + (i as f64) * 0.01;
923            x_vals.push(x_v);
924            y_vals.push(y_v);
925        }
926        let x = Array2::from_shape_vec((n, 1), x_vals).unwrap();
927        let y = Array1::from(y_vals);
928        (x, y, entity, time)
929    }
930
931    #[test]
932    fn test_re_model_slope() {
933        let (x, y, entity, time) = make_re_panel();
934        let result =
935            RandomEffectsModel::fit(&x.view(), &y.view(), &entity, &time).expect("RE fit failed");
936        let slope = result.coefficients[0];
937        // RE slope may be biased when entity effects correlate with x; allow wider tolerance
938        assert!(
939            (slope - 2.0).abs() < 0.5,
940            "RE slope: expected ~2.0 (within 0.5), got {}",
941            slope
942        );
943        assert!(result.sigma2_u >= 0.0, "sigma2_u should be non-negative");
944        assert!(
945            result.sigma2_epsilon >= 0.0,
946            "sigma2_eps should be non-negative"
947        );
948    }
949
950    #[test]
951    fn test_hausman_test() {
952        let (x, y, entity, time) = make_re_panel();
953        let fe =
954            FixedEffectsModel::fit(&x.view(), &y.view(), &entity, &time, false).expect("FE fit");
955        let re = RandomEffectsModel::fit(&x.view(), &y.view(), &entity, &time).expect("RE fit");
956        let ht = HausmanTest::test(&fe, &re).expect("Hausman test failed");
957        assert!(ht.h_stat >= 0.0, "H-stat should be non-negative");
958        assert!(ht.p_value >= 0.0 && ht.p_value <= 1.0, "p-value in [0,1]");
959    }
960
961    #[test]
962    fn test_lmm_fit() {
963        let (x, y, entity, time) = make_re_panel();
964        let lmm = LinearMixedModel::new();
965        let result = lmm
966            .fit(&x.view(), &y.view(), &entity)
967            .expect("LMM fit failed");
968        let slope = result.fixed_effects[0];
969        assert!(
970            (slope - 2.0).abs() < 0.3,
971            "LMM slope: expected ~2.0, got {}",
972            slope
973        );
974        assert_eq!(result.random_intercepts.len(), 10);
975    }
976
977    #[test]
978    fn test_reml_estimate() {
979        let (x, y, entity, _time) = make_re_panel();
980        let (sigma2_u, sigma2_eps, _loglik) =
981            REML::estimate(&x.view(), &y.view(), &entity).expect("REML failed");
982        assert!(sigma2_u >= 0.0, "REML sigma2_u must be non-negative");
983        assert!(sigma2_eps >= 0.0, "REML sigma2_eps must be non-negative");
984    }
985
986    #[test]
987    fn test_lmm_random_slopes_shape() {
988        // Panel: y_it = 3.0 * x_it + u_i + v_i * x_it + eps_it
989        // with 5 entities, 6 time periods each
990        let n_ent = 5usize;
991        let t_per = 6usize;
992        let n = n_ent * t_per;
993
994        let entity: Vec<usize> = (0..n_ent)
995            .flat_map(|e| std::iter::repeat(e).take(t_per))
996            .collect();
997
998        // Entity-specific slope deviations (random slopes around 3.0)
999        let slope_devs = [0.2_f64, -0.3, 0.1, -0.1, 0.15];
1000        // Entity-specific intercept effects
1001        let intercept_devs = [1.0_f64, -1.0, 0.5, -0.5, 0.0];
1002
1003        let mut x_vals = Vec::with_capacity(n);
1004        let mut y_vals = Vec::with_capacity(n);
1005        for (i, &eid) in entity.iter().enumerate() {
1006            let x_v = (i as f64) * 0.25 + 0.5;
1007            let slope = 3.0 + slope_devs[eid];
1008            let y_v = slope * x_v + intercept_devs[eid] + (i as f64) * 0.005;
1009            x_vals.push(x_v);
1010            y_vals.push(y_v);
1011        }
1012        let x = Array2::from_shape_vec((n, 1), x_vals).unwrap();
1013        let y = Array1::from(y_vals);
1014
1015        // Fit with random_slopes = true
1016        let config = LmmConfig {
1017            random_slopes: true,
1018            max_iter: 200,
1019            tol: 1e-8,
1020        };
1021        let lmm = LinearMixedModel::with_config(config);
1022        let result = lmm
1023            .fit(&x.view(), &y.view(), &entity)
1024            .expect("LMM fit failed");
1025
1026        // random_slopes must be Some
1027        assert!(
1028            result.random_slopes.is_some(),
1029            "random_slopes should be Some when config.random_slopes=true"
1030        );
1031        let rs = result
1032            .random_slopes
1033            .as_ref()
1034            .expect("random_slopes is Some");
1035        // shape must be (n_entities, n_slope_params) = (5, 1)
1036        assert_eq!(
1037            rs.dim(),
1038            (n_ent, 1),
1039            "random slopes shape: expected ({n_ent}, 1), got {:?}",
1040            rs.dim()
1041        );
1042
1043        // With random_slopes=false the result must still be None
1044        let lmm_default = LinearMixedModel::new();
1045        let result_no_slopes = lmm_default
1046            .fit(&x.view(), &y.view(), &entity)
1047            .expect("LMM fit failed (no slopes)");
1048        assert!(
1049            result_no_slopes.random_slopes.is_none(),
1050            "random_slopes should be None when config.random_slopes=false"
1051        );
1052    }
1053}