sklearn_rs/linear_model/
linear_regression.rs

1use ndarray::{Array1, Array2};
2use crate::base::{Estimator, Predictor, validate_features, validate_target};
3use crate::error::{Result, SklearnError};
4
5/// 线性回归模型,类似于 scikit-learn 的 LinearRegression
6#[derive(Debug, Clone)]
7pub struct LinearRegression {
8    pub fit_intercept: bool,
9}
10
11impl Default for LinearRegression {
12    fn default() -> Self {
13        Self {
14            fit_intercept: true,
15        }
16    }
17}
18
19impl LinearRegression {
20    pub fn new(fit_intercept: bool) -> Self {
21        Self { fit_intercept }
22    }
23}
24
25/// 训练好的线性模型
26#[derive(Debug, Clone)]
27pub struct LinearModel {
28    pub coefficients: Array1<f64>,
29    pub intercept: f64,
30}
31
32impl LinearModel {
33    pub fn new(coefficients: Array1<f64>, intercept: f64) -> Self {
34        Self {
35            coefficients,
36            intercept,
37        }
38    }
39}
40
41impl Predictor for LinearModel {
42    type Input = Array2<f64>;
43    type Output = Array1<f64>;
44    
45    fn predict(&self, x: &Self::Input) -> Result<Self::Output> {
46        validate_features(x)?;
47        
48        if x.ncols() != self.coefficients.len() {
49            return Err(SklearnError::ShapeMismatch {
50                expected: format!("{} 个特征", self.coefficients.len()),
51                actual: format!("{} 个特征", x.ncols()),
52            });
53        }
54        
55        let mut predictions = Array1::zeros(x.nrows());
56        
57        for (i, row) in x.rows().into_iter().enumerate() {
58            let mut prediction = self.intercept;
59            for (j, &coef) in self.coefficients.iter().enumerate() {
60                prediction += coef * row[j];
61            }
62            predictions[i] = prediction;
63        }
64        
65        Ok(predictions)
66    }
67}
68
69impl Estimator for LinearRegression {
70    type Input = Array2<f64>;
71    type Target = Array1<f64>;
72    type Model = LinearModel;
73    
74    fn fit(&self, x: &Self::Input, y: &Self::Target) -> Result<Self::Model> {
75        validate_features(x)?;
76        validate_target(y)?;
77        
78        if x.nrows() != y.len() {
79            return Err(SklearnError::ShapeMismatch {
80                expected: format!("{} 个样本", x.nrows()),
81                actual: format!("{} 个目标值", y.len()),
82            });
83        }
84        
85        // 检查样本数量是否足够
86        let n_samples = x.nrows();
87        let n_features = x.ncols();
88        
89        if self.fit_intercept {
90            if n_samples <= n_features + 1 {
91                return self.fit_underdetermined_with_intercept(x, y);
92            }
93            self.fit_with_intercept(x, y)
94        } else {
95            if n_samples <= n_features {
96                return self.fit_underdetermined_without_intercept(x, y);
97            }
98            self.fit_without_intercept(x, y)
99        }
100    }
101}
102
103impl LinearRegression {
104    fn fit_with_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
105        let n_samples = x.nrows();
106        let n_features = x.ncols();
107        
108        // 构建设计矩阵 [1, X]
109        let mut design_matrix = Array2::zeros((n_samples, n_features + 1));
110        for i in 0..n_samples {
111            design_matrix[(i, 0)] = 1.0; // 截距列
112            for j in 0..n_features {
113                design_matrix[(i, j + 1)] = x[(i, j)];
114            }
115        }
116        
117        // 计算 X^T X
118        let xt_x = self.matrix_transpose_dot(&design_matrix, &design_matrix);
119        
120        // 计算 X^T y
121        let xt_y = self.matrix_transpose_dot_vector(&design_matrix, y);
122        
123        // 解线性方程组 (X^T X) * coefficients = X^T y
124        let all_coefficients = self.solve_linear_system(&xt_x, &xt_y)?;
125        
126        // 分离截距和系数
127        let intercept = all_coefficients[0];
128        let coefficients = all_coefficients.slice(ndarray::s![1..]).to_owned();
129        
130        Ok(LinearModel::new(coefficients, intercept))
131    }
132    
133    fn fit_without_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
134        // 计算 X^T X
135        let xt_x = self.matrix_transpose_dot(x, x);
136        
137        // 计算 X^T y
138        let xt_y = self.matrix_transpose_dot_vector(x, y);
139        
140        // 解线性方程组
141        let coefficients = self.solve_linear_system(&xt_x, &xt_y)?;
142        
143        Ok(LinearModel::new(coefficients, 0.0))
144    }
145    
146    // 处理欠定系统(样本数 <= 特征数 + 1)
147    fn fit_underdetermined_with_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
148        let n_samples = x.nrows();
149        let n_features = x.ncols();
150        
151        println!("警告: 欠定系统,样本数({}) <= 特征数({}) + 1", n_samples, n_features);
152        println!("使用最小范数解");
153        
154        // 使用伪逆或最小范数解
155        // 这里简化处理:如果只有一个样本,直接返回平均值
156        if n_samples == 1 {
157            let intercept = y[0];
158            let coefficients = Array1::zeros(n_features);
159            return Ok(LinearModel::new(coefficients, intercept));
160        }
161        
162        // 对于其他欠定情况,使用正规方程但添加小的正则化项避免奇异
163        let mut design_matrix = Array2::zeros((n_samples, n_features + 1));
164        for i in 0..n_samples {
165            design_matrix[(i, 0)] = 1.0;
166            for j in 0..n_features {
167                design_matrix[(i, j + 1)] = x[(i, j)];
168            }
169        }
170        
171        let xt_x = self.matrix_transpose_dot(&design_matrix, &design_matrix);
172        let xt_y = self.matrix_transpose_dot_vector(&design_matrix, y);
173        
174        // 添加小的正则化项避免奇异
175        let mut regularized_xt_x = xt_x.clone();
176        for i in 0..regularized_xt_x.nrows() {
177            regularized_xt_x[(i, i)] += 1e-8;
178        }
179        
180        let all_coefficients = self.solve_linear_system(&regularized_xt_x, &xt_y)?;
181        
182        let intercept = all_coefficients[0];
183        let coefficients = all_coefficients.slice(ndarray::s![1..]).to_owned();
184        
185        Ok(LinearModel::new(coefficients, intercept))
186    }
187    
188    // 处理欠定系统(无截距)
189    fn fit_underdetermined_without_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
190        let n_samples = x.nrows();
191        let n_features = x.ncols();
192        
193        println!("警告: 欠定系统,样本数({}) <= 特征数({})", n_samples, n_features);
194        println!("使用最小范数解");
195        
196        // 如果只有一个样本,返回该样本对应的解
197        if n_samples == 1 {
198            let mut coefficients = Array1::zeros(n_features);
199            // 简单处理:使用第一个特征来拟合
200            if n_features > 0 && x[(0, 0)].abs() > 1e-10 {
201                coefficients[0] = y[0] / x[(0, 0)];
202            }
203            return Ok(LinearModel::new(coefficients, 0.0));
204        }
205        
206        let xt_x = self.matrix_transpose_dot(x, x);
207        let xt_y = self.matrix_transpose_dot_vector(x, y);
208        
209        // 添加小的正则化项避免奇异
210        let mut regularized_xt_x = xt_x.clone();
211        for i in 0..regularized_xt_x.nrows() {
212            regularized_xt_x[(i, i)] += 1e-8;
213        }
214        
215        let coefficients = self.solve_linear_system(&regularized_xt_x, &xt_y)?;
216        
217        Ok(LinearModel::new(coefficients, 0.0))
218    }
219    
220    // 手动实现矩阵转置乘法:A^T * B
221    fn matrix_transpose_dot(&self, a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
222        let n = a.ncols(); // A^T 的行数
223        let m = b.ncols(); // B 的列数
224        let mut result = Array2::zeros((n, m));
225        
226        for i in 0..n {
227            for j in 0..m {
228                let mut sum = 0.0;
229                for k in 0..a.nrows() {
230                    sum += a[(k, i)] * b[(k, j)];
231                }
232                result[(i, j)] = sum;
233            }
234        }
235        
236        result
237    }
238    
239    // 手动实现矩阵转置乘向量:A^T * v
240    fn matrix_transpose_dot_vector(&self, a: &Array2<f64>, v: &Array1<f64>) -> Array1<f64> {
241        let n = a.ncols();
242        let mut result = Array1::zeros(n);
243        
244        for i in 0..n {
245            let mut sum = 0.0;
246            for k in 0..a.nrows() {
247                sum += a[(k, i)] * v[k];
248            }
249            result[i] = sum;
250        }
251        
252        result
253    }
254    
255    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
256        let n = a.nrows();
257        
258        // 使用高斯消元法求解线性方程组
259        let mut augmented = Array2::zeros((n, n + 1));
260        
261        // 构建增广矩阵 [A | b]
262        for i in 0..n {
263            for j in 0..n {
264                augmented[(i, j)] = a[(i, j)];
265            }
266            augmented[(i, n)] = b[i];
267        }
268        
269        // 高斯消元法
270        for i in 0..n {
271            // 寻找主元
272            let mut max_row = i;
273            for k in i + 1..n {
274                if augmented[(k, i)].abs() > augmented[(max_row, i)].abs() {
275                    max_row = k;
276                }
277            }
278            
279            // 交换行
280            if max_row != i {
281                for j in 0..=n {
282                    let temp = augmented[(i, j)];
283                    augmented[(i, j)] = augmented[(max_row, j)];
284                    augmented[(max_row, j)] = temp;
285                }
286            }
287            
288            // 主元为0,矩阵奇异
289            if augmented[(i, i)].abs() < 1e-10 {
290                return Err(SklearnError::FitFailed {
291                    reason: "矩阵奇异,无法求解".to_string(),
292                });
293            }
294            
295            // 消元
296            for k in i + 1..n {
297                let factor = augmented[(k, i)] / augmented[(i, i)];
298                for j in i..=n {
299                    augmented[(k, j)] -= factor * augmented[(i, j)];
300                }
301            }
302        }
303        
304        // 回代
305        let mut x = Array1::zeros(n);
306        for i in (0..n).rev() {
307            x[i] = augmented[(i, n)];
308            for j in i + 1..n {
309                x[i] -= augmented[(i, j)] * x[j];
310            }
311            x[i] /= augmented[(i, i)];
312        }
313        
314        Ok(x)
315    }
316}