Skip to main content

so_models/glm/
family.rs

1//! Distribution families for Generalized Linear Models
2
3#![allow(non_snake_case)] // Allow mathematical notation
4
5use ndarray::Array1;
6use serde::{Deserialize, Serialize};
7use so_core::error::Result;
8use statrs::distribution::{Continuous, Normal};
9
10/// Distribution families for GLM
11#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum Family {
13    /// Gaussian (normal) distribution with identity link
14    Gaussian,
15    /// Binomial distribution for binary/count data
16    Binomial,
17    /// Poisson distribution for count data
18    Poisson,
19    /// Gamma distribution for positive continuous data
20    Gamma,
21    /// Inverse Gaussian distribution
22    InverseGaussian,
23}
24
25impl Family {
26    /// Get the default link function for this family
27    pub fn default_link(&self) -> Link {
28        match self {
29            Family::Gaussian => Link::Identity,
30            Family::Binomial => Link::Logit,
31            Family::Poisson => Link::Log,
32            Family::Gamma => Link::Inverse,
33            Family::InverseGaussian => Link::InverseSquare,
34        }
35    }
36
37    /// Compute variance function V(μ) for the mean μ
38    pub fn variance(&self, mu: f64) -> f64 {
39        match self {
40            Family::Gaussian => 1.0,
41            Family::Binomial => mu * (1.0 - mu),
42            Family::Poisson => mu,
43            Family::Gamma => mu.powi(2),
44            Family::InverseGaussian => mu.powi(3),
45        }
46    }
47
48    /// Compute the unit deviance d(y, μ) for a single observation
49    pub fn unit_deviance(&self, y: f64, mu: f64) -> f64 {
50        match self {
51            Family::Gaussian => (y - mu).powi(2),
52            Family::Binomial => {
53                if y == 0.0 {
54                    2.0 * (1.0 - mu).ln().max(-100.0)
55                } else if y == 1.0 {
56                    2.0 * mu.ln().max(-100.0)
57                } else {
58                    // For proportion data (0 < y < 1)
59                    2.0 * (y * (y / mu).ln().max(-100.0)
60                        + (1.0 - y) * ((1.0 - y) / (1.0 - mu)).ln().max(-100.0))
61                }
62            }
63            Family::Poisson => {
64                if mu == 0.0 {
65                    if y == 0.0 { 0.0 } else { 2.0 * y }
66                } else {
67                    2.0 * (y * (y / mu).ln().max(-100.0) - (y - mu))
68                }
69            }
70            Family::Gamma => 2.0 * ((y - mu) / mu - (y / mu).ln()),
71            Family::InverseGaussian => (y - mu).powi(2) / (mu.powi(2) * y),
72        }
73    }
74
75    /// Compute total deviance for a set of observations
76    pub fn deviance(&self, y: &Array1<f64>, mu: &Array1<f64>) -> f64 {
77        y.iter()
78            .zip(mu.iter())
79            .map(|(&y_val, &mu_val)| self.unit_deviance(y_val, mu_val))
80            .sum()
81    }
82
83    /// Compute initial values for the response variable
84    pub fn initialize(&self, y: &Array1<f64>) -> Array1<f64> {
85        match self {
86            Family::Gaussian => y.clone(),
87            Family::Binomial => {
88                // For binary data, apply logit transform with clipping
89                y.mapv(|y_val| {
90                    let clipped = y_val.max(0.0001).min(0.9999);
91                    (clipped / (1.0 - clipped)).ln()
92                })
93            }
94            Family::Poisson => {
95                // For Poisson, log transform with offset for zeros
96                y.mapv(|y_val| (y_val + 0.5).ln())
97            }
98            Family::Gamma => {
99                // For Gamma, log transform
100                y.mapv(|y_val| y_val.max(1e-8).ln())
101            }
102            Family::InverseGaussian => {
103                // For Inverse Gaussian, log transform
104                y.mapv(|y_val| y_val.max(1e-8).ln())
105            }
106        }
107    }
108
109    /// Check if response values are valid for this family
110    pub fn validate_response(&self, y: &Array1<f64>) -> Result<()> {
111        match self {
112            Family::Gaussian => Ok(()), // Any real value
113            Family::Binomial => {
114                // Check that values are in [0, 1]
115                for &val in y {
116                    if !(0.0..=1.0).contains(&val) {
117                        return Err(so_core::error::Error::DataError(format!(
118                            "Binomial response must be in [0, 1], got {}",
119                            val
120                        )));
121                    }
122                }
123                Ok(())
124            }
125            Family::Poisson => {
126                // Check that values are non-negative integers (or counts)
127                for &val in y {
128                    if val < 0.0 {
129                        return Err(so_core::error::Error::DataError(format!(
130                            "Poisson response must be non-negative, got {}",
131                            val
132                        )));
133                    }
134                }
135                Ok(())
136            }
137            Family::Gamma | Family::InverseGaussian => {
138                // Check that values are positive
139                for &val in y {
140                    if val <= 0.0 {
141                        return Err(so_core::error::Error::DataError(format!(
142                            "{} response must be positive, got {}",
143                            match self {
144                                Family::Gamma => "Gamma",
145                                Family::InverseGaussian => "Inverse Gaussian",
146                                _ => unreachable!(),
147                            },
148                            val
149                        )));
150                    }
151                }
152                Ok(())
153            }
154        }
155    }
156
157    /// Get the name of the family as a string
158    pub fn name(&self) -> &'static str {
159        match self {
160            Family::Gaussian => "Gaussian",
161            Family::Binomial => "Binomial",
162            Family::Poisson => "Poisson",
163            Family::Gamma => "Gamma",
164            Family::InverseGaussian => "Inverse Gaussian",
165        }
166    }
167
168    /// Compute the dispersion parameter (scale) from Pearson residuals
169    pub fn estimate_dispersion(
170        &self,
171        y: &Array1<f64>,
172        mu: &Array1<f64>,
173        n: usize,
174        p: usize,
175    ) -> f64 {
176        let pearson_residuals: f64 = y
177            .iter()
178            .zip(mu.iter())
179            .map(|(&y_val, &mu_val)| {
180                let variance = self.variance(mu_val);
181                if variance > 0.0 {
182                    (y_val - mu_val).powi(2) / variance
183                } else {
184                    0.0
185                }
186            })
187            .sum();
188
189        pearson_residuals / (n - p) as f64
190    }
191}
192
193/// Link functions for GLM
194#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
195pub enum Link {
196    /// Identity link: η = μ
197    Identity,
198    /// Logit link: η = log(μ / (1 - μ))
199    Logit,
200    /// Probit link: η = Φ⁻¹(μ)
201    Probit,
202    /// Complementary log-log link: η = log(-log(1 - μ))
203    Cloglog,
204    /// Log link: η = log(μ)
205    Log,
206    /// Inverse link: η = 1/μ
207    Inverse,
208    /// Inverse square link: η = 1/μ²
209    InverseSquare,
210    /// Square root link: η = √μ
211    Sqrt,
212}
213
214impl Link {
215    /// Apply link function: η = g(μ)
216    pub fn link(&self, mu: f64) -> f64 {
217        match self {
218            Link::Identity => mu,
219            Link::Logit => (mu / (1.0 - mu)).ln(),
220            Link::Probit => {
221                // Approximate inverse normal CDF
222                if mu <= 0.0 || mu >= 1.0 {
223                    f64::NAN
224                } else {
225                    statrs::function::erf::erf_inv(2.0 * mu - 1.0) * 2.0f64.sqrt()
226                }
227            }
228            Link::Cloglog => (-(1.0 - mu).ln()).ln(),
229            Link::Log => mu.ln(),
230            Link::Inverse => 1.0 / mu,
231            Link::InverseSquare => 1.0 / mu.powi(2),
232            Link::Sqrt => mu.sqrt(),
233        }
234    }
235
236    /// Apply inverse link: μ = g⁻¹(η)
237    pub fn inverse_link(&self, eta: f64) -> f64 {
238        match self {
239            Link::Identity => eta,
240            Link::Logit => 1.0 / (1.0 + (-eta).exp()),
241            Link::Probit => 0.5 * (1.0 + statrs::function::erf::erf(eta / 2.0f64.sqrt())),
242            Link::Cloglog => 1.0 - (-eta.exp()).exp(),
243            Link::Log => eta.exp(),
244            Link::Inverse => 1.0 / eta,
245            Link::InverseSquare => 1.0 / eta.sqrt(),
246            Link::Sqrt => eta.powi(2),
247        }
248    }
249
250    /// Derivative of inverse link: dμ/dη
251    pub fn derivative(&self, eta: f64) -> f64 {
252        match self {
253            Link::Identity => 1.0,
254            Link::Logit => {
255                let mu = self.inverse_link(eta);
256                mu * (1.0 - mu)
257            }
258            Link::Probit => {
259                // Derivative of inverse normal CDF is normal PDF
260                Normal::new(0.0, 1.0).unwrap().pdf(eta)
261            }
262            Link::Cloglog => {
263                let mu = self.inverse_link(eta);
264                (1.0 - mu) * (-(1.0 - mu).ln())
265            }
266            Link::Log => eta.exp(), // Same as inverse link for log
267            Link::Inverse => -1.0 / eta.powi(2),
268            Link::InverseSquare => -0.5 / eta.powf(-1.5),
269            Link::Sqrt => 2.0 * eta,
270        }
271    }
272
273    /// Get the name of the link function as a string
274    pub fn name(&self) -> &'static str {
275        match self {
276            Link::Identity => "identity",
277            Link::Logit => "logit",
278            Link::Probit => "probit",
279            Link::Cloglog => "cloglog",
280            Link::Log => "log",
281            Link::Inverse => "inverse",
282            Link::InverseSquare => "inverse square",
283            Link::Sqrt => "sqrt",
284        }
285    }
286}
287
288/// Check if a link-function combination is valid
289pub fn is_valid_link(family: Family, link: Link) -> bool {
290    match family {
291        Family::Gaussian => matches!(link, Link::Identity | Link::Log | Link::Inverse),
292        Family::Binomial => matches!(link, Link::Logit | Link::Probit | Link::Cloglog | Link::Log),
293        Family::Poisson => matches!(link, Link::Log | Link::Identity | Link::Sqrt),
294        Family::Gamma => matches!(link, Link::Inverse | Link::Log | Link::Identity),
295        Family::InverseGaussian => matches!(
296            link,
297            Link::InverseSquare | Link::Inverse | Link::Log | Link::Identity
298        ),
299    }
300}