1use crate::error::{StatsError, StatsResult};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::{Float, FromPrimitive};
13use scirs2_linalg::{lstsq, solve};
14
15fn matmul<F: Float + std::iter::Sum>(a: &Array2<F>, b: &Array2<F>) -> StatsResult<Array2<F>> {
20 let (m, k) = a.dim();
21 let (kb, n) = b.dim();
22 if k != kb {
23 return Err(StatsError::DimensionMismatch(format!(
24 "matmul: inner dims mismatch {} vs {}",
25 k, kb
26 )));
27 }
28 let mut c = Array2::zeros((m, n));
29 for i in 0..m {
30 for j in 0..n {
31 let mut s = F::zero();
32 for l in 0..k {
33 s = s + a[[i, l]] * b[[l, j]];
34 }
35 c[[i, j]] = s;
36 }
37 }
38 Ok(c)
39}
40
41fn ols<F>(x: &Array2<F>, y: &Array1<F>) -> StatsResult<(Array1<F>, Array1<F>)>
44where
45 F: Float
46 + std::iter::Sum
47 + std::fmt::Debug
48 + std::fmt::Display
49 + scirs2_core::numeric::NumAssign
50 + scirs2_core::numeric::One
51 + scirs2_core::ndarray::ScalarOperand
52 + FromPrimitive
53 + Send
54 + Sync
55 + 'static,
56{
57 let n = y.len();
58 let (n2, _k) = x.dim();
59 if n != n2 {
60 return Err(StatsError::DimensionMismatch(format!(
61 "ols: x has {} rows, y has {} elements",
62 n2, n
63 )));
64 }
65 let result = lstsq(&x.view(), &y.view(), None)
66 .map_err(|e| StatsError::ComputationError(format!("lstsq failed: {e}")))?;
67 let coeffs = result.x;
68 let mut fitted = Array1::zeros(n);
70 for i in 0..n {
71 let mut s = F::zero();
72 for j in 0..coeffs.len() {
73 s = s + x[[i, j]] * coeffs[j];
74 }
75 fitted[i] = s;
76 }
77 let resid: Array1<F> = y
78 .iter()
79 .zip(fitted.iter())
80 .map(|(&yi, &fi)| yi - fi)
81 .collect();
82 Ok((coeffs, resid))
83}
84
85#[derive(Debug, Clone)]
91pub struct FEResult<F> {
92 pub coefficients: Array1<F>,
94 pub std_errors: Array1<F>,
96 pub t_stats: Array1<F>,
98 pub f_stat: F,
100 pub f_pvalue: F,
102 pub r2_within: F,
104 pub r2_between: F,
106 pub r2_overall: F,
108 pub n_obs: usize,
110 pub n_entities: usize,
112 pub residuals: Array1<F>,
114 pub fitted: Array1<F>,
116 pub entity_effects: Option<Array1<F>>,
118 pub time_effects: Option<Array1<F>>,
120}
121
122pub struct WithinTransform;
130
131impl WithinTransform {
132 pub fn transform<F: Float + FromPrimitive>(
140 data: &ArrayView2<F>,
141 entity: &[usize],
142 ) -> StatsResult<Array2<F>> {
143 let (n, k) = data.dim();
144 if entity.len() != n {
145 return Err(StatsError::DimensionMismatch(format!(
146 "WithinTransform: data has {} rows but entity has {} elements",
147 n,
148 entity.len()
149 )));
150 }
151 let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
153 let mut sums = Array2::<F>::zeros((n_entities, k));
155 let mut counts = vec![0usize; n_entities];
156 for (row, &eid) in entity.iter().enumerate() {
157 counts[eid] += 1;
158 for col in 0..k {
159 sums[[eid, col]] = sums[[eid, col]] + data[[row, col]];
160 }
161 }
162 let mut means = Array2::<F>::zeros((n_entities, k));
163 for eid in 0..n_entities {
164 let cnt = F::from_usize(counts[eid])
165 .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
166 for col in 0..k {
167 means[[eid, col]] = if cnt > F::zero() {
168 sums[[eid, col]] / cnt
169 } else {
170 F::zero()
171 };
172 }
173 }
174 let mut demeaned = data.to_owned();
176 for (row, &eid) in entity.iter().enumerate() {
177 for col in 0..k {
178 demeaned[[row, col]] = demeaned[[row, col]] - means[[eid, col]];
179 }
180 }
181 Ok(demeaned)
182 }
183}
184
185pub struct FixedEffectsModel;
211
212impl FixedEffectsModel {
213 pub fn fit<F>(
222 x: &ArrayView2<F>,
223 y: &ArrayView1<F>,
224 entity: &[usize],
225 time: &[usize],
226 two_way: bool,
227 ) -> StatsResult<FEResult<F>>
228 where
229 F: Float
230 + std::iter::Sum
231 + std::fmt::Debug
232 + std::fmt::Display
233 + scirs2_core::numeric::NumAssign
234 + scirs2_core::numeric::One
235 + scirs2_core::ndarray::ScalarOperand
236 + FromPrimitive
237 + Send
238 + Sync
239 + 'static,
240 {
241 let n = y.len();
242 let (nx, k) = x.dim();
243 if nx != n || entity.len() != n || time.len() != n {
244 return Err(StatsError::DimensionMismatch(
245 "x, y, entity, time must all have the same length N".to_string(),
246 ));
247 }
248 if n == 0 {
249 return Err(StatsError::InsufficientData("Empty dataset".to_string()));
250 }
251
252 let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
253 let n_periods = time.iter().copied().max().map(|m| m + 1).unwrap_or(0);
254
255 let x_owned = x.to_owned();
257 let mut xd = WithinTransform::transform(&x_owned.view(), entity)?;
258 let mut yd_vec: Vec<F> = y.iter().copied().collect();
259
260 let mut y_sums = vec![F::zero(); n_entities];
262 let mut y_counts = vec![0usize; n_entities];
263 for (i, &eid) in entity.iter().enumerate() {
264 y_sums[eid] = y_sums[eid] + y[i];
265 y_counts[eid] += 1;
266 }
267 let y_entity_means: Vec<F> = y_sums
268 .iter()
269 .zip(y_counts.iter())
270 .map(|(&s, &c)| {
271 if c > 0 {
272 s / F::from_usize(c).unwrap_or(F::one())
273 } else {
274 F::zero()
275 }
276 })
277 .collect();
278 for (i, &eid) in entity.iter().enumerate() {
279 yd_vec[i] = yd_vec[i] - y_entity_means[eid];
280 }
281
282 if two_way {
283 let mut yd2 = yd_vec.clone();
286 let mut t_sums = vec![F::zero(); n_periods];
287 let mut t_counts = vec![0usize; n_periods];
288 for (i, &tid) in time.iter().enumerate() {
289 t_sums[tid] = t_sums[tid] + yd2[i];
290 t_counts[tid] += 1;
291 }
292 let y_time_means: Vec<F> = t_sums
293 .iter()
294 .zip(t_counts.iter())
295 .map(|(&s, &c)| {
296 if c > 0 {
297 s / F::from_usize(c).unwrap_or(F::one())
298 } else {
299 F::zero()
300 }
301 })
302 .collect();
303 for (i, &tid) in time.iter().enumerate() {
304 yd2[i] = yd2[i] - y_time_means[tid];
305 }
306 yd_vec = yd2;
307
308 let xd2 = WithinTransform::transform(&xd.view(), time)?;
310 xd = xd2;
311 }
312
313 let yd = Array1::from(yd_vec);
314
315 let (coeffs, resid) = ols(&xd, &yd)?;
317
318 let mut fitted = Array1::zeros(n);
320 for i in 0..n {
321 let mut s = y_entity_means[entity[i]]; for j in 0..k {
323 s = s + x[[i, j]] * coeffs[j];
324 }
325 fitted[i] = s;
326 }
327 let orig_resid: Array1<F> = (0..n).map(|i| y[i] - fitted[i]).collect();
328
329 let ss_res_within: F = resid.iter().map(|&r| r * r).sum();
331 let yd_mean = yd.iter().copied().sum::<F>()
332 / F::from_usize(n)
333 .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
334 let ss_tot_within: F = yd.iter().map(|&v| (v - yd_mean) * (v - yd_mean)).sum();
335 let r2_within = if ss_tot_within > F::zero() {
336 F::one() - ss_res_within / ss_tot_within
337 } else {
338 F::zero()
339 };
340
341 let mut fy_sums = vec![F::zero(); n_entities];
344 for (i, &eid) in entity.iter().enumerate() {
345 fy_sums[eid] = fy_sums[eid] + fitted[i];
346 }
347 let y_bar_bar = y.iter().copied().sum::<F>()
348 / F::from_usize(n)
349 .ok_or_else(|| StatsError::ComputationError("FromPrimitive failed".to_string()))?;
350 let mut ss_between_tot = F::zero();
351 let mut ss_between_res = F::zero();
352 for eid in 0..n_entities {
353 if y_counts[eid] == 0 {
354 continue;
355 }
356 let cnt = F::from_usize(y_counts[eid]).unwrap_or(F::one());
357 let y_em = y_entity_means[eid];
358 let f_em = fy_sums[eid] / cnt;
359 ss_between_tot = ss_between_tot + cnt * (y_em - y_bar_bar) * (y_em - y_bar_bar);
360 ss_between_res = ss_between_res + cnt * (y_em - f_em) * (y_em - f_em);
361 }
362 let r2_between = if ss_between_tot > F::zero() {
363 F::one() - ss_between_res / ss_between_tot
364 } else {
365 F::zero()
366 };
367
368 let ss_tot: F = y
370 .iter()
371 .map(|&yi| (yi - y_bar_bar) * (yi - y_bar_bar))
372 .sum();
373 let ss_res_overall: F = orig_resid.iter().map(|&r| r * r).sum();
374 let r2_overall = if ss_tot > F::zero() {
375 F::one() - ss_res_overall / ss_tot
376 } else {
377 F::zero()
378 };
379
380 let xtx = matmul(&xd.t().to_owned(), &xd)?;
383 let std_errors = hc0_se(&xd, &resid, &xtx)?;
385
386 let t_stats: Array1<F> = coeffs
387 .iter()
388 .zip(std_errors.iter())
389 .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
390 .collect();
391
392 let df1 = F::from_usize(k).unwrap_or(F::one());
395 let df2_int = if n > n_entities + k {
396 n - n_entities - k
397 } else {
398 1
399 };
400 let df2 = F::from_usize(df2_int).unwrap_or(F::one());
401 let f_stat = if (F::one() - r2_within) > F::zero() {
402 (r2_within / df1) / ((F::one() - r2_within) / df2)
403 } else {
404 F::zero()
405 };
406 let f_pvalue = approximate_f_pvalue(f_stat, k, df2_int);
407
408 let mut entity_effects = Array1::zeros(n_entities);
411 for eid in 0..n_entities {
412 if y_counts[eid] == 0 {
413 continue;
414 }
415 let cnt = F::from_usize(y_counts[eid]).unwrap_or(F::one());
416 let mut x_row_mean = vec![F::zero(); k];
418 for (i, &e2) in entity.iter().enumerate() {
419 if e2 == eid {
420 for j in 0..k {
421 x_row_mean[j] = x_row_mean[j] + x[[i, j]];
422 }
423 }
424 }
425 let mut alpha = y_entity_means[eid];
426 for j in 0..k {
427 alpha = alpha - (x_row_mean[j] / cnt) * coeffs[j];
428 }
429 entity_effects[eid] = alpha;
430 }
431
432 Ok(FEResult {
433 coefficients: coeffs,
434 std_errors,
435 t_stats,
436 f_stat,
437 f_pvalue,
438 r2_within,
439 r2_between,
440 r2_overall,
441 n_obs: n,
442 n_entities,
443 residuals: orig_resid,
444 fitted,
445 entity_effects: Some(entity_effects),
446 time_effects: None,
447 })
448 }
449}
450
451pub struct TwoWayFE;
459
460impl TwoWayFE {
461 pub fn fit<F>(
463 x: &ArrayView2<F>,
464 y: &ArrayView1<F>,
465 entity: &[usize],
466 time: &[usize],
467 ) -> StatsResult<FEResult<F>>
468 where
469 F: Float
470 + std::iter::Sum
471 + std::fmt::Debug
472 + std::fmt::Display
473 + scirs2_core::numeric::NumAssign
474 + scirs2_core::numeric::One
475 + scirs2_core::ndarray::ScalarOperand
476 + FromPrimitive
477 + Send
478 + Sync
479 + 'static,
480 {
481 let n = y.len();
482 let n_entities = entity.iter().copied().max().map(|m| m + 1).unwrap_or(0);
483 let n_periods = time.iter().copied().max().map(|m| m + 1).unwrap_or(0);
484
485 let mut result = FixedEffectsModel::fit(x, y, entity, time, true)?;
486
487 let k = result.coefficients.len();
491 let mut time_effects = Array1::zeros(n_periods);
492 let mut t_sums = vec![F::zero(); n_periods];
493 let mut t_x_sums = vec![vec![F::zero(); k]; n_periods];
494 let mut t_counts = vec![0usize; n_periods];
495 for (i, &tid) in time.iter().enumerate() {
496 t_sums[tid] = t_sums[tid] + y[i];
497 t_counts[tid] += 1;
498 for j in 0..k {
499 t_x_sums[tid][j] = t_x_sums[tid][j] + x[[i, j]];
500 }
501 }
502 let y_bar = y.iter().copied().sum::<F>() / F::from_usize(n).unwrap_or(F::one());
503 let mut x_bar = vec![F::zero(); k];
504 for j in 0..k {
505 let s: F = (0..n).map(|i| x[[i, j]]).sum();
506 x_bar[j] = s / F::from_usize(n).unwrap_or(F::one());
507 }
508
509 for tid in 0..n_periods {
510 if t_counts[tid] == 0 {
511 continue;
512 }
513 let cnt = F::from_usize(t_counts[tid]).unwrap_or(F::one());
514 let y_t_bar = t_sums[tid] / cnt;
515 let mut tau = y_t_bar - y_bar;
516 for j in 0..k {
517 let x_t_bar_j = t_x_sums[tid][j] / cnt;
518 tau = tau - (x_t_bar_j - x_bar[j]) * result.coefficients[j];
519 }
520 time_effects[tid] = tau;
521 }
522 result.time_effects = Some(time_effects);
523 Ok(result)
524 }
525}
526
527pub struct FirstDiffEstimator;
536
537impl FirstDiffEstimator {
538 pub fn fit<F>(
546 x: &ArrayView2<F>,
547 y: &ArrayView1<F>,
548 entity: &[usize],
549 time: &[usize],
550 ) -> StatsResult<FEResult<F>>
551 where
552 F: Float
553 + std::iter::Sum
554 + std::fmt::Debug
555 + std::fmt::Display
556 + scirs2_core::numeric::NumAssign
557 + scirs2_core::numeric::One
558 + scirs2_core::ndarray::ScalarOperand
559 + FromPrimitive
560 + Send
561 + Sync
562 + 'static,
563 {
564 let n = y.len();
565 let (nx, k) = x.dim();
566 if nx != n || entity.len() != n || time.len() != n {
567 return Err(StatsError::DimensionMismatch(
568 "x, y, entity, time must have the same length".to_string(),
569 ));
570 }
571 let mut idx: Vec<usize> = (0..n).collect();
573 idx.sort_by_key(|&i| (entity[i], time[i]));
574
575 let mut dy_vec: Vec<F> = Vec::new();
577 let mut dx_rows: Vec<Vec<F>> = Vec::new();
578 let mut diff_entity: Vec<usize> = Vec::new();
579
580 for w in idx.windows(2) {
581 let i_prev = w[0];
582 let i_curr = w[1];
583 if entity[i_curr] != entity[i_prev] {
584 continue; }
586 let dy = y[i_curr] - y[i_prev];
588 dy_vec.push(dy);
589 let row: Vec<F> = (0..k).map(|j| x[[i_curr, j]] - x[[i_prev, j]]).collect();
590 dx_rows.push(row);
591 diff_entity.push(entity[i_curr]);
592 }
593
594 let nd = dy_vec.len();
595 if nd < k + 1 {
596 return Err(StatsError::InsufficientData(format!(
597 "First-difference estimator: only {} difference observations for {} regressors",
598 nd, k
599 )));
600 }
601 let yd = Array1::from(dy_vec);
602 let xd_flat: Vec<F> = dx_rows.iter().flat_map(|r| r.iter().copied()).collect();
603 let xd = Array2::from_shape_vec((nd, k), xd_flat)
604 .map_err(|e| StatsError::ComputationError(format!("Array reshape: {e}")))?;
605
606 let (coeffs, resid) = ols(&xd, &yd)?;
607
608 let xtx = matmul(&xd.t().to_owned(), &xd)?;
610 let std_errors = hc0_se(&xd, &resid, &xtx)?;
611 let t_stats: Array1<F> = coeffs
612 .iter()
613 .zip(std_errors.iter())
614 .map(|(&c, &se)| if se > F::zero() { c / se } else { F::zero() })
615 .collect();
616
617 let ss_res: F = resid.iter().map(|&r| r * r).sum();
618 let yd_mean = yd.iter().copied().sum::<F>() / F::from_usize(nd).unwrap_or(F::one());
619 let ss_tot: F = yd.iter().map(|&v| (v - yd_mean) * (v - yd_mean)).sum();
620 let r2 = if ss_tot > F::zero() {
621 F::one() - ss_res / ss_tot
622 } else {
623 F::zero()
624 };
625
626 let df1 = F::from_usize(k).unwrap_or(F::one());
627 let df2_int = if nd > k { nd - k } else { 1 };
628 let df2 = F::from_usize(df2_int).unwrap_or(F::one());
629 let f_stat = if (F::one() - r2) > F::zero() {
630 (r2 / df1) / ((F::one() - r2) / df2)
631 } else {
632 F::zero()
633 };
634 let f_pvalue = approximate_f_pvalue(f_stat, k, df2_int);
635 let n_entities = diff_entity
636 .iter()
637 .copied()
638 .max()
639 .map(|m| m + 1)
640 .unwrap_or(0);
641
642 Ok(FEResult {
643 coefficients: coeffs,
644 std_errors,
645 t_stats,
646 f_stat,
647 f_pvalue,
648 r2_within: r2,
649 r2_between: F::zero(),
650 r2_overall: r2,
651 n_obs: nd,
652 n_entities,
653 fitted: {
654 let fitted_vals: Array1<F> = yd
656 .iter()
657 .zip(resid.iter())
658 .map(|(&y_val, &r)| y_val - r)
659 .collect();
660 fitted_vals
661 },
662 residuals: resid,
663 entity_effects: None,
664 time_effects: None,
665 })
666 }
667}
668
669fn hc0_se<F>(x: &Array2<F>, e: &Array1<F>, xtx: &Array2<F>) -> StatsResult<Array1<F>>
676where
677 F: Float
678 + std::iter::Sum
679 + std::fmt::Debug
680 + std::fmt::Display
681 + scirs2_core::numeric::NumAssign
682 + scirs2_core::numeric::One
683 + scirs2_core::ndarray::ScalarOperand
684 + FromPrimitive
685 + Send
686 + Sync
687 + 'static,
688{
689 let (n, k) = x.dim();
690 if e.len() != n {
691 return Err(StatsError::DimensionMismatch(
692 "hc0_se: e length mismatch".to_string(),
693 ));
694 }
695 let mut meat = Array2::<F>::zeros((k, k));
697 for i in 0..n {
698 let ei2 = e[i] * e[i];
699 for j in 0..k {
700 for l in 0..k {
701 meat[[j, l]] = meat[[j, l]] + x[[i, j]] * x[[i, l]] * ei2;
702 }
703 }
704 }
705 let mut var_beta = Array2::<F>::zeros((k, k));
708 for col in 0..k {
709 let rhs: Array1<F> = (0..k).map(|r| meat[[r, col]]).collect();
710 let v = solve(&xtx.view(), &rhs.view(), None)
711 .map_err(|e2| StatsError::ComputationError(format!("solve failed: {e2}")))?;
712 let rhs2 = v;
713 let w = solve(&xtx.view(), &rhs2.view(), None)
714 .map_err(|e2| StatsError::ComputationError(format!("solve failed: {e2}")))?;
715 for r in 0..k {
716 var_beta[[r, col]] = w[r];
717 }
718 }
719 let se: Array1<F> = (0..k)
720 .map(|j| {
721 let v = var_beta[[j, j]];
722 if v >= F::zero() {
723 v.sqrt()
724 } else {
725 F::zero()
726 }
727 })
728 .collect();
729 Ok(se)
730}
731
732fn approximate_f_pvalue<F: Float + FromPrimitive>(f_stat: F, df1: usize, df2: usize) -> F {
738 if f_stat <= F::zero() {
739 return F::one();
740 }
741 let chi2 = F::from_usize(df1).unwrap_or(F::one()) * f_stat;
743 let k = F::from_usize(df1).unwrap_or(F::one());
745 let two = F::from_f64(2.0).unwrap_or(F::one());
746 let nine = F::from_f64(9.0).unwrap_or(F::one());
747 let mu = k;
748 let sigma = (two * k).sqrt();
749 let z = (chi2 - mu) / sigma;
750 p_value_normal_upper(z)
752}
753
754fn p_value_normal_upper<F: Float + FromPrimitive>(z: F) -> F {
756 let p1 = F::from_f64(0.2316419).unwrap_or(F::zero());
758 let b1 = F::from_f64(0.319381530).unwrap_or(F::zero());
759 let b2 = F::from_f64(-0.356563782).unwrap_or(F::zero());
760 let b3 = F::from_f64(1.781477937).unwrap_or(F::zero());
761 let b4 = F::from_f64(-1.821255978).unwrap_or(F::zero());
762 let b5 = F::from_f64(1.330274429).unwrap_or(F::zero());
763 let sqrt2pi_inv = F::from_f64(0.39894228).unwrap_or(F::zero());
764
765 let abs_z = if z < F::zero() { -z } else { z };
766 let t = F::one() / (F::one() + p1 * abs_z);
767 let poly = t * (b1 + t * (b2 + t * (b3 + t * (b4 + t * b5))));
768 let phi = sqrt2pi_inv * (-(abs_z * abs_z) / (F::from_f64(2.0).unwrap_or(F::one()))).exp();
769 let p_upper = phi * poly;
770 let p_upper = if p_upper < F::zero() {
771 F::zero()
772 } else if p_upper > F::one() {
773 F::one()
774 } else {
775 p_upper
776 };
777 if z >= F::zero() {
778 p_upper
779 } else {
780 F::one() - p_upper
781 }
782}
783
784#[cfg(test)]
789mod tests {
790 use super::*;
791 use scirs2_core::ndarray::{array, Array1, Array2};
792
793 fn make_balanced_panel() -> (Array2<f64>, Array1<f64>, Vec<usize>, Vec<usize>) {
794 let n = 12;
797 let entity: Vec<usize> = (0..3).flat_map(|e| std::iter::repeat(e).take(4)).collect();
798 let time: Vec<usize> = (0..4).cycle().take(n).collect();
799 let x_vals = [
801 1.0_f64, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, ];
805 let effects = [0.0_f64, 10.0, 20.0];
807 let y_vals: Vec<f64> = (0..n)
808 .map(|i| 1.5 * x_vals[i] + effects[entity[i]])
809 .collect();
810 let x = Array2::from_shape_vec((n, 1), x_vals.to_vec()).unwrap();
811 let y = Array1::from(y_vals);
812 (x, y, entity, time)
813 }
814
815 #[test]
816 fn test_within_transform_demeaning() {
817 let data = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
818 let entity = vec![0, 0, 1, 1];
819 let demeaned = WithinTransform::transform(&data.view(), &entity).unwrap();
820 assert!((demeaned[[0, 0]] - (-1.0)).abs() < 1e-10);
822 assert!((demeaned[[1, 0]] - 1.0).abs() < 1e-10);
823 assert!((demeaned[[2, 0]] - (-1.0)).abs() < 1e-10);
825 assert!((demeaned[[3, 0]] - 1.0).abs() < 1e-10);
826 }
827
828 #[test]
829 fn test_fe_model_recovers_slope() {
830 let (x, y, entity, time) = make_balanced_panel();
831 let result = FixedEffectsModel::fit(&x.view(), &y.view(), &entity, &time, false)
832 .expect("FE fit failed");
833 let slope = result.coefficients[0];
835 assert!(
836 (slope - 1.5).abs() < 1e-6,
837 "Expected slope ≈ 1.5, got {}",
838 slope
839 );
840 assert!(result.r2_within > 0.99, "R² within should be near 1");
841 }
842
843 #[test]
844 fn test_first_diff_estimator() {
845 let (x, y, entity, time) = make_balanced_panel();
846 let result =
847 FirstDiffEstimator::fit(&x.view(), &y.view(), &entity, &time).expect("FD fit failed");
848 let slope = result.coefficients[0];
849 assert!(
850 (slope - 1.5).abs() < 1e-6,
851 "FD slope: expected 1.5, got {}",
852 slope
853 );
854 }
855
856 #[test]
857 fn test_two_way_fe() {
858 let n_ent = 4usize;
862 let t_per = 5usize;
863 let n = n_ent * t_per;
864 let entity: Vec<usize> = (0..n_ent)
865 .flat_map(|e| std::iter::repeat(e).take(t_per))
866 .collect();
867 let time: Vec<usize> = (0..t_per).cycle().take(n).collect();
868 let prime_steps = [1.0_f64, 2.0, 3.0, 5.0]; let entity_effects = [0.0_f64, 5.0, -3.0, 8.0];
871 let time_effects = [0.0_f64, 1.0, -1.0, 2.0, -2.0];
872 let mut x_vals = Vec::with_capacity(n);
873 let mut y_vals = Vec::with_capacity(n);
874 for (i, (&eid, &tid)) in entity.iter().zip(time.iter()).enumerate() {
875 let x_v = prime_steps[eid] * (1.0 + (i % t_per) as f64 * 0.37);
876 let y_v = 1.5 * x_v + entity_effects[eid] + time_effects[tid];
877 x_vals.push(x_v);
878 y_vals.push(y_v);
879 }
880 let x = Array2::from_shape_vec((n, 1), x_vals).unwrap();
881 let y = Array1::from(y_vals);
882
883 let result =
884 TwoWayFE::fit(&x.view(), &y.view(), &entity, &time).expect("Two-way FE fit failed");
885 assert!(result.time_effects.is_some());
886 let slope = result.coefficients[0];
887 assert!(
889 (slope - 1.5).abs() < 0.2,
890 "Two-way FE slope: expected ~1.5, got {}",
891 slope
892 );
893 }
894}