1use crate::{
2 data::dataset::{Dataset, RealNumber},
3 metrics::errors::RegressionMetrics,
4};
5use nalgebra::{DMatrix, DVector};
6use std::error::Error;
7
8#[derive(Clone, Debug)]
48pub struct LinearRegression<T: RealNumber> {
49 weights: DVector<T>,
50}
51
52impl<T: RealNumber> RegressionMetrics<T> for LinearRegression<T> {}
53
54impl<T: RealNumber> Default for LinearRegression<T> {
55 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl<T: RealNumber> LinearRegression<T> {
64 pub fn new() -> Self {
68 Self {
69 weights: DVector::<T>::from_element(3, T::from_f64(1.0).unwrap()),
70 }
71 }
72
73 pub fn with_params(
90 dimension: Option<usize>,
91 weights: Option<DVector<T>>,
92 ) -> Result<Self, Box<dyn Error>> {
93 match (dimension, &weights) {
94 (None, None) => Err("Please input the dimension or starting weights.".into()),
95
96 (Some(dim), Some(w)) if dim != w.len() - 1 => {
97 Err("The weights should be longer by 1 than the dimension to account for the bias weight.".into())
98 }
99 _ => Ok(Self {
100 weights: weights.unwrap_or_else(|| {
101 DVector::<T>::from_element(dimension.unwrap() + 1, T::from_f64(1.0).unwrap())
102 }),
103 }),
104 }
105 }
106
107 pub fn weights(&self) -> &DVector<T> {
109 &self.weights
110 }
111
112 pub fn predict(&self, x_pred: &DMatrix<T>) -> Result<DVector<T>, Box<dyn Error>> {
122 let x_pred_with_bias = x_pred.clone().insert_column(0, T::from_f64(1.0).unwrap());
123 Ok(self.h(&x_pred_with_bias))
124 }
125
126 pub fn fit(
146 &mut self,
147 dataset: &Dataset<T, T>,
148 lr: T,
149 mut max_steps: usize,
150 epsilon: Option<T>,
151 progress: Option<usize>,
152 ) -> Result<String, Box<dyn Error>> {
153 if progress.is_some_and(|steps| steps == 0) {
154 return Err(
155 "The number of steps for progress visualization must be greater than 0.".into(),
156 );
157 }
158
159 let (x, y) = dataset.into_parts();
160
161 let epsilon = epsilon.unwrap_or_else(|| T::from_f64(1e-6).unwrap());
162 let initial_max_steps = max_steps;
163 let x_with_bias = x.clone().insert_column(0, T::from_f64(1.0).unwrap());
164 while max_steps > 0 {
165 let weights_prev = self.weights.clone();
166
167 let gradient = self.gradient(&x_with_bias, y);
168
169 if gradient.iter().any(|&g| g.is_nan()) {
170 return Err("Gradient turned to NaN during training.".into());
171 }
172
173 self.weights -= gradient * lr;
174
175 if progress.is_some_and(|steps| max_steps % steps == 0) {
176 println!("Step: {}", initial_max_steps - max_steps);
177 println!("Weights: {:?}", self.weights);
178 println!("MSE: {:?}", self.mse_training(&x_with_bias, y));
179 }
180
181 let delta = self
182 .weights
183 .iter()
184 .zip(weights_prev.iter())
185 .map(|(&w, &w_prev)| (w - w_prev) * (w - w_prev))
186 .fold(T::from_f64(0.0).unwrap(), |acc, x| acc + x);
187
188 if delta < epsilon {
189 return Ok(format!(
190 "Finished training in {} steps.",
191 initial_max_steps - max_steps,
192 ));
193 }
194 max_steps -= 1;
195 }
196 Ok("Reached maximum steps without converging.".into())
197 }
198
199 fn gradient(&self, x: &DMatrix<T>, y: &DVector<T>) -> DVector<T> {
200 let y_pred = self.h(x);
201
202 let errors = y_pred - y;
203
204 x.transpose() * errors * T::from_f64(2.0).unwrap() / T::from_usize(y.len()).unwrap()
205 }
206
207 fn h(&self, x: &DMatrix<T>) -> DVector<T> {
208 x * &self.weights
209 }
210
211 fn mse_training(&self, x: &DMatrix<T>, y: &DVector<T>) -> T {
212 let m = T::from_usize(y.len()).unwrap();
213 let y_pred = self.h(x);
214
215 let errors = y_pred - y;
216
217 let errors_sq = errors.component_mul(&errors);
218 errors_sq.sum() / m
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use approx::assert_relative_eq;
225
226 use super::*;
227
228 #[test]
229 fn test_new() {
230 let model = LinearRegression::<f32>::new();
231 assert_eq!(model.weights().len(), 3);
232 assert!(model.weights().iter().all(|&w| w == 1.0));
233 }
234
235 #[test]
236 fn test_with_params() {
237 let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
239 let model = LinearRegression::with_params(Some(2), Some(weights.clone()));
240 assert!(model.is_ok());
241 let model = model.unwrap();
242 assert_eq!(model.weights, weights);
243 }
244
245 #[test]
246 fn test_with_params_incorrect() {
247 let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
248 let model = LinearRegression::with_params(Some(4), Some(weights));
249 assert!(model.is_err());
250 }
251
252 #[test]
253 fn test_with_dimension() {
254 let model = LinearRegression::<f64>::with_params(Some(3), None);
255 assert!(model.is_ok());
256 assert_eq!(model.as_ref().unwrap().weights().len(), 4);
257 assert!(model.unwrap().weights().iter().all(|&w| w == 1.0));
258 }
259
260 #[test]
261 fn test_with_weights() {
262 let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
263 let model = LinearRegression::with_params(None, Some(weights.clone()));
264 assert!(model.is_ok());
265 assert_eq!(model.unwrap().weights, weights);
266 }
267
268 #[test]
269 fn test_with_nothing_provided() {
270 let model = LinearRegression::<f64>::with_params(None, None);
272 assert!(model.is_err());
273 }
274
275 #[test]
276 fn test_weights() {
277 let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
279 let model = LinearRegression::with_params(Some(2), Some(weights.clone())).unwrap();
280 let model_weights = model.weights();
281 assert_eq!(model_weights, &weights);
282 }
283
284 #[test]
285 fn test_predict() {
286 let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
287 let model = LinearRegression::with_params(None, Some(weights)).unwrap();
288 let x_pred = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
289 let prediction = model.predict(&x_pred);
290 assert!(prediction.is_ok());
291
292 let expected = DVector::from_vec(vec![9.0, 19.0]);
293 assert_eq!(prediction.unwrap(), expected);
294 }
295
296 #[test]
297 fn test_gradient() {
298 let model =
300 LinearRegression::<f64>::with_params(None, Some(DVector::from(vec![1.0, 2.0, 3.0])))
301 .unwrap();
302
303 let x = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
305 let y = DVector::from_vec(vec![7.0, 8.0]);
306 let x_with_bias = x.clone().insert_column(0, 1.0);
307
308 let gradient = model.gradient(&x_with_bias, &y);
310
311 let expected_gradient = DVector::from_vec(vec![13.0, 35.0, 48.0]);
313
314 assert_eq!(gradient, expected_gradient);
316 }
317
318 #[test]
319 fn test_mse_training() {
320 let model =
321 LinearRegression::<f64>::with_params(None, Some(DVector::from(vec![1.0, 2.0, 3.0])))
322 .unwrap();
323 let x = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
324 let y = DVector::from_vec(vec![7.0, 8.0]);
325
326 let x_with_bias = x.clone().insert_column(0, 1.0);
327
328 let mse = model.mse_training(&x_with_bias, &y);
329
330 assert_relative_eq!(mse, 62.5, epsilon = 1e-6);
331 }
332
333 #[test]
334 fn test_fit_with_progress_set_to_zero() {
335 let mut model = LinearRegression::<f64>::new();
336
337 let x = DMatrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
339 let y = DVector::from_vec(vec![1.0, 2.0]);
340 let dataset = Dataset::new(x, y);
341
342 let lr = 0.1;
343 let max_steps = 100;
344 let epsilon = Some(0.0001);
345 let progress = Some(0);
346
347 let result = model.fit(&dataset, lr, max_steps, epsilon, progress);
348
349 assert!(result.is_err());
350 assert_eq!(
351 result.unwrap_err().to_string(),
352 "The number of steps for progress visualization must be greater than 0."
353 );
354 }
355
356 #[test]
357 fn test_fit_no_convergence() {
358 let mut logistic_regression = LinearRegression::<f64>::new();
359 let dataset = Dataset::new(
360 DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]),
361 DVector::from_vec(vec![0.0, 1.0]),
362 );
363 let result = logistic_regression.fit(&dataset, 0.1, 100, Some(1e-6), None);
364 assert!(result.is_ok());
365 assert_eq!(
366 result.unwrap(),
367 "Reached maximum steps without converging.".to_string()
368 );
369 }
370
371 #[test]
372 fn test_fit_with_convergence() {
373 let mut logistic_regression = LinearRegression::<f64>::new();
374 let dataset = Dataset::new(
375 DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]),
376 DVector::from_vec(vec![0.0, 1.0]),
377 );
378 let result = logistic_regression.fit(&dataset, 0.01, 100, Some(1e-2), Some(1));
379 assert!(result.is_ok());
380 assert_eq!(result.unwrap(), "Finished training in 4 steps.".to_string());
381 }
382}