1use ndarray::{Array1, Array2};
2use crate::base::{Estimator, Predictor, validate_features, validate_target};
3use crate::error::{Result, SklearnError};
4
5#[derive(Debug, Clone)]
7pub struct LinearRegression {
8 pub fit_intercept: bool,
9}
10
11impl Default for LinearRegression {
12 fn default() -> Self {
13 Self {
14 fit_intercept: true,
15 }
16 }
17}
18
19impl LinearRegression {
20 pub fn new(fit_intercept: bool) -> Self {
21 Self { fit_intercept }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct LinearModel {
28 pub coefficients: Array1<f64>,
29 pub intercept: f64,
30}
31
32impl LinearModel {
33 pub fn new(coefficients: Array1<f64>, intercept: f64) -> Self {
34 Self {
35 coefficients,
36 intercept,
37 }
38 }
39}
40
41impl Predictor for LinearModel {
42 type Input = Array2<f64>;
43 type Output = Array1<f64>;
44
45 fn predict(&self, x: &Self::Input) -> Result<Self::Output> {
46 validate_features(x)?;
47
48 if x.ncols() != self.coefficients.len() {
49 return Err(SklearnError::ShapeMismatch {
50 expected: format!("{} 个特征", self.coefficients.len()),
51 actual: format!("{} 个特征", x.ncols()),
52 });
53 }
54
55 let mut predictions = Array1::zeros(x.nrows());
56
57 for (i, row) in x.rows().into_iter().enumerate() {
58 let mut prediction = self.intercept;
59 for (j, &coef) in self.coefficients.iter().enumerate() {
60 prediction += coef * row[j];
61 }
62 predictions[i] = prediction;
63 }
64
65 Ok(predictions)
66 }
67}
68
69impl Estimator for LinearRegression {
70 type Input = Array2<f64>;
71 type Target = Array1<f64>;
72 type Model = LinearModel;
73
74 fn fit(&self, x: &Self::Input, y: &Self::Target) -> Result<Self::Model> {
75 validate_features(x)?;
76 validate_target(y)?;
77
78 if x.nrows() != y.len() {
79 return Err(SklearnError::ShapeMismatch {
80 expected: format!("{} 个样本", x.nrows()),
81 actual: format!("{} 个目标值", y.len()),
82 });
83 }
84
85 let n_samples = x.nrows();
87 let n_features = x.ncols();
88
89 if self.fit_intercept {
90 if n_samples <= n_features + 1 {
91 return self.fit_underdetermined_with_intercept(x, y);
92 }
93 self.fit_with_intercept(x, y)
94 } else {
95 if n_samples <= n_features {
96 return self.fit_underdetermined_without_intercept(x, y);
97 }
98 self.fit_without_intercept(x, y)
99 }
100 }
101}
102
103impl LinearRegression {
104 fn fit_with_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
105 let n_samples = x.nrows();
106 let n_features = x.ncols();
107
108 let mut design_matrix = Array2::zeros((n_samples, n_features + 1));
110 for i in 0..n_samples {
111 design_matrix[(i, 0)] = 1.0; for j in 0..n_features {
113 design_matrix[(i, j + 1)] = x[(i, j)];
114 }
115 }
116
117 let xt_x = self.matrix_transpose_dot(&design_matrix, &design_matrix);
119
120 let xt_y = self.matrix_transpose_dot_vector(&design_matrix, y);
122
123 let all_coefficients = self.solve_linear_system(&xt_x, &xt_y)?;
125
126 let intercept = all_coefficients[0];
128 let coefficients = all_coefficients.slice(ndarray::s![1..]).to_owned();
129
130 Ok(LinearModel::new(coefficients, intercept))
131 }
132
133 fn fit_without_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
134 let xt_x = self.matrix_transpose_dot(x, x);
136
137 let xt_y = self.matrix_transpose_dot_vector(x, y);
139
140 let coefficients = self.solve_linear_system(&xt_x, &xt_y)?;
142
143 Ok(LinearModel::new(coefficients, 0.0))
144 }
145
146 fn fit_underdetermined_with_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
148 let n_samples = x.nrows();
149 let n_features = x.ncols();
150
151 println!("警告: 欠定系统,样本数({}) <= 特征数({}) + 1", n_samples, n_features);
152 println!("使用最小范数解");
153
154 if n_samples == 1 {
157 let intercept = y[0];
158 let coefficients = Array1::zeros(n_features);
159 return Ok(LinearModel::new(coefficients, intercept));
160 }
161
162 let mut design_matrix = Array2::zeros((n_samples, n_features + 1));
164 for i in 0..n_samples {
165 design_matrix[(i, 0)] = 1.0;
166 for j in 0..n_features {
167 design_matrix[(i, j + 1)] = x[(i, j)];
168 }
169 }
170
171 let xt_x = self.matrix_transpose_dot(&design_matrix, &design_matrix);
172 let xt_y = self.matrix_transpose_dot_vector(&design_matrix, y);
173
174 let mut regularized_xt_x = xt_x.clone();
176 for i in 0..regularized_xt_x.nrows() {
177 regularized_xt_x[(i, i)] += 1e-8;
178 }
179
180 let all_coefficients = self.solve_linear_system(®ularized_xt_x, &xt_y)?;
181
182 let intercept = all_coefficients[0];
183 let coefficients = all_coefficients.slice(ndarray::s![1..]).to_owned();
184
185 Ok(LinearModel::new(coefficients, intercept))
186 }
187
188 fn fit_underdetermined_without_intercept(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<LinearModel> {
190 let n_samples = x.nrows();
191 let n_features = x.ncols();
192
193 println!("警告: 欠定系统,样本数({}) <= 特征数({})", n_samples, n_features);
194 println!("使用最小范数解");
195
196 if n_samples == 1 {
198 let mut coefficients = Array1::zeros(n_features);
199 if n_features > 0 && x[(0, 0)].abs() > 1e-10 {
201 coefficients[0] = y[0] / x[(0, 0)];
202 }
203 return Ok(LinearModel::new(coefficients, 0.0));
204 }
205
206 let xt_x = self.matrix_transpose_dot(x, x);
207 let xt_y = self.matrix_transpose_dot_vector(x, y);
208
209 let mut regularized_xt_x = xt_x.clone();
211 for i in 0..regularized_xt_x.nrows() {
212 regularized_xt_x[(i, i)] += 1e-8;
213 }
214
215 let coefficients = self.solve_linear_system(®ularized_xt_x, &xt_y)?;
216
217 Ok(LinearModel::new(coefficients, 0.0))
218 }
219
220 fn matrix_transpose_dot(&self, a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
222 let n = a.ncols(); let m = b.ncols(); let mut result = Array2::zeros((n, m));
225
226 for i in 0..n {
227 for j in 0..m {
228 let mut sum = 0.0;
229 for k in 0..a.nrows() {
230 sum += a[(k, i)] * b[(k, j)];
231 }
232 result[(i, j)] = sum;
233 }
234 }
235
236 result
237 }
238
239 fn matrix_transpose_dot_vector(&self, a: &Array2<f64>, v: &Array1<f64>) -> Array1<f64> {
241 let n = a.ncols();
242 let mut result = Array1::zeros(n);
243
244 for i in 0..n {
245 let mut sum = 0.0;
246 for k in 0..a.nrows() {
247 sum += a[(k, i)] * v[k];
248 }
249 result[i] = sum;
250 }
251
252 result
253 }
254
255 fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
256 let n = a.nrows();
257
258 let mut augmented = Array2::zeros((n, n + 1));
260
261 for i in 0..n {
263 for j in 0..n {
264 augmented[(i, j)] = a[(i, j)];
265 }
266 augmented[(i, n)] = b[i];
267 }
268
269 for i in 0..n {
271 let mut max_row = i;
273 for k in i + 1..n {
274 if augmented[(k, i)].abs() > augmented[(max_row, i)].abs() {
275 max_row = k;
276 }
277 }
278
279 if max_row != i {
281 for j in 0..=n {
282 let temp = augmented[(i, j)];
283 augmented[(i, j)] = augmented[(max_row, j)];
284 augmented[(max_row, j)] = temp;
285 }
286 }
287
288 if augmented[(i, i)].abs() < 1e-10 {
290 return Err(SklearnError::FitFailed {
291 reason: "矩阵奇异,无法求解".to_string(),
292 });
293 }
294
295 for k in i + 1..n {
297 let factor = augmented[(k, i)] / augmented[(i, i)];
298 for j in i..=n {
299 augmented[(k, j)] -= factor * augmented[(i, j)];
300 }
301 }
302 }
303
304 let mut x = Array1::zeros(n);
306 for i in (0..n).rev() {
307 x[i] = augmented[(i, n)];
308 for j in i + 1..n {
309 x[i] -= augmented[(i, j)] * x[j];
310 }
311 x[i] /= augmented[(i, i)];
312 }
313
314 Ok(x)
315 }
316}