Skip to main content

scirs2_stats/panel/
fixed_effects.rs

1//! Fixed Effects Panel Data Models
2//!
3//! Implements:
4//! - `FixedEffectsModel`: within estimator with entity and time FE
5//! - `WithinTransform`: demean-by-entity (within transformation)
6//! - `TwoWayFE`: two-way fixed effects (entity + time)
7//! - `FEResult`: coefficients, std errors, F-stat, R² within/between/overall
8//! - `FirstDiffEstimator`: first-difference estimator for T=2 or balanced panels
9
10use crate::error::{StatsError, StatsResult};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::{Float, FromPrimitive};
13use scirs2_linalg::{lstsq, solve};
14
15// ──────────────────────────────────────────────────────────────────────────────
16// Helper: simple matrix multiply A(m×k) × B(k×n) -> (m×n)
17// ──────────────────────────────────────────────────────────────────────────────
18
19fn matmul<F: Float + std::iter::Sum>(a: &Array2<F>, b: &Array2<F>) -> StatsResult<Array2<F>> {
20    let (m, k) = a.dim();
21    let (kb, n) = b.dim();
22    if k != kb {
23        return Err(StatsError::DimensionMismatch(format!(
24            "matmul: inner dims mismatch {} vs {}",
25            k, kb
26        )));
27    }
28    let mut c = Array2::zeros((m, n));
29    for i in 0..m {
30        for j in 0..n {
31            let mut s = F::zero();
32            for l in 0..k {
33                s = s + a[[i, l]] * b[[l, j]];
34            }
35            c[[i, j]] = s;
36        }
37    }
38    Ok(c)
39}
40
41/// Ordinary-least-squares via QR / normal equations using scirs2-linalg lstsq.
42/// Returns (coefficients, residuals)
43fn ols<F>(x: &Array2<F>, y: &Array1<F>) -> StatsResult<(Array1<F>, Array1<F>)>
44where
45    F: Float
46        + std::iter::Sum
47        + std::fmt::Debug
48        + std::fmt::Display
49        + scirs2_core::numeric::NumAssign
50        + scirs2_core::numeric::One
51        + scirs2_core::ndarray::ScalarOperand
52        + FromPrimitive
53        + Send
54        + Sync
55        + 'static,
56{
57    let n = y.len();
58    let (n2, _k) = x.dim();
59    if n != n2 {
60        return Err(StatsError::DimensionMismatch(format!(
61            "ols: x has {} rows, y has {} elements",
62            n2, n
63        )));
64    }
65    let result = lstsq(&x.view(), &y.view(), None)
66        .map_err(|e| StatsError::ComputationError(format!("lstsq failed: {e}")))?;
67    let coeffs = result.x;
68    // residuals = y - X β
69    let mut fitted = Array1::zeros(n);
70    for i in 0..n {
71        let mut s = F::zero();
72        for j in 0..coeffs.len() {
73            s = s + x[[i, j]] * coeffs[j];
74        }
75        fitted[i] = s;
76    }
77    let resid: Array1<F> = y
78        .iter()
79        .zip(fitted.iter())
80        .map(|(&yi, &fi)| yi - fi)
81        .collect();
82    Ok((coeffs, resid))
83}
84
85// ──────────────────────────────────────────────────────────────────────────────
86// Result type
87// ──────────────────────────────────────────────────────────────────────────────
88
89/// Results from a fixed-effects estimation.
90#[derive(Debug, Clone)]
91pub struct FEResult<F> {
92    /// Estimated slope coefficients (excludes entity/time dummies)
93    pub coefficients: Array1<F>,
94    /// Heteroskedasticity-consistent (HC0) standard errors
95    pub std_errors: Array1<F>,
96    /// t-statistics (coeff / se)
97    pub t_stats: Array1<F>,
98    /// Overall F-statistic for joint significance
99    pub f_stat: F,
100    /// p-value for the F-test (approximated via F(k, N-n-k) distribution)
101    pub f_pvalue: F,
102    /// R² within (variation explained after demeaning)
103    pub r2_within: F,
104    /// R² between (explained variation of entity means)
105    pub r2_between: F,
106    /// R² overall
107    pub r2_overall: F,
108    /// Number of observations
109    pub n_obs: usize,
110    /// Number of entities (panels)
111    pub n_entities: usize,
112    /// Residuals (length n_obs)
113    pub residuals: Array1<F>,
114    /// Fitted values (length n_obs)
115    pub fitted: Array1<F>,
116    /// Estimated entity fixed effects (length n_entities)
117    pub entity_effects: Option<Array1<F>>,
118    /// Estimated time fixed effects (length n_periods), if two-way
119    pub time_effects: Option<Array1<F>>,
120}
121
122// ──────────────────────────────────────────────────────────────────────────────
123// WithinTransform
124// ──────────────────────────────────────────────────────────────────────────────
125
126/// Performs the within (entity-demeaning) transformation on panel data.
127///
128/// For entity `i`, the demeaned value is `x_{it} - ȳ_i`.
129pub struct WithinTransform;
130
131impl WithinTransform {
132    /// Demean a matrix by entity means.
133    ///
134    /// # Arguments
135    /// * `data`    – shape (N, K) stacked observations (row-major: entity 0 all T periods, entity 1, …)
136    /// * `entity`  – entity index vector of length N
137    ///
138    /// Returns the demeaned matrix (same shape as `data`).
139    pub fn transform<F: Float + FromPrimitive>(
140        data: &ArrayView2<F>,
141        entity: &[usize],
142    ) -> StatsResult<Array2<F>> {
143        let (n, k) = data.dim();
144        if entity.len() != n {
145            return Err(StatsError::DimensionMismatch(format!(
146                "WithinTransform: data has {} rows but entity has {} elements",
147                n,
148                entity.len()
149            )));
150        }
151        // Find number of unique entities
152        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
153        // Compute entity means for each column
154        let mut sums = Array2::<F>::zeros((n_entities, k));
155        let mut counts = vec![0usize; n_entities];
156        for (row, &eid) in entity.iter().enumerate() {
157            counts[eid] += 1;
158            for col in 0..k {
159                sums[[eid, col]] = sums[[eid, col]] + data[[row, col]];
160            }
161        }
162        let mut means = Array2::<F>::zeros((n_entities, k));
163        for eid in 0..n_entities {
164            let cnt = F::from_usize(counts[eid])
165                .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
166            for col in 0..k {
167                means[[eid, col]] = if cnt > F::zero() {
168                    sums[[eid, col]] / cnt
169                } else {
170                    F::zero()
171                };
172            }
173        }
174        // Subtract entity mean
175        let mut demeaned = data.to_owned();
176        for (row, &eid) in entity.iter().enumerate() {
177            for col in 0..k {
178                demeaned[[row, col]] = demeaned[[row, col]] - means[[eid, col]];
179            }
180        }
181        Ok(demeaned)
182    }
183}
184
185// ──────────────────────────────────────────────────────────────────────────────
186// FixedEffectsModel
187// ──────────────────────────────────────────────────────────────────────────────
188
189/// Entity fixed-effects (within) estimator.
190///
191/// Removes entity-specific heterogeneity by demeaning each variable by the
192/// corresponding entity mean, then applies OLS.
193///
194/// # Example (illustrative)
195/// ```rust,no_run
196/// use scirs2_stats::panel::fixed_effects::FixedEffectsModel;
197/// use scirs2_core::ndarray::{Array1, Array2};
198///
199/// // 3 entities × 4 periods = 12 observations, 2 regressors
200/// let n = 12;
201/// let x = Array2::<f64>::ones((n, 2));
202/// let y = Array1::<f64>::ones(n);
203/// let entity: Vec<usize> = (0..3).flat_map(|e| std::iter::repeat(e).take(4)).collect();
204/// let time: Vec<usize>   = (0..4).cycle().take(n).collect();
205///
206/// let result = FixedEffectsModel::fit(&x.view(), &y.view(), &entity, &time, false)
207///     .expect("fit failed");
208/// println!("Coefficients: {:?}", result.coefficients);
209/// ```
210pub struct FixedEffectsModel;
211
212impl FixedEffectsModel {
213    /// Fit a one-way (entity) or two-way (entity + time) fixed-effects model.
214    ///
215    /// # Arguments
216    /// * `x`      – design matrix (N × K), **without** intercept or dummies
217    /// * `y`      – response vector (N)
218    /// * `entity` – entity IDs, 0-indexed, length N
219    /// * `time`   – time period IDs, 0-indexed, length N
220    /// * `two_way` – if `true`, also absorb time fixed effects
221    pub fn fit<F>(
222        x: &ArrayView2<F>,
223        y: &ArrayView1<F>,
224        entity: &[usize],
225        time: &[usize],
226        two_way: bool,
227    ) -> StatsResult<FEResult<F>>
228    where
229        F: Float
230            + std::iter::Sum
231            + std::fmt::Debug
232            + std::fmt::Display
233            + scirs2_core::numeric::NumAssign
234            + scirs2_core::numeric::One
235            + scirs2_core::ndarray::ScalarOperand
236            + FromPrimitive
237            + Send
238            + Sync
239            + 'static,
240    {
241        let n = y.len();
242        let (nx, k) = x.dim();
243        if nx != n || entity.len() != n || time.len() != n {
244            return Err(StatsError::DimensionMismatch(
245                "x, y, entity, time must all have the same length N".to_string(),
246            ));
247        }
248        if n == 0 {
249            return Err(StatsError::InsufficientData("Empty dataset".to_string()));
250        }
251
252        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
253        let n_periods = time.iter().copied().max().map(|m| m + 1).unwrap_or(0);
254
255        // ── within-demean X and y ────────────────────────────────────────────
256        let x_owned = x.to_owned();
257        let mut xd = WithinTransform::transform(&x_owned.view(), entity)?;
258        let mut yd_vec: Vec<F> = y.iter().copied().collect();
259
260        // entity means of y
261        let mut y_sums = vec![F::zero(); n_entities];
262        let mut y_counts = vec![0usize; n_entities];
263        for (i, &eid) in entity.iter().enumerate() {
264            y_sums[eid] = y_sums[eid] + y[i];
265            y_counts[eid] += 1;
266        }
267        let y_entity_means: Vec<F> = y_sums
268            .iter()
269            .zip(y_counts.iter())
270            .map(|(&s, &c)| {
271                if c > 0 {
272                    s / F::from_usize(c).unwrap_or(F::one())
273                } else {
274                    F::zero()
275                }
276            })
277            .collect();
278        for (i, &eid) in entity.iter().enumerate() {
279            yd_vec[i] = yd_vec[i] - y_entity_means[eid];
280        }
281
282        if two_way {
283            // Additional demeaning by time period
284            // Iterative demeaning (Frisch-Waugh): demean by time given entity-demeaned
285            let mut yd2 = yd_vec.clone();
286            let mut t_sums = vec![F::zero(); n_periods];
287            let mut t_counts = vec![0usize; n_periods];
288            for (i, &tid) in time.iter().enumerate() {
289                t_sums[tid] = t_sums[tid] + yd2[i];
290                t_counts[tid] += 1;
291            }
292            let y_time_means: Vec<F> = t_sums
293                .iter()
294                .zip(t_counts.iter())
295                .map(|(&s, &c)| {
296                    if c > 0 {
297                        s / F::from_usize(c).unwrap_or(F::one())
298                    } else {
299                        F::zero()
300                    }
301                })
302                .collect();
303            for (i, &tid) in time.iter().enumerate() {
304                yd2[i] = yd2[i] - y_time_means[tid];
305            }
306            yd_vec = yd2;
307
308            // Also demean X by time
309            let xd2 = WithinTransform::transform(&xd.view(), time)?;
310            xd = xd2;
311        }
312
313        let yd = Array1::from(yd_vec);
314
315        // ── OLS on demeaned data ─────────────────────────────────────────────
316        let (coeffs, resid) = ols(&xd, &yd)?;
317
318        // ── compute fitted values (in original space) ────────────────────────
319        let mut fitted = Array1::zeros(n);
320        for i in 0..n {
321            let mut s = y_entity_means[entity[i]]; // entity FE
322            for j in 0..k {
323                s = s + x[[i, j]] * coeffs[j];
324            }
325            fitted[i] = s;
326        }
327        let orig_resid: Array1<F> = (0..n).map(|i| y[i] - fitted[i]).collect();
328
329        // ── R² within ──────────────────────────────────────────────────────────
330        let ss_res_within: F = resid.iter().map(|&r| r * r).sum();
331        let yd_mean = yd.iter().copied().sum::<F>()
332            / F::from_usize(n)
333                .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
334        let ss_tot_within: F = yd.iter().map(|&v| (v - yd_mean) * (v - yd_mean)).sum();
335        let r2_within = if ss_tot_within > F::zero() {
336            F::one() - ss_res_within / ss_tot_within
337        } else {
338            F::zero()
339        };
340
341        // ── R² between (entity means) ─────────────────────────────────────────
342        // Entity mean of y vs entity mean of ŷ
343        let mut fy_sums = vec![F::zero(); n_entities];
344        for (i, &eid) in entity.iter().enumerate() {
345            fy_sums[eid] = fy_sums[eid] + fitted[i];
346        }
347        let y_bar_bar = y.iter().copied().sum::<F>()
348            / F::from_usize(n)
349                .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
350        let mut ss_between_tot = F::zero();
351        let mut ss_between_res = F::zero();
352        for eid in 0..n_entities {
353            if y_counts[eid] == 0 {
354                continue;
355            }
356            let cnt = F::from_usize(y_counts[eid]).unwrap_or(F::one());
357            let y_em = y_entity_means[eid];
358            let f_em = fy_sums[eid] / cnt;
359            ss_between_tot = ss_between_tot + cnt * (y_em - y_bar_bar) * (y_em - y_bar_bar);
360            ss_between_res = ss_between_res + cnt * (y_em - f_em) * (y_em - f_em);
361        }
362        let r2_between = if ss_between_tot > F::zero() {
363            F::one() - ss_between_res / ss_between_tot
364        } else {
365            F::zero()
366        };
367
368        // ── R² overall ────────────────────────────────────────────────────────
369        let ss_tot: F = y
370            .iter()
371            .map(|&yi| (yi - y_bar_bar) * (yi - y_bar_bar))
372            .sum();
373        let ss_res_overall: F = orig_resid.iter().map(|&r| r * r).sum();
374        let r2_overall = if ss_tot > F::zero() {
375            F::one() - ss_res_overall / ss_tot
376        } else {
377            F::zero()
378        };
379
380        // ── HC0 standard errors ───────────────────────────────────────────────
381        // Var(β̂) ≈ (X'X)⁻¹ X'ee'X (X'X)⁻¹
382        let xtx = matmul(&xd.t().to_owned(), &xd)?;
383        // We use the sandwich estimator via forming X'diag(e²)X
384        let std_errors = hc0_se(&xd, &resid, &xtx)?;
385
386        let t_stats: Array1<F> = coeffs
387            .iter()
388            .zip(std_errors.iter())
389            .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
390            .collect();
391
392        // ── F-statistic ───────────────────────────────────────────────────────
393        // F = (R² / k) / ((1 - R²) / (N - n_ent - k))
394        let df1 = F::from_usize(k).unwrap_or(F::one());
395        let df2_int = if n > n_entities + k {
396            n - n_entities - k
397        } else {
398            1
399        };
400        let df2 = F::from_usize(df2_int).unwrap_or(F::one());
401        let f_stat = if (F::one() - r2_within) > F::zero() {
402            (r2_within / df1) / ((F::one() - r2_within) / df2)
403        } else {
404            F::zero()
405        };
406        let f_pvalue = approximate_f_pvalue(f_stat, k, df2_int);
407
408        // ── entity effects ────────────────────────────────────────────────────
409        // α_i = ȳ_i - X̄_i β̂
410        let mut entity_effects = Array1::zeros(n_entities);
411        for eid in 0..n_entities {
412            if y_counts[eid] == 0 {
413                continue;
414            }
415            let cnt = F::from_usize(y_counts[eid]).unwrap_or(F::one());
416            // compute mean of X for this entity
417            let mut x_row_mean = vec![F::zero(); k];
418            for (i, &e2) in entity.iter().enumerate() {
419                if e2 == eid {
420                    for j in 0..k {
421                        x_row_mean[j] = x_row_mean[j] + x[[i, j]];
422                    }
423                }
424            }
425            let mut alpha = y_entity_means[eid];
426            for j in 0..k {
427                alpha = alpha - (x_row_mean[j] / cnt) * coeffs[j];
428            }
429            entity_effects[eid] = alpha;
430        }
431
432        Ok(FEResult {
433            coefficients: coeffs,
434            std_errors,
435            t_stats,
436            f_stat,
437            f_pvalue,
438            r2_within,
439            r2_between,
440            r2_overall,
441            n_obs: n,
442            n_entities,
443            residuals: orig_resid,
444            fitted,
445            entity_effects: Some(entity_effects),
446            time_effects: None,
447        })
448    }
449}
450
451// ──────────────────────────────────────────────────────────────────────────────
452// TwoWayFE
453// ──────────────────────────────────────────────────────────────────────────────
454
455/// Two-way fixed-effects estimator (entity + time).
456///
457/// Convenience wrapper over `FixedEffectsModel::fit(..., two_way: true)`.
458pub struct TwoWayFE;
459
460impl TwoWayFE {
461    /// Fit a two-way fixed-effects model.
462    pub fn fit<F>(
463        x: &ArrayView2<F>,
464        y: &ArrayView1<F>,
465        entity: &[usize],
466        time: &[usize],
467    ) -> StatsResult<FEResult<F>>
468    where
469        F: Float
470            + std::iter::Sum
471            + std::fmt::Debug
472            + std::fmt::Display
473            + scirs2_core::numeric::NumAssign
474            + scirs2_core::numeric::One
475            + scirs2_core::ndarray::ScalarOperand
476            + FromPrimitive
477            + Send
478            + Sync
479            + 'static,
480    {
481        let n = y.len();
482        let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
483        let n_periods = time.iter().copied().max().map(|m| m + 1).unwrap_or(0);
484
485        let mut result = FixedEffectsModel::fit(x, y, entity, time, true)?;
486
487        // Compute time effects from residuals of entity-demeaned regression
488        // τ_t = mean residual for time t (after entity FE absorbed)
489        // Here we approximate: τ_t = ȳ_t - ȳ - X̄_t β̂
490        let k = result.coefficients.len();
491        let mut time_effects = Array1::zeros(n_periods);
492        let mut t_sums = vec![F::zero(); n_periods];
493        let mut t_x_sums = vec![vec![F::zero(); k]; n_periods];
494        let mut t_counts = vec![0usize; n_periods];
495        for (i, &tid) in time.iter().enumerate() {
496            t_sums[tid] = t_sums[tid] + y[i];
497            t_counts[tid] += 1;
498            for j in 0..k {
499                t_x_sums[tid][j] = t_x_sums[tid][j] + x[[i, j]];
500            }
501        }
502        let y_bar = y.iter().copied().sum::<F>() / F::from_usize(n).unwrap_or(F::one());
503        let mut x_bar = vec![F::zero(); k];
504        for j in 0..k {
505            let s: F = (0..n).map(|i| x[[i, j]]).sum();
506            x_bar[j] = s / F::from_usize(n).unwrap_or(F::one());
507        }
508
509        for tid in 0..n_periods {
510            if t_counts[tid] == 0 {
511                continue;
512            }
513            let cnt = F::from_usize(t_counts[tid]).unwrap_or(F::one());
514            let y_t_bar = t_sums[tid] / cnt;
515            let mut tau = y_t_bar - y_bar;
516            for j in 0..k {
517                let x_t_bar_j = t_x_sums[tid][j] / cnt;
518                tau = tau - (x_t_bar_j - x_bar[j]) * result.coefficients[j];
519            }
520            time_effects[tid] = tau;
521        }
522        result.time_effects = Some(time_effects);
523        Ok(result)
524    }
525}
526
527// ──────────────────────────────────────────────────────────────────────────────
528// FirstDiffEstimator
529// ──────────────────────────────────────────────────────────────────────────────
530
531/// First-difference (FD) estimator for balanced panels.
532///
533/// For each entity, computes Δy_{it} = y_{it} - y_{i,t-1} and similarly for X,
534/// then applies OLS.  The FD estimator eliminates all time-invariant unobservables.
535pub struct FirstDiffEstimator;
536
537impl FirstDiffEstimator {
538    /// Fit the first-difference estimator.
539    ///
540    /// # Arguments
541    /// * `x`       – (N × K) design matrix, rows ordered by entity then time
542    /// * `y`       – response (N)
543    /// * `entity`  – entity IDs (length N)
544    /// * `time`    – time period IDs (length N, must be monotone within entity)
545    pub fn fit<F>(
546        x: &ArrayView2<F>,
547        y: &ArrayView1<F>,
548        entity: &[usize],
549        time: &[usize],
550    ) -> StatsResult<FEResult<F>>
551    where
552        F: Float
553            + std::iter::Sum
554            + std::fmt::Debug
555            + std::fmt::Display
556            + scirs2_core::numeric::NumAssign
557            + scirs2_core::numeric::One
558            + scirs2_core::ndarray::ScalarOperand
559            + FromPrimitive
560            + Send
561            + Sync
562            + 'static,
563    {
564        let n = y.len();
565        let (nx, k) = x.dim();
566        if nx != n || entity.len() != n || time.len() != n {
567            return Err(StatsError::DimensionMismatch(
568                "x, y, entity, time must have the same length".to_string(),
569            ));
570        }
571        // Build sorted indices: sort by (entity, time)
572        let mut idx: Vec<usize> = (0..n).collect();
573        idx.sort_by_key(|&i| (entity[i], time[i]));
574
575        // Compute first differences
576        let mut dy_vec: Vec<F> = Vec::new();
577        let mut dx_rows: Vec<Vec<F>> = Vec::new();
578        let mut diff_entity: Vec<usize> = Vec::new();
579
580        for w in idx.windows(2) {
581            let i_prev = w[0];
582            let i_curr = w[1];
583            if entity[i_curr] != entity[i_prev] {
584                continue; // different entity → no diff
585            }
586            // consecutive periods within same entity
587            let dy = y[i_curr] - y[i_prev];
588            dy_vec.push(dy);
589            let row: Vec<F> = (0..k).map(|j| x[[i_curr, j]] - x[[i_prev, j]]).collect();
590            dx_rows.push(row);
591            diff_entity.push(entity[i_curr]);
592        }
593
594        let nd = dy_vec.len();
595        if nd < k + 1 {
596            return Err(StatsError::InsufficientData(format!(
597                "First-difference estimator: only {} difference observations for {} regressors",
598                nd, k
599            )));
600        }
601        let yd = Array1::from(dy_vec);
602        let xd_flat: Vec<F> = dx_rows.iter().flat_map(|r| r.iter().copied()).collect();
603        let xd = Array2::from_shape_vec((nd, k), xd_flat)
604            .map_err(|e| StatsError::ComputationError(format!("Array reshape: {e}")))?;
605
606        let (coeffs, resid) = ols(&xd, &yd)?;
607
608        // HC0 standard errors
609        let xtx = matmul(&xd.t().to_owned(), &xd)?;
610        let std_errors = hc0_se(&xd, &resid, &xtx)?;
611        let t_stats: Array1<F> = coeffs
612            .iter()
613            .zip(std_errors.iter())
614            .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
615            .collect();
616
617        let ss_res: F = resid.iter().map(|&r| r * r).sum();
618        let yd_mean = yd.iter().copied().sum::<F>() / F::from_usize(nd).unwrap_or(F::one());
619        let ss_tot: F = yd.iter().map(|&v| (v - yd_mean) * (v - yd_mean)).sum();
620        let r2 = if ss_tot > F::zero() {
621            F::one() - ss_res / ss_tot
622        } else {
623            F::zero()
624        };
625
626        let df1 = F::from_usize(k).unwrap_or(F::one());
627        let df2_int = if nd > k { nd - k } else { 1 };
628        let df2 = F::from_usize(df2_int).unwrap_or(F::one());
629        let f_stat = if (F::one() - r2) > F::zero() {
630            (r2 / df1) / ((F::one() - r2) / df2)
631        } else {
632            F::zero()
633        };
634        let f_pvalue = approximate_f_pvalue(f_stat, k, df2_int);
635        let n_entities = diff_entity
636            .iter()
637            .copied()
638            .max()
639            .map(|m| m + 1)
640            .unwrap_or(0);
641
642        Ok(FEResult {
643            coefficients: coeffs,
644            std_errors,
645            t_stats,
646            f_stat,
647            f_pvalue,
648            r2_within: r2,
649            r2_between: F::zero(),
650            r2_overall: r2,
651            n_obs: nd,
652            n_entities,
653            fitted: {
654                // fitted = yd - resid (i.e., yd itself minus first-differences residuals)
655                let fitted_vals: Array1<F> = yd
656                    .iter()
657                    .zip(resid.iter())
658                    .map(|(&y_val, &r)| y_val - r)
659                    .collect();
660                fitted_vals
661            },
662            residuals: resid,
663            entity_effects: None,
664            time_effects: None,
665        })
666    }
667}
668
669// ──────────────────────────────────────────────────────────────────────────────
670// HC0 standard errors helper
671// ──────────────────────────────────────────────────────────────────────────────
672
673/// Compute HC0 sandwich standard errors.
674/// se_j = sqrt( [(X'X)^{-1} X'diag(e²)X (X'X)^{-1}]_jj )
675fn hc0_se<F>(x: &Array2<F>, e: &Array1<F>, xtx: &Array2<F>) -> StatsResult<Array1<F>>
676where
677    F: Float
678        + std::iter::Sum
679        + std::fmt::Debug
680        + std::fmt::Display
681        + scirs2_core::numeric::NumAssign
682        + scirs2_core::numeric::One
683        + scirs2_core::ndarray::ScalarOperand
684        + FromPrimitive
685        + Send
686        + Sync
687        + 'static,
688{
689    let (n, k) = x.dim();
690    if e.len() != n {
691        return Err(StatsError::DimensionMismatch(
692            "hc0_se: e length mismatch".to_string(),
693        ));
694    }
695    // Meat = X' diag(e²) X  (k×k)
696    let mut meat = Array2::<F>::zeros((k, k));
697    for i in 0..n {
698        let ei2 = e[i] * e[i];
699        for j in 0..k {
700            for l in 0..k {
701                meat[[j, l]] = meat[[j, l]] + x[[i, j]] * x[[i, l]] * ei2;
702            }
703        }
704    }
705    // Solve (X'X) V = meat  for V, then solve (X'X) W = V' for W.
706    // Var(β̂) = (X'X)^{-1} meat (X'X)^{-1}  →  solve column by column.
707    let mut var_beta = Array2::<F>::zeros((k, k));
708    for col in 0..k {
709        let rhs: Array1<F> = (0..k).map(|r| meat[[r, col]]).collect();
710        let v = solve(&xtx.view(), &rhs.view(), None)
711            .map_err(|e2| StatsError::ComputationError(format!("solve failed: {e2}")))?;
712        let rhs2 = v;
713        let w = solve(&xtx.view(), &rhs2.view(), None)
714            .map_err(|e2| StatsError::ComputationError(format!("solve failed: {e2}")))?;
715        for r in 0..k {
716            var_beta[[r, col]] = w[r];
717        }
718    }
719    let se: Array1<F> = (0..k)
720        .map(|j| {
721            let v = var_beta[[j, j]];
722            if v >= F::zero() {
723                v.sqrt()
724            } else {
725                F::zero()
726            }
727        })
728        .collect();
729    Ok(se)
730}
731
732// ──────────────────────────────────────────────────────────────────────────────
733// F p-value approximation (chi²-based for large df2)
734// ──────────────────────────────────────────────────────────────────────────────
735
736/// Very rough p-value for F(df1, df2) using a chi-squared upper-tail approximation.
737fn approximate_f_pvalue<F: Float + FromPrimitive>(f_stat: F, df1: usize, df2: usize) -> F {
738    if f_stat <= F::zero() {
739        return F::one();
740    }
741    // Approximate: chi² = df1 * F_stat; p ≈ 1 - chi²_cdf(chi², df1)
742    let chi2 = F::from_usize(df1).unwrap_or(F::one()) * f_stat;
743    // Wilson-Hilferty approximation for chi²(df1) upper tail
744    let k = F::from_usize(df1).unwrap_or(F::one());
745    let two = F::from_f64(2.0).unwrap_or(F::one());
746    let nine = F::from_f64(9.0).unwrap_or(F::one());
747    let mu = k;
748    let sigma = (two * k).sqrt();
749    let z = (chi2 - mu) / sigma;
750    // P(Z > z) using standard normal approximation
751    p_value_normal_upper(z)
752}
753
754/// Upper-tail probability for N(0,1): P(Z > z), using the rational approximation.
755fn p_value_normal_upper<F: Float + FromPrimitive>(z: F) -> F {
756    // Abramowitz & Stegun 26.2.17 approximation
757    let p1 = F::from_f64(0.2316419).unwrap_or(F::zero());
758    let b1 = F::from_f64(0.319381530).unwrap_or(F::zero());
759    let b2 = F::from_f64(-0.356563782).unwrap_or(F::zero());
760    let b3 = F::from_f64(1.781477937).unwrap_or(F::zero());
761    let b4 = F::from_f64(-1.821255978).unwrap_or(F::zero());
762    let b5 = F::from_f64(1.330274429).unwrap_or(F::zero());
763    let sqrt2pi_inv = F::from_f64(0.39894228).unwrap_or(F::zero());
764
765    let abs_z = if z < F::zero() { -z } else { z };
766    let t = F::one() / (F::one() + p1 * abs_z);
767    let poly = t * (b1 + t * (b2 + t * (b3 + t * (b4 + t * b5))));
768    let phi = sqrt2pi_inv * (-(abs_z * abs_z) / (F::from_f64(2.0).unwrap_or(F::one()))).exp();
769    let p_upper = phi * poly;
770    let p_upper = if p_upper < F::zero() {
771        F::zero()
772    } else if p_upper > F::one() {
773        F::one()
774    } else {
775        p_upper
776    };
777    if z >= F::zero() {
778        p_upper
779    } else {
780        F::one() - p_upper
781    }
782}
783
784// ──────────────────────────────────────────────────────────────────────────────
785// Tests
786// ──────────────────────────────────────────────────────────────────────────────
787
788#[cfg(test)]
789mod tests {
790    use super::*;
791    use scirs2_core::ndarray::{array, Array1, Array2};
792
793    fn make_balanced_panel() -> (Array2<f64>, Array1<f64>, Vec<usize>, Vec<usize>) {
794        // 3 entities × 4 periods
795        // y_it = 1.5*x_it + entity_effect + noise
796        let n = 12;
797        let entity: Vec<usize> = (0..3).flat_map(|e| std::iter::repeat(e).take(4)).collect();
798        let time: Vec<usize> = (0..4).cycle().take(n).collect();
799        // X: regressor with known slope = 1.5
800        let x_vals = [
801            1.0_f64, 2.0, 3.0, 4.0, // entity 0
802            2.0, 3.0, 4.0, 5.0, // entity 1
803            3.0, 4.0, 5.0, 6.0, // entity 2
804        ];
805        // Entity effects: 0.0, 10.0, 20.0
806        let effects = [0.0_f64, 10.0, 20.0];
807        let y_vals: Vec<f64> = (0..n)
808            .map(|i| 1.5 * x_vals[i] + effects[entity[i]])
809            .collect();
810        let x = Array2::from_shape_vec((n, 1), x_vals.to_vec()).unwrap();
811        let y = Array1::from(y_vals);
812        (x, y, entity, time)
813    }
814
815    #[test]
816    fn test_within_transform_demeaning() {
817        let data = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
818        let entity = vec![0, 0, 1, 1];
819        let demeaned = WithinTransform::transform(&data.view(), &entity).unwrap();
820        // entity 0 mean = (1+3)/2=2, (2+4)/2=3
821        assert!((demeaned[[0, 0]] - (-1.0)).abs() < 1e-10);
822        assert!((demeaned[[1, 0]] - 1.0).abs() < 1e-10);
823        // entity 1 mean = (5+7)/2=6, (6+8)/2=7
824        assert!((demeaned[[2, 0]] - (-1.0)).abs() < 1e-10);
825        assert!((demeaned[[3, 0]] - 1.0).abs() < 1e-10);
826    }
827
828    #[test]
829    fn test_fe_model_recovers_slope() {
830        let (x, y, entity, time) = make_balanced_panel();
831        let result = FixedEffectsModel::fit(&x.view(), &y.view(), &entity, &time, false)
832            .expect("FE fit failed");
833        // Should recover slope ≈ 1.5
834        let slope = result.coefficients[0];
835        assert!(
836            (slope - 1.5).abs() < 1e-6,
837            "Expected slope ≈ 1.5, got {}",
838            slope
839        );
840        assert!(result.r2_within > 0.99, "R² within should be near 1");
841    }
842
843    #[test]
844    fn test_first_diff_estimator() {
845        let (x, y, entity, time) = make_balanced_panel();
846        let result =
847            FirstDiffEstimator::fit(&x.view(), &y.view(), &entity, &time).expect("FD fit failed");
848        let slope = result.coefficients[0];
849        assert!(
850            (slope - 1.5).abs() < 1e-6,
851            "FD slope: expected 1.5, got {}",
852            slope
853        );
854    }
855
856    #[test]
857    fn test_two_way_fe() {
858        // Build a larger balanced panel with idiosyncratic (non-time-collinear) x values.
859        // 4 entities × 5 periods, y_it = 1.5 * x_it + entity_effect + time_effect
860        // x values vary independently across entity and time to avoid rank deficiency.
861        let n_ent = 4usize;
862        let t_per = 5usize;
863        let n = n_ent * t_per;
864        let entity: Vec<usize> = (0..n_ent)
865            .flat_map(|e| std::iter::repeat(e).take(t_per))
866            .collect();
867        let time: Vec<usize> = (0..t_per).cycle().take(n).collect();
868        // x values: use a prime-step pattern per entity so they differ from time pattern
869        let prime_steps = [1.0_f64, 2.0, 3.0, 5.0]; // different multiplier per entity
870        let entity_effects = [0.0_f64, 5.0, -3.0, 8.0];
871        let time_effects = [0.0_f64, 1.0, -1.0, 2.0, -2.0];
872        let mut x_vals = Vec::with_capacity(n);
873        let mut y_vals = Vec::with_capacity(n);
874        for (i, (&eid, &tid)) in entity.iter().zip(time.iter()).enumerate() {
875            let x_v = prime_steps[eid] * (1.0 + (i % t_per) as f64 * 0.37);
876            let y_v = 1.5 * x_v + entity_effects[eid] + time_effects[tid];
877            x_vals.push(x_v);
878            y_vals.push(y_v);
879        }
880        let x = Array2::from_shape_vec((n, 1), x_vals).unwrap();
881        let y = Array1::from(y_vals);
882
883        let result =
884            TwoWayFE::fit(&x.view(), &y.view(), &entity, &time).expect("Two-way FE fit failed");
885        assert!(result.time_effects.is_some());
886        let slope = result.coefficients[0];
887        // With two-way FE, slope should still ≈ 1.5
888        assert!(
889            (slope - 1.5).abs() < 0.2,
890            "Two-way FE slope: expected ~1.5, got {}",
891            slope
892        );
893    }
894}