strafe_trait/
model.rs

1use std::error::Error;
2
3use nalgebra::DMatrix;
4use strafe_type::{Alpha64, ModelMatrix};
5
6use crate::{manipulator::Manipulator, Statistic, StatisticalEstimate};
7
8pub trait ModelBuilder {
9    type Model;
10    fn with_x(self, x: &ModelMatrix) -> Self;
11    fn with_y(self, y: &ModelMatrix) -> Self;
12    fn with_weights(self, weights: &ModelMatrix) -> Self;
13    fn with_alpha<A: Into<Alpha64>>(self, alpha: A) -> Self;
14    fn build(self) -> Self::Model;
15}
16
17pub trait Model {
18    fn get_x(&self) -> ModelMatrix;
19    fn get_x1(&self) -> ModelMatrix;
20    fn get_y(&self) -> ModelMatrix;
21    fn get_weights(&self) -> ModelMatrix;
22    fn get_intercept(&self) -> bool;
23    fn manipulator(
24        &self,
25    ) -> Result<
26        Box<dyn StatisticalEstimate<Manipulator<ModelMatrix, ModelMatrix>, ModelMatrix>>,
27        Box<dyn Error>,
28    >;
29    fn determination(&self) -> Result<Box<dyn Statistic>, Box<dyn Error>>;
30    fn parameters(
31        &self,
32    ) -> Result<Vec<Box<dyn StatisticalEstimate<f64, (f64, f64)>>>, Box<dyn Error>>;
33    fn predictions(
34        &self,
35    ) -> Result<Box<dyn StatisticalEstimate<ModelMatrix, ModelMatrix>>, Box<dyn Error>>;
36    fn residuals(&self) -> Result<ModelMatrix, Box<dyn Error>>;
37    fn studentized_residuals(&self) -> Result<ModelMatrix, Box<dyn Error>>;
38    fn standardized_residuals(&self) -> Result<ModelMatrix, Box<dyn Error>>;
39    fn variance(&self) -> Result<DMatrix<f64>, Box<dyn Error>>;
40    fn leverage(&self) -> Vec<f64> {
41        let mut x = self.get_x1().matrix();
42        let w = self.get_weights().matrix();
43        for (mut row, w) in x.row_iter_mut().zip(w.iter()) {
44            for x in row.iter_mut() {
45                *x *= w.sqrt();
46            }
47        }
48
49        (x.clone()
50            * (x.transpose() * x.clone())
51                .pseudo_inverse(f64::EPSILON)
52                .unwrap()
53            * x.transpose())
54        .diagonal()
55        .as_slice()
56        .to_vec()
57    }
58}