rusty_ai/regression/
linear.rs

1use crate::{
2    data::dataset::{Dataset, RealNumber},
3    metrics::errors::RegressionMetrics,
4};
5use nalgebra::{DMatrix, DVector};
6use std::error::Error;
7
8/// Represents a linear regression model.
9///
10/// The `LinearRegression` struct implements a linear regression model for predicting a target variable based on one or more input features.
11/// It uses the least squares method to estimate the weights of the linear model.
12///
13/// # Type Parameters
14///
15/// * `T`: The numeric type used for calculations. Must implement the `RealNumber` trait.
16///
17/// # Fields
18///
19/// * `weights`: The weights of the logistic regression model, with the first being the bias weight.
20///
21/// # Examples
22///
23/// ```
24/// use rusty_ai::regression::linear::LinearRegression;
25/// use rusty_ai::data::dataset::Dataset;
26/// use nalgebra::{DMatrix, DVector};
27///
28/// // Create a new linear regression model
29/// let mut model = LinearRegression::<f64>::new();
30///
31/// // Fit the model to a dataset
32/// let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
33/// let y = DVector::from_vec(vec![1.5, 2.5, 3.5]);
34/// let dataset = Dataset::new(x, y);
35/// let learning_rate = 0.01;
36/// let max_steps = 1000;
37/// let epsilon = Some(0.001);
38/// let progress = Some(100);
39/// let result = model.fit(&dataset, learning_rate, max_steps, epsilon, progress);
40///
41/// // Make predictions using the trained model
42/// let x_test = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
43/// let predictions = model.predict(&x_test);
44/// assert!(predictions.is_ok());
45/// ```
46
47#[derive(Clone, Debug)]
48pub struct LinearRegression<T: RealNumber> {
49    weights: DVector<T>,
50}
51
52impl<T: RealNumber> RegressionMetrics<T> for LinearRegression<T> {}
53
54impl<T: RealNumber> Default for LinearRegression<T> {
55    /// Creates a new `LinearRegression` model with default weights.
56    ///
57    /// The default weights are initialized to 1.0 for each feature, including the bias weight.
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl<T: RealNumber> LinearRegression<T> {
64    /// Creates a new `LinearRegression` model with default weights.
65    ///
66    /// The default weights are initialized to 1.0 for each feature, including the bias weight.
67    pub fn new() -> Self {
68        Self {
69            weights: DVector::<T>::from_element(3, T::from_f64(1.0).unwrap()),
70        }
71    }
72
73    /// Creates a new `LinearRegression` model with custom parameters.
74    ///
75    /// # Arguments
76    ///
77    /// * `dimension`: The dimension of the input features. If `None`, the dimension will be inferred from the provided weights.
78    /// * `weights`: The initial weights for the linear regression model. If `None`, default weights will be used.
79    ///
80    /// # Returns
81    ///
82    /// A `Result` containing the `LinearRegression` model if the parameters are valid, or an error message if the parameters are invalid.
83    ///
84    /// # Errors
85    ///
86    /// An error will be returned if:
87    /// * Both `dimension` and `weights` are `None`.
88    /// * The length of `weights` is not equal to `dimension + 1` to account for the bias weight.
89    pub fn with_params(
90        dimension: Option<usize>,
91        weights: Option<DVector<T>>,
92    ) -> Result<Self, Box<dyn Error>> {
93        match (dimension, &weights) {
94            (None, None) => Err("Please input the dimension or starting weights.".into()),
95
96            (Some(dim), Some(w)) if dim != w.len() - 1 => {
97                Err("The weights should be longer by 1 than the dimension to account for the bias weight.".into())
98            }
99            _ => Ok(Self {
100                weights: weights.unwrap_or_else(|| {
101                    DVector::<T>::from_element(dimension.unwrap() + 1, T::from_f64(1.0).unwrap())
102                }),
103            }),
104        }
105    }
106
107    /// A reference to the weights of the linear regression model.
108    pub fn weights(&self) -> &DVector<T> {
109        &self.weights
110    }
111
112    /// Makes predictions using the trained linear regression model.
113    ///
114    /// # Arguments
115    ///
116    /// * `x_pred`: The input features for which to make predictions.
117    ///
118    /// # Returns
119    ///
120    /// A `Result` containing the predicted target values if successful, or an error message if an error occurs during prediction.
121    pub fn predict(&self, x_pred: &DMatrix<T>) -> Result<DVector<T>, Box<dyn Error>> {
122        let x_pred_with_bias = x_pred.clone().insert_column(0, T::from_f64(1.0).unwrap());
123        Ok(self.h(&x_pred_with_bias))
124    }
125
126    /// Fits the linear regression model to a dataset.
127    ///
128    /// # Arguments
129    ///
130    /// * `dataset`: The dataset containing the input features and target values.
131    /// * `lr`: The learning rate for gradient descent.
132    /// * `max_steps`: The maximum number of steps to perform during training.
133    /// * `epsilon`: The convergence threshold. If the change in weights is below this threshold, training will stop.
134    /// * `progress`: The number of steps at which to display progress information. If `None`, no progress information will be displayed.
135    ///
136    /// # Returns
137    ///
138    /// A `Result` containing a success message if training is successful, or an error message if an error occurs during training.
139    ///
140    /// # Errors
141    ///
142    /// An error will be returned if:
143    /// * The number of steps for progress visualization is 0.
144    /// * The gradient turns to NaN during training.
145    pub fn fit(
146        &mut self,
147        dataset: &Dataset<T, T>,
148        lr: T,
149        mut max_steps: usize,
150        epsilon: Option<T>,
151        progress: Option<usize>,
152    ) -> Result<String, Box<dyn Error>> {
153        if progress.is_some_and(|steps| steps == 0) {
154            return Err(
155                "The number of steps for progress visualization must be greater than 0.".into(),
156            );
157        }
158
159        let (x, y) = dataset.into_parts();
160
161        let epsilon = epsilon.unwrap_or_else(|| T::from_f64(1e-6).unwrap());
162        let initial_max_steps = max_steps;
163        let x_with_bias = x.clone().insert_column(0, T::from_f64(1.0).unwrap());
164        while max_steps > 0 {
165            let weights_prev = self.weights.clone();
166
167            let gradient = self.gradient(&x_with_bias, y);
168
169            if gradient.iter().any(|&g| g.is_nan()) {
170                return Err("Gradient turned to NaN during training.".into());
171            }
172
173            self.weights -= gradient * lr;
174
175            if progress.is_some_and(|steps| max_steps % steps == 0) {
176                println!("Step: {}", initial_max_steps - max_steps);
177                println!("Weights: {:?}", self.weights);
178                println!("MSE: {:?}", self.mse_training(&x_with_bias, y));
179            }
180
181            let delta = self
182                .weights
183                .iter()
184                .zip(weights_prev.iter())
185                .map(|(&w, &w_prev)| (w - w_prev) * (w - w_prev))
186                .fold(T::from_f64(0.0).unwrap(), |acc, x| acc + x);
187
188            if delta < epsilon {
189                return Ok(format!(
190                    "Finished training in {} steps.",
191                    initial_max_steps - max_steps,
192                ));
193            }
194            max_steps -= 1;
195        }
196        Ok("Reached maximum steps without converging.".into())
197    }
198
199    fn gradient(&self, x: &DMatrix<T>, y: &DVector<T>) -> DVector<T> {
200        let y_pred = self.h(x);
201
202        let errors = y_pred - y;
203
204        x.transpose() * errors * T::from_f64(2.0).unwrap() / T::from_usize(y.len()).unwrap()
205    }
206
207    fn h(&self, x: &DMatrix<T>) -> DVector<T> {
208        x * &self.weights
209    }
210
211    fn mse_training(&self, x: &DMatrix<T>, y: &DVector<T>) -> T {
212        let m = T::from_usize(y.len()).unwrap();
213        let y_pred = self.h(x);
214
215        let errors = y_pred - y;
216
217        let errors_sq = errors.component_mul(&errors);
218        errors_sq.sum() / m
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use approx::assert_relative_eq;
225
226    use super::*;
227
228    #[test]
229    fn test_new() {
230        let model = LinearRegression::<f32>::new();
231        assert_eq!(model.weights().len(), 3);
232        assert!(model.weights().iter().all(|&w| w == 1.0));
233    }
234
235    #[test]
236    fn test_with_params() {
237        // Test with valid dimensions and weights
238        let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
239        let model = LinearRegression::with_params(Some(2), Some(weights.clone()));
240        assert!(model.is_ok());
241        let model = model.unwrap();
242        assert_eq!(model.weights, weights);
243    }
244
245    #[test]
246    fn test_with_params_incorrect() {
247        let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
248        let model = LinearRegression::with_params(Some(4), Some(weights));
249        assert!(model.is_err());
250    }
251
252    #[test]
253    fn test_with_dimension() {
254        let model = LinearRegression::<f64>::with_params(Some(3), None);
255        assert!(model.is_ok());
256        assert_eq!(model.as_ref().unwrap().weights().len(), 4);
257        assert!(model.unwrap().weights().iter().all(|&w| w == 1.0));
258    }
259
260    #[test]
261    fn test_with_weights() {
262        let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
263        let model = LinearRegression::with_params(None, Some(weights.clone()));
264        assert!(model.is_ok());
265        assert_eq!(model.unwrap().weights, weights);
266    }
267
268    #[test]
269    fn test_with_nothing_provided() {
270        // Test with no dimensions and no weights
271        let model = LinearRegression::<f64>::with_params(None, None);
272        assert!(model.is_err());
273    }
274
275    #[test]
276    fn test_weights() {
277        // Create a LinearRegression model with known weights
278        let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
279        let model = LinearRegression::with_params(Some(2), Some(weights.clone())).unwrap();
280        let model_weights = model.weights();
281        assert_eq!(model_weights, &weights);
282    }
283
284    #[test]
285    fn test_predict() {
286        let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
287        let model = LinearRegression::with_params(None, Some(weights)).unwrap();
288        let x_pred = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
289        let prediction = model.predict(&x_pred);
290        assert!(prediction.is_ok());
291
292        let expected = DVector::from_vec(vec![9.0, 19.0]);
293        assert_eq!(prediction.unwrap(), expected);
294    }
295
296    #[test]
297    fn test_gradient() {
298        // Create a LinearRegression instance
299        let model =
300            LinearRegression::<f64>::with_params(None, Some(DVector::from(vec![1.0, 2.0, 3.0])))
301                .unwrap();
302
303        // Create input matrix and target vector
304        let x = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
305        let y = DVector::from_vec(vec![7.0, 8.0]);
306        let x_with_bias = x.clone().insert_column(0, 1.0);
307
308        // Calculate the gradient
309        let gradient = model.gradient(&x_with_bias, &y);
310
311        // Define the expected gradient
312        let expected_gradient = DVector::from_vec(vec![13.0, 35.0, 48.0]);
313
314        // Check if the calculated gradient matches the expected gradient
315        assert_eq!(gradient, expected_gradient);
316    }
317
318    #[test]
319    fn test_mse_training() {
320        let model =
321            LinearRegression::<f64>::with_params(None, Some(DVector::from(vec![1.0, 2.0, 3.0])))
322                .unwrap();
323        let x = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
324        let y = DVector::from_vec(vec![7.0, 8.0]);
325
326        let x_with_bias = x.clone().insert_column(0, 1.0);
327
328        let mse = model.mse_training(&x_with_bias, &y);
329
330        assert_relative_eq!(mse, 62.5, epsilon = 1e-6);
331    }
332
333    #[test]
334    fn test_fit_with_progress_set_to_zero() {
335        let mut model = LinearRegression::<f64>::new();
336
337        // Create a dummy dataset
338        let x = DMatrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
339        let y = DVector::from_vec(vec![1.0, 2.0]);
340        let dataset = Dataset::new(x, y);
341
342        let lr = 0.1;
343        let max_steps = 100;
344        let epsilon = Some(0.0001);
345        let progress = Some(0);
346
347        let result = model.fit(&dataset, lr, max_steps, epsilon, progress);
348
349        assert!(result.is_err());
350        assert_eq!(
351            result.unwrap_err().to_string(),
352            "The number of steps for progress visualization must be greater than 0."
353        );
354    }
355
356    #[test]
357    fn test_fit_no_convergence() {
358        let mut logistic_regression = LinearRegression::<f64>::new();
359        let dataset = Dataset::new(
360            DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]),
361            DVector::from_vec(vec![0.0, 1.0]),
362        );
363        let result = logistic_regression.fit(&dataset, 0.1, 100, Some(1e-6), None);
364        assert!(result.is_ok());
365        assert_eq!(
366            result.unwrap(),
367            "Reached maximum steps without converging.".to_string()
368        );
369    }
370
371    #[test]
372    fn test_fit_with_convergence() {
373        let mut logistic_regression = LinearRegression::<f64>::new();
374        let dataset = Dataset::new(
375            DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]),
376            DVector::from_vec(vec![0.0, 1.0]),
377        );
378        let result = logistic_regression.fit(&dataset, 0.01, 100, Some(1e-2), Some(1));
379        assert!(result.is_ok());
380        assert_eq!(result.unwrap(), "Finished training in 4 steps.".to_string());
381    }
382}