Skip to main content

so_models/
mixed.rs

1//! Mixed effects models for StatOxide
2//!
3//! This module implements linear mixed models (LMM) and generalized linear mixed models (GLMM)
4//! for hierarchical or clustered data.
5//!
6//! # Model Specification
7//!
8//! Linear Mixed Model:
9//! y = Xβ + Zb + ε
10//! where:
11//! - y: response vector
12//! - X: fixed effects design matrix
13//! - β: fixed effects coefficients
14//! - Z: random effects design matrix
15//! - b: random effects coefficients ~ N(0, G)
16//! - ε: residuals ~ N(0, R)
17//!
18//! Generalized Linear Mixed Model extends LMM to non-Gaussian responses using link functions.
19//!
20//! # Estimation Methods
21//!
22//! 1. **REML**: Restricted Maximum Likelihood (preferred for variance components)
23//! 2. **ML**: Maximum Likelihood
24//! 3. **PQL**: Penalized Quasi-Likelihood (for GLMM)
25//! 4. **Laplace Approximation**: For non-Gaussian GLMM
26//!
27
28#![allow(non_snake_case)] // Allow mathematical notation (X, Z, V, etc.)
29
30use ndarray::{Array1, Array2};
31use serde::{Deserialize, Serialize};
32
33use so_core::data::DataFrame;
34use so_core::error::{Error, Result};
35use so_core::formula::Formula;
36use so_linalg::{inv, solve};
37
38use crate::glm::{Family, GLM, Link};
39
40/// Random effects structure specification
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RandomEffect {
43    /// Name of the grouping variable
44    pub group_var: String,
45    /// Formula for random effects within groups
46    pub formula: String,
47    /// Covariance structure (currently only supports independent)
48    pub covariance: RandomCovariance,
49}
50
51/// Covariance structure for random effects
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum RandomCovariance {
54    /// Independent random effects (diagonal covariance)
55    Independent,
56    /// Compound symmetry (exchangeable)
57    CompoundSymmetry,
58    /// Auto-regressive of order 1
59    AR1,
60    /// Unstructured covariance
61    Unstructured,
62    /// Custom covariance matrix
63    Custom(Array2<f64>),
64}
65
66/// Linear Mixed Model (LMM) results
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LMMResults {
69    /// Fixed effects coefficients
70    pub fixed_effects: Array1<f64>,
71    /// Standard errors for fixed effects
72    pub fixed_se: Array1<f64>,
73    /// Random effects variance components
74    pub variance_components: Vec<(String, f64)>,
75    /// Residual variance
76    pub residual_variance: f64,
77    /// Log-likelihood
78    pub log_lik: f64,
79    /// Akaike Information Criterion
80    pub aic: f64,
81    /// Bayesian Information Criterion
82    pub bic: f64,
83    /// Degrees of freedom for fixed effects
84    pub df_fixed: usize,
85    /// Degrees of freedom for residuals
86    pub df_resid: usize,
87    /// Convergence status
88    pub converged: bool,
89    /// Number of iterations
90    pub iterations: usize,
91}
92
93/// Linear Mixed Model builder
94pub struct LinearMixedModelBuilder {
95    data: DataFrame,
96    response: String,
97    fixed_formula: String,
98    random_effects: Vec<RandomEffect>,
99    method: EstimationMethod,
100    max_iter: usize,
101    tol: f64,
102}
103
104/// Estimation method for mixed models
105#[derive(Debug, Clone, Copy)]
106pub enum EstimationMethod {
107    /// Restricted Maximum Likelihood
108    REML,
109    /// Maximum Likelihood
110    ML,
111}
112
113/// Estimation method for GLMM
114#[derive(Debug, Clone, Copy)]
115pub enum GLMMEstimationMethod {
116    /// Penalized Quasi-Likelihood (PQL)
117    PQL,
118    /// Laplace Approximation
119    Laplace,
120    /// Adaptive Gauss-Hermite Quadrature (higher accuracy)
121    AGHQ(usize), // number of quadrature points
122}
123
124impl LinearMixedModelBuilder {
125    /// Create a new LMM builder
126    pub fn new(data: DataFrame, response: &str, fixed_formula: &str) -> Self {
127        Self {
128            data,
129            response: response.to_string(),
130            fixed_formula: fixed_formula.to_string(),
131            random_effects: Vec::new(),
132            method: EstimationMethod::REML,
133            max_iter: 100,
134            tol: 1e-6,
135        }
136    }
137
138    /// Add a random effect
139    pub fn random_effect(mut self, group_var: &str, formula: &str) -> Self {
140        self.random_effects.push(RandomEffect {
141            group_var: group_var.to_string(),
142            formula: formula.to_string(),
143            covariance: RandomCovariance::Independent,
144        });
145        self
146    }
147
148    /// Set estimation method
149    pub fn method(mut self, method: EstimationMethod) -> Self {
150        self.method = method;
151        self
152    }
153
154    /// Set maximum iterations
155    pub fn max_iterations(mut self, max_iter: usize) -> Self {
156        self.max_iter = max_iter;
157        self
158    }
159
160    /// Set convergence tolerance
161    pub fn tolerance(mut self, tol: f64) -> Self {
162        self.tol = tol;
163        self
164    }
165
166    /// Fit the linear mixed model
167    pub fn fit(self) -> Result<LMMResults> {
168        // Parse fixed effects formula
169        // This is a simplified implementation
170        // In practice, we would use the formula parser from so-core
171
172        // Extract response variable
173        let y = self.data.column(&self.response).ok_or_else(|| {
174            Error::DataError(format!("Response column '{}' not found", self.response))
175        })?;
176        let y_array = y.data().to_owned();
177
178        // Build design matrix for fixed effects
179        let X = self.build_fixed_design_matrix()?;
180
181        // Build random effects design matrices
182        let (Z_matrices, group_sizes) = self.build_random_design_matrices()?;
183
184        // Fit using EM algorithm (simplified)
185        self.fit_em(&y_array, &X, &Z_matrices, &group_sizes)
186    }
187
188    /// Build fixed effects design matrix
189    fn build_fixed_design_matrix(&self) -> Result<Array2<f64>> {
190        // Parse fixed effects formula
191        let formula_str = if self.fixed_formula.contains('~') {
192            self.fixed_formula.clone()
193        } else {
194            // If no response in formula (for random effects), add placeholder
195            format!("__response__ ~ {}", self.fixed_formula)
196        };
197
198        let formula = Formula::parse(&formula_str)
199            .map_err(|e| Error::FormulaError(format!("Failed to parse fixed formula: {}", e)))?;
200
201        // Build design matrix (includes intercept if specified)
202        formula
203            .build_matrix(&self.data)
204            .map_err(|e| Error::DataError(format!("Failed to build design matrix: {}", e)))
205    }
206
207    /// Build random effects design matrices
208    fn build_random_design_matrices(&self) -> Result<(Vec<Array2<f64>>, Vec<usize>)> {
209        let mut Z_matrices = Vec::new();
210        let mut group_sizes = Vec::new();
211
212        for random_effect in &self.random_effects {
213            // Simplified: create indicator matrix for groups
214            let group_col = self.data.column(&random_effect.group_var).ok_or_else(|| {
215                Error::DataError(format!(
216                    "Group column '{}' not found",
217                    random_effect.group_var
218                ))
219            })?;
220
221            // TODO: Implement proper categorical extraction
222            let groups: Vec<String> = vec!["group1".to_string(), "group2".to_string()]; // Placeholder
223            let n_groups = groups.len();
224            let n = self.data.n_rows();
225
226            let mut Z = Array2::zeros((n, n_groups));
227
228            // Extract group indices from column data (assumes numeric encoding)
229            let group_data = group_col.data();
230            for j in 0..n {
231                let group_idx = group_data[j] as usize % n_groups; // Simple mapping
232                Z[(j, group_idx)] = 1.0;
233            }
234
235            Z_matrices.push(Z);
236            group_sizes.push(n_groups);
237        }
238
239        Ok((Z_matrices, group_sizes))
240    }
241
242    /// Fit using Expectation-Maximization algorithm (simplified)
243    #[allow(unused_assignments, unused_variables)]
244    fn fit_em(
245        &self,
246        y: &Array1<f64>,
247        X: &Array2<f64>,
248        Z_matrices: &[Array2<f64>],
249        group_sizes: &[usize],
250    ) -> Result<LMMResults> {
251        let n = y.len();
252        let p = X.ncols();
253
254        // Initial values
255        let mut sigma2_e = 1.0; // Residual variance
256        let mut sigma2_u = vec![1.0; Z_matrices.len()]; // Random effect variances
257
258        // Combine all Z matrices into a single block-diagonal matrix
259        let Z = self.combine_Z_matrices(Z_matrices, group_sizes);
260        let q = Z.ncols();
261
262        let mut beta = Array1::zeros(p);
263        let mut u = Array1::zeros(q);
264
265        let mut converged = false;
266        let mut iter = 0;
267
268        while !converged && iter < self.max_iter {
269            iter += 1;
270
271            // E-step: Update random effects
272            let V_inv = self.compute_V_inv(&Z, sigma2_e, &sigma2_u, group_sizes)?;
273            let XtVX = X.t().dot(&V_inv.dot(X));
274            let XtVy = X.t().dot(&V_inv.dot(y));
275
276            beta = solve(&XtVX, &XtVy).map_err(|e| {
277                Error::LinearAlgebraError(format!("Failed to solve for beta: {}", e))
278            })?;
279
280            let residuals = y - X.dot(&beta);
281            u = Z.t().dot(&V_inv.dot(&residuals));
282
283            // M-step: Update variance components
284            let old_sigma2_e = sigma2_e;
285            let old_sigma2_u = sigma2_u.clone();
286
287            // Update residual variance
288            let y_Xb = y - X.dot(&beta);
289            let y_Xb_Zu = &y_Xb - Z.dot(&u);
290            sigma2_e = y_Xb_Zu.dot(&y_Xb_Zu) / (n - p) as f64;
291
292            // Update random effect variances (simplified)
293            for i in 0..sigma2_u.len() {
294                let start_idx: usize = group_sizes[..i].iter().sum();
295                let end_idx = start_idx + group_sizes[i];
296                let u_i = u.slice(ndarray::s![start_idx..end_idx]);
297                sigma2_u[i] = u_i.dot(&u_i) / group_sizes[i] as f64;
298            }
299
300            // Check convergence
301            let delta_e = (sigma2_e - old_sigma2_e).abs() / old_sigma2_e.max(1e-10);
302            let max_delta_u = sigma2_u
303                .iter()
304                .zip(&old_sigma2_u)
305                .map(|(new, old)| (new - old).abs() / old.max(1e-10))
306                .fold(0.0, f64::max);
307
308            converged = delta_e < self.tol && max_delta_u < self.tol;
309        }
310
311        // Compute standard errors and log-likelihood
312        let V_inv = self.compute_V_inv(&Z, sigma2_e, &sigma2_u, group_sizes)?;
313        let XtVX = X.t().dot(&V_inv.dot(X));
314        let cov_beta = inv(&XtVX).map_err(|e| {
315            Error::LinearAlgebraError(format!("Failed to invert X'V^{{-1}}X: {}", e))
316        })?;
317
318        let fixed_se = cov_beta.diag().mapv(|x| x.sqrt());
319
320        // Compute log-likelihood
321        let V = self.compute_V(&Z, sigma2_e, &sigma2_u, group_sizes);
322        let log_lik = self.compute_log_lik(y, X, &V, beta.clone(), self.method);
323
324        // Compute information criteria
325        let n_params = p + sigma2_u.len() + 1; // beta + variance components
326        let aic = -2.0 * log_lik + 2.0 * n_params as f64;
327        let bic = -2.0 * log_lik + (n_params as f64) * (n as f64).ln();
328
329        // Prepare variance components with names
330        let mut var_comps = Vec::new();
331        for (i, random_effect) in self.random_effects.iter().enumerate() {
332            var_comps.push((random_effect.group_var.clone(), sigma2_u[i]));
333        }
334
335        Ok(LMMResults {
336            fixed_effects: beta,
337            fixed_se,
338            variance_components: var_comps,
339            residual_variance: sigma2_e,
340            log_lik,
341            aic,
342            bic,
343            df_fixed: p,
344            df_resid: n - p,
345            converged,
346            iterations: iter,
347        })
348    }
349
350    /// Combine Z matrices into block-diagonal matrix
351    fn combine_Z_matrices(&self, Z_matrices: &[Array2<f64>], group_sizes: &[usize]) -> Array2<f64> {
352        let n = Z_matrices[0].nrows();
353        let total_cols: usize = group_sizes.iter().sum();
354
355        let mut Z = Array2::zeros((n, total_cols));
356        let mut col_offset = 0;
357
358        for (i, Z_i) in Z_matrices.iter().enumerate() {
359            let cols = group_sizes[i];
360            for row in 0..n {
361                for col in 0..cols {
362                    Z[(row, col_offset + col)] = Z_i[(row, col)];
363                }
364            }
365            col_offset += cols;
366        }
367
368        Z
369    }
370
371    /// Compute V = ZGZ' + σ²I
372    fn compute_V(
373        &self,
374        Z: &Array2<f64>,
375        sigma2_e: f64,
376        sigma2_u: &[f64],
377        group_sizes: &[usize],
378    ) -> Array2<f64> {
379        let n = Z.nrows();
380        let mut V = Array2::zeros((n, n));
381
382        // Add residual variance component
383        for i in 0..n {
384            V[(i, i)] = sigma2_e;
385        }
386
387        // Add random effect components
388        let mut col_offset = 0;
389        for (k, &sigma2_u_k) in sigma2_u.iter().enumerate() {
390            let cols = group_sizes[k];
391            let Z_k = Z.slice(ndarray::s![.., col_offset..col_offset + cols]);
392
393            // Add Z_k G_k Z_k' where G_k = σ²_u_k I
394            let ZkZkt = Z_k.dot(&Z_k.t());
395            V = &V + &(ZkZkt * sigma2_u_k);
396
397            col_offset += cols;
398        }
399
400        V
401    }
402
403    /// Compute V^{-1} using Woodbury identity (simplified)
404    fn compute_V_inv(
405        &self,
406        Z: &Array2<f64>,
407        sigma2_e: f64,
408        _sigma2_u: &[f64],
409        _group_sizes: &[usize],
410    ) -> Result<Array2<f64>> {
411        let n = Z.nrows();
412        let mut V_inv = Array2::zeros((n, n));
413
414        // For independent random effects with diagonal G, we can use special structure
415        // Simplified: return identity for now
416        for i in 0..n {
417            V_inv[(i, i)] = 1.0 / sigma2_e;
418        }
419
420        Ok(V_inv)
421    }
422
423    /// Compute log-likelihood
424    fn compute_log_lik(
425        &self,
426        y: &Array1<f64>,
427        X: &Array2<f64>,
428        V: &Array2<f64>,
429        beta: Array1<f64>,
430        method: EstimationMethod,
431    ) -> f64 {
432        let n = y.len() as f64;
433        let _p = X.ncols() as f64;
434
435        // Compute residuals
436        let residuals = y - X.dot(&beta);
437
438        // Log-determinant of V (simplified: assuming V is diagonal)
439        let log_det_V: f64 = V.diag().iter().map(|&v| v.ln()).sum();
440
441        // Quadratic form: r'V^{-1}r
442        // Note: need to clone residuals for the division
443        let residuals_clone = residuals.clone();
444        let Vinv_r = residuals_clone / V.diag(); // Simplified for diagonal V
445        let quad_form = residuals.dot(&Vinv_r);
446
447        let log_lik = -0.5 * (n * (2.0 * std::f64::consts::PI).ln() + log_det_V + quad_form);
448
449        match method {
450            EstimationMethod::ML => log_lik,
451            EstimationMethod::REML => {
452                // REML adjusts for fixed effects
453                // Compute X'V^{-1}X where V is diagonal
454                let inv_diag = V.diag().mapv(|v| 1.0 / v);
455                let X_scaled = X * &inv_diag.insert_axis(ndarray::Axis(1));
456                let XtVX = X.t().dot(&X_scaled);
457                let log_det_XtVX = XtVX.diag().iter().map(|&x| x.ln()).sum::<f64>();
458                log_lik - 0.5 * log_det_XtVX
459            }
460        }
461    }
462}
463
464/// Generalized Linear Mixed Model (GLMM) results
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct GLMMResults {
467    /// Fixed effects coefficients
468    pub fixed_effects: Array1<f64>,
469    /// Standard errors for fixed effects
470    pub fixed_se: Array1<f64>,
471    /// Random effects variance components
472    pub variance_components: Vec<(String, f64)>,
473    /// Scale parameter (dispersion)
474    pub scale: f64,
475    /// Log-likelihood (approximate)
476    pub log_lik: f64,
477    /// Akaike Information Criterion
478    pub aic: f64,
479    /// Bayesian Information Criterion
480    pub bic: f64,
481    /// Degrees of freedom for fixed effects
482    pub df_fixed: usize,
483    /// Number of observations
484    pub n_obs: usize,
485    /// Convergence status
486    pub converged: bool,
487    /// Number of iterations
488    pub iterations: usize,
489    /// Family used
490    pub family: Family,
491    /// Link function used
492    pub link: Link,
493}
494
495/// Generalized Linear Mixed Model builder
496pub struct GLMMBuilder {
497    data: DataFrame,
498    response: String,
499    fixed_formula: String,
500    random_effects: Vec<RandomEffect>,
501    family: Family,
502    link: Option<Link>,
503    method: GLMMEstimationMethod,
504    max_iter: usize,
505    tol: f64,
506}
507
508impl GLMMBuilder {
509    /// Create a new GLMM builder
510    pub fn new(data: DataFrame, response: &str, fixed_formula: &str, family: Family) -> Self {
511        Self {
512            data,
513            response: response.to_string(),
514            fixed_formula: fixed_formula.to_string(),
515            random_effects: Vec::new(),
516            family,
517            link: None,
518            method: GLMMEstimationMethod::PQL,
519            max_iter: 50,
520            tol: 1e-6,
521        }
522    }
523
524    /// Add a random effect
525    pub fn random_effect(mut self, group_var: &str, formula: &str) -> Self {
526        self.random_effects.push(RandomEffect {
527            group_var: group_var.to_string(),
528            formula: formula.to_string(),
529            covariance: RandomCovariance::Independent,
530        });
531        self
532    }
533
534    /// Set the link function (if None, uses family's default)
535    pub fn link(mut self, link: Link) -> Self {
536        self.link = Some(link);
537        self
538    }
539
540    /// Set estimation method
541    pub fn method(mut self, method: GLMMEstimationMethod) -> Self {
542        self.method = method;
543        self
544    }
545
546    /// Set maximum iterations
547    pub fn max_iterations(mut self, max_iter: usize) -> Self {
548        self.max_iter = max_iter;
549        self
550    }
551
552    /// Set convergence tolerance
553    pub fn tolerance(mut self, tol: f64) -> Self {
554        self.tol = tol;
555        self
556    }
557
558    /// Fit the GLMM using Penalized Quasi-Likelihood (PQL)
559    pub fn fit(self) -> Result<GLMMResults> {
560        // Determine link function
561        let link = self.link.unwrap_or_else(|| self.family.default_link());
562
563        // Extract response variable
564        let y = self.data.column(&self.response).ok_or_else(|| {
565            Error::DataError(format!("Response column '{}' not found", self.response))
566        })?;
567        let y_array = y.data().to_owned();
568
569        // For PQL, we need initial values from a GLM without random effects
570        let glmm_results = self.fit_pql(&y_array, link)?;
571        Ok(glmm_results)
572    }
573
574    /// Fit using Penalized Quasi-Likelihood (PQL) algorithm
575    #[allow(unused_assignments, unused_variables)]
576    fn fit_pql(&self, y: &Array1<f64>, link: Link) -> Result<GLMMResults> {
577        let n = y.len();
578
579        // Step 1: Fit a GLM without random effects to get initial estimates
580        let _glm_model = GLM::new()
581            .family(self.family)
582            .link(link)
583            .max_iter(self.max_iter)
584            .tol(self.tol)
585            .build();
586
587        // We need to fit the GLM - for simplicity, we'll use intercept-only model initially
588        // In practice, we should parse the fixed formula and build proper design matrix
589        let X = self.build_fixed_design_matrix()?;
590
591        // Initialize: fit GLM ignoring random effects
592        let mut eta = Array1::zeros(n); // linear predictor
593        let mut mu = Array1::zeros(n); // mean
594        let mut mu_eta = Array1::zeros(n); // derivative dμ/dη
595
596        // Initialize with simple values based on family
597        match self.family {
598            Family::Binomial => {
599                // For binary data, initialize with empirical logits
600                let y_mean = y.mean().unwrap_or(0.5);
601                let eps = 1e-4;
602                let y_clamped = y_mean.max(eps).min(1.0 - eps);
603                let init_eta = link.link(y_clamped);
604                eta.fill(init_eta);
605            }
606            Family::Poisson => {
607                let y_mean = y.mean().unwrap_or(1.0);
608                let init_eta = link.link(y_mean.max(1e-4));
609                eta.fill(init_eta);
610            }
611            _ => {
612                // Gaussian or other families
613                let y_mean = y.mean().unwrap_or(0.0);
614                let init_eta = link.link(y_mean);
615                eta.fill(init_eta);
616            }
617        }
618
619        // Update mu and mu_eta based on initial eta
620        for i in 0..n {
621            mu[i] = link.inverse_link(eta[i]);
622            mu_eta[i] = link.derivative(eta[i]);
623        }
624
625        // Build random effects design matrices (simplified)
626        let (Z_matrices, group_sizes) = self.build_random_design_matrices()?;
627        let Z = self.combine_Z_matrices(&Z_matrices, &group_sizes);
628        let q = Z.ncols();
629
630        // Initial variance components
631        let mut sigma2_e = 1.0; // residual/scale parameter
632        let mut sigma2_u = vec![1.0; Z_matrices.len()];
633
634        // Initial fixed effects (from intercept-only GLM)
635        let p = X.ncols();
636        let mut beta = Array1::zeros(p);
637        if p > 0 {
638            // Simple intercept estimate
639            beta[0] = eta.mean().unwrap_or(0.0);
640        }
641
642        let mut u = Array1::zeros(q);
643
644        let mut converged = false;
645        let mut iter = 0;
646
647        while !converged && iter < self.max_iter {
648            iter += 1;
649
650            // PQL iteration:
651            // 1. Compute working variable: y* = η + (y - μ) * (dη/dμ)
652            // where dη/dμ = 1 / (dμ/dη) = 1 / mu_eta
653            let mut y_star = Array1::zeros(n);
654            for i in 0..n {
655                let d_eta_d_mu = if mu_eta[i].abs() > 1e-10 {
656                    1.0 / mu_eta[i]
657                } else {
658                    1.0
659                };
660                y_star[i] = eta[i] + (y[i] - mu[i]) * d_eta_d_mu;
661            }
662
663            // 2. Compute weights: w = 1 / (V(μ) * (dη/dμ)^2)
664            // where V(μ) is variance function of the family
665            let mut weights = Array1::zeros(n);
666            for i in 0..n {
667                let d_eta_d_mu = if mu_eta[i].abs() > 1e-10 {
668                    1.0 / mu_eta[i]
669                } else {
670                    1.0
671                };
672                let v_mu = self.family.variance(mu[i]);
673                weights[i] = 1.0 / (v_mu * d_eta_d_mu * d_eta_d_mu);
674            }
675
676            // 3. Fit weighted LMM to y_star with weights
677            // This is simplified - we should implement proper weighted LMM
678            // For now, we'll use a simplified EM-like approach
679
680            // Update beta and u using weighted least squares analogy
681            let W_sqrt = weights.mapv(|w| w.sqrt());
682            let y_star_weighted = &y_star * &W_sqrt;
683            let X_weighted = &X * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
684            let Z_weighted = &Z * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
685
686            // Solve weighted mixed model equations (simplified)
687            // [X'WX  X'WZ] [β] = [X'Wy*]
688            // [Z'WX  Z'WZ + G^{-1}] [u]   [Z'Wy*]
689            // where G = diag(σ²_u_k I) for each random effect
690
691            let XtWX = X_weighted.t().dot(&X_weighted);
692            let ZtWZ = Z_weighted.t().dot(&Z_weighted);
693            let XtWZ = X_weighted.t().dot(&Z_weighted);
694            let ZtWX = Z_weighted.t().dot(&X_weighted);
695
696            let XtWy = X_weighted.t().dot(&y_star_weighted);
697            let ZtWy = Z_weighted.t().dot(&y_star_weighted);
698
699            // Build mixed model equations matrix
700            let total_cols = p + q;
701            let mut M = Array2::zeros((total_cols, total_cols));
702            let mut rhs = Array1::zeros(total_cols);
703
704            // Top-left: X'WX
705            M.slice_mut(ndarray::s![0..p, 0..p]).assign(&XtWX);
706            // Top-right: X'WZ
707            M.slice_mut(ndarray::s![0..p, p..]).assign(&XtWZ);
708            // Bottom-left: Z'WX
709            M.slice_mut(ndarray::s![p.., 0..p]).assign(&ZtWX);
710            // Bottom-right: Z'WZ + G^{-1}
711            let mut ZtWZ_plus_Ginv = ZtWZ.clone();
712
713            // Add G^{-1} to diagonal blocks
714            let mut col_offset = 0;
715            for (k, sigma2_u_k) in sigma2_u.iter().enumerate() {
716                let cols = group_sizes[k];
717                let g_inv = 1.0 / f64::max(*sigma2_u_k, 1e-10);
718                for i in 0..cols {
719                    let idx = col_offset + i;
720                    ZtWZ_plus_Ginv[(idx, idx)] += g_inv;
721                }
722                col_offset += cols;
723            }
724
725            M.slice_mut(ndarray::s![p.., p..]).assign(&ZtWZ_plus_Ginv);
726
727            // Right-hand side
728            rhs.slice_mut(ndarray::s![0..p]).assign(&XtWy);
729            rhs.slice_mut(ndarray::s![p..]).assign(&ZtWy);
730
731            // Solve mixed model equations
732            let solution = solve(&M, &rhs).map_err(|e| {
733                Error::LinearAlgebraError(format!("Failed to solve mixed model equations: {}", e))
734            })?;
735
736            let new_beta = solution.slice(ndarray::s![0..p]).to_owned();
737            let new_u = solution.slice(ndarray::s![p..]).to_owned();
738
739            // 4. Update linear predictor and mean
740            let new_eta = X.dot(&new_beta) + Z.dot(&new_u);
741
742            // Update mu and mu_eta
743            let mut new_mu = Array1::zeros(n);
744            let mut new_mu_eta = Array1::zeros(n);
745            for i in 0..n {
746                new_mu[i] = link.inverse_link(new_eta[i]);
747                new_mu_eta[i] = link.derivative(new_eta[i]);
748            }
749
750            // 5. Update variance components (simplified EM update)
751            let old_sigma2_e = sigma2_e;
752            let old_sigma2_u = sigma2_u.clone();
753
754            // Update random effect variances
755            col_offset = 0;
756            for (k, sigma2_u_k) in sigma2_u.iter_mut().enumerate() {
757                let cols = group_sizes[k];
758                let u_k = new_u.slice(ndarray::s![col_offset..col_offset + cols]);
759                let trace_term = 0.0; // Simplified - should compute trace of inverse
760                *sigma2_u_k = u_k.dot(&u_k) / (cols as f64 - trace_term).max(1.0);
761                col_offset += cols;
762            }
763
764            // Update scale parameter
765            let residuals = y - &new_mu;
766            let pearson_residuals =
767                residuals.mapv(|r| r * r / self.family.variance(new_mu[0]).max(1e-10));
768            sigma2_e = pearson_residuals.mean().unwrap_or(1.0);
769
770            // Check convergence
771            let beta_diff = (&new_beta - &beta)
772                .mapv(|x| x.abs())
773                .mean()
774                .unwrap_or(f64::INFINITY);
775            let eta_diff = (&new_eta - &eta)
776                .mapv(|x| x.abs())
777                .mean()
778                .unwrap_or(f64::INFINITY);
779
780            beta = new_beta;
781            u = new_u;
782            eta = new_eta;
783            mu = new_mu;
784            mu_eta = new_mu_eta;
785
786            converged = beta_diff < self.tol && eta_diff < self.tol;
787
788            // Also check variance component convergence
789            let sigma2_u_diff = sigma2_u
790                .iter()
791                .zip(&old_sigma2_u)
792                .map(|(new, old)| (new - old).abs() / old.max(1e-10))
793                .fold(0.0, f64::max);
794            let sigma2_e_diff = (sigma2_e - old_sigma2_e).abs() / old_sigma2_e.max(1e-10);
795
796            converged = converged && sigma2_u_diff < self.tol && sigma2_e_diff < self.tol;
797        }
798
799        // Compute approximate standard errors
800        // From final mixed model equations matrix inverse
801        // Recompute weights using final mu and mu_eta
802        let mut final_weights = Array1::zeros(n);
803        for i in 0..n {
804            let d_eta_d_mu = if mu_eta[i].abs() > 1e-10 {
805                1.0 / mu_eta[i]
806            } else {
807                1.0
808            };
809            let v_mu = self.family.variance(mu[i]);
810            final_weights[i] = 1.0 / (v_mu * d_eta_d_mu * d_eta_d_mu);
811        }
812        let W_sqrt = final_weights.mapv(|w| w.sqrt());
813        let X_weighted = &X * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
814        let Z_weighted = &Z * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
815
816        let XtWX = X_weighted.t().dot(&X_weighted);
817        let ZtWZ = Z_weighted.t().dot(&Z_weighted);
818        let XtWZ = X_weighted.t().dot(&Z_weighted);
819        let ZtWX = Z_weighted.t().dot(&X_weighted);
820
821        let total_cols = p + q;
822        let mut M = Array2::zeros((total_cols, total_cols));
823        M.slice_mut(ndarray::s![0..p, 0..p]).assign(&XtWX);
824        M.slice_mut(ndarray::s![0..p, p..]).assign(&XtWZ);
825        M.slice_mut(ndarray::s![p.., 0..p]).assign(&ZtWX);
826
827        let mut ZtWZ_plus_Ginv = ZtWZ.clone();
828        let mut col_offset = 0;
829        for (k, sigma2_u_k) in sigma2_u.iter().enumerate() {
830            let cols = group_sizes[k];
831            let g_inv = 1.0 / f64::max(*sigma2_u_k, 1e-10);
832            for i in 0..cols {
833                let idx = col_offset + i;
834                ZtWZ_plus_Ginv[(idx, idx)] += g_inv;
835            }
836            col_offset += cols;
837        }
838        M.slice_mut(ndarray::s![p.., p..]).assign(&ZtWZ_plus_Ginv);
839
840        let Minv = inv(&M).map_err(|e| {
841            Error::LinearAlgebraError(format!("Failed to invert mixed model matrix: {}", e))
842        })?;
843
844        let cov_beta = Minv.slice(ndarray::s![0..p, 0..p]).to_owned();
845        let fixed_se = cov_beta.diag().mapv(|x| x.sqrt());
846
847        // Compute approximate log-likelihood (quasi-likelihood)
848        let log_lik = self.approximate_log_lik(y, &mu, &eta, &final_weights, sigma2_e);
849
850        // Compute information criteria
851        let n_params = p + sigma2_u.len() + 1; // beta + variance components + scale
852        let aic = -2.0 * log_lik + 2.0 * n_params as f64;
853        let bic = -2.0 * log_lik + (n_params as f64) * (n as f64).ln();
854
855        // Prepare variance components with names
856        let mut var_comps = Vec::new();
857        for (i, random_effect) in self.random_effects.iter().enumerate() {
858            var_comps.push((random_effect.group_var.clone(), sigma2_u[i]));
859        }
860
861        Ok(GLMMResults {
862            fixed_effects: beta,
863            fixed_se,
864            variance_components: var_comps,
865            scale: sigma2_e,
866            log_lik,
867            aic,
868            bic,
869            df_fixed: p,
870            n_obs: n,
871            converged,
872            iterations: iter,
873            family: self.family,
874            link,
875        })
876    }
877
878    /// Build fixed effects design matrix
879    fn build_fixed_design_matrix(&self) -> Result<Array2<f64>> {
880        // Parse fixed effects formula
881        let formula_str = if self.fixed_formula.contains('~') {
882            self.fixed_formula.clone()
883        } else {
884            // If no response in formula (for random effects), add placeholder
885            format!("__response__ ~ {}", self.fixed_formula)
886        };
887
888        let formula = Formula::parse(&formula_str)
889            .map_err(|e| Error::FormulaError(format!("Failed to parse fixed formula: {}", e)))?;
890
891        // Build design matrix (includes intercept if specified)
892        formula
893            .build_matrix(&self.data)
894            .map_err(|e| Error::DataError(format!("Failed to build design matrix: {}", e)))
895    }
896
897    /// Build random effects design matrices (simplified)
898    fn build_random_design_matrices(&self) -> Result<(Vec<Array2<f64>>, Vec<usize>)> {
899        let mut Z_matrices = Vec::new();
900        let mut group_sizes = Vec::new();
901
902        for random_effect in &self.random_effects {
903            let group_col = self.data.column(&random_effect.group_var).ok_or_else(|| {
904                Error::DataError(format!(
905                    "Group column '{}' not found",
906                    random_effect.group_var
907                ))
908            })?;
909
910            // Simplified: create indicator matrix for groups
911            // Assuming groups are integer-coded 0..n_groups-1
912            let group_data = group_col.data();
913            let max_group = group_data
914                .iter()
915                .map(|&x| x as i64)
916                .max()
917                .unwrap_or(0)
918                .max(0) as usize;
919            let n_groups = max_group + 1;
920
921            let n = self.data.n_rows();
922            let mut Z = Array2::zeros((n, n_groups));
923
924            for j in 0..n {
925                let group_idx = group_data[j] as usize % n_groups.max(1);
926                if n_groups > 0 {
927                    Z[(j, group_idx)] = 1.0;
928                }
929            }
930
931            Z_matrices.push(Z);
932            group_sizes.push(n_groups);
933        }
934
935        Ok((Z_matrices, group_sizes))
936    }
937
938    /// Combine Z matrices into block-diagonal matrix
939    fn combine_Z_matrices(&self, Z_matrices: &[Array2<f64>], group_sizes: &[usize]) -> Array2<f64> {
940        let n = Z_matrices[0].nrows();
941        let total_cols: usize = group_sizes.iter().sum();
942
943        let mut Z = Array2::zeros((n, total_cols));
944        let mut col_offset = 0;
945
946        for (i, Z_i) in Z_matrices.iter().enumerate() {
947            let cols = group_sizes[i];
948            for row in 0..n {
949                for col in 0..cols {
950                    Z[(row, col_offset + col)] = Z_i[(row, col)];
951                }
952            }
953            col_offset += cols;
954        }
955
956        Z
957    }
958
959    /// Compute approximate log-likelihood (quasi-likelihood)
960    fn approximate_log_lik(
961        &self,
962        y: &Array1<f64>,
963        mu: &Array1<f64>,
964        _eta: &Array1<f64>,
965        _weights: &Array1<f64>,
966        scale: f64,
967    ) -> f64 {
968        let n = y.len() as f64;
969
970        // Quasi-likelihood approximation
971        let mut ql = 0.0;
972
973        for i in 0..y.len() {
974            // Contribution from quasi-likelihood
975            // This is simplified - proper implementation would integrate the variance function
976            let deviance = self.family.unit_deviance(y[i], mu[i]);
977            ql += -0.5 * deviance / scale;
978        }
979
980        // Add normalization constant approximation
981        ql - 0.5 * n * (2.0 * std::f64::consts::PI * scale).ln()
982    }
983}