rusty_ai/regression/
logistic.rs

1use std::{error::Error, marker::PhantomData};
2
3use crate::{
4    data::dataset::{Dataset, RealNumber, WholeNumber},
5    metrics::confusion::ClassificationMetrics,
6};
7use nalgebra::{DMatrix, DVector};
8
9/// Logistic regression model for binary classification.
10///
11/// This struct represents a logistic regression model for binary classification. It uses the sigmoid function to map the input features to a probability between 0 and 1, and makes predictions based on a threshold of 0.5.
12///
13/// # Type Parameters
14///
15/// * `XT`: The type of the input features.
16/// * `YT`: The type of the target labels.
17///
18/// # Fields
19///
20/// * `weights`: The weights of the logistic regression model, with the first being the bias weight.
21/// * `_marker`: A marker field to indicate the target label type.
22///
23/// # Examples
24///
25/// ```
26/// use rusty_ai::regression::logistic::LogisticRegression;
27/// use rusty_ai::data::dataset::Dataset;
28/// use nalgebra::{DMatrix, DVector};
29///
30/// // Create a new logistic regression model
31/// let mut model: LogisticRegression<f64, u8> = LogisticRegression::new();
32///
33/// // Fit the model to a dataset
34/// let x = DMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
35/// let y = DVector::from_vec(vec![0, 1, 0]);
36/// let dataset = Dataset::new(x, y);
37/// let lr = 0.01;
38/// let max_steps = 1000;
39/// let epsilon = Some(0.001);
40/// let progress = Some(100);
41/// let result = model.fit(&dataset, lr, max_steps, epsilon, progress);
42///
43/// // Make predictions using the trained model
44/// let x_pred = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
45/// let predictions = model.predict(&x_pred);
46/// ```
47
48#[derive(Clone, Debug)]
49pub struct LogisticRegression<XT: RealNumber, YT: WholeNumber> {
50    weights: DVector<XT>,
51
52    _marker: PhantomData<YT>,
53}
54
55impl<XT: RealNumber, YT: WholeNumber> ClassificationMetrics<YT> for LogisticRegression<XT, YT> {}
56
57impl<XT: RealNumber, YT: WholeNumber> Default for LogisticRegression<XT, YT> {
58    /// Creates a new instance of `LogisticRegression` with default values.
59    ///
60    /// # Returns
61    ///
62    /// A new `LogisticRegression` instance.
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl<XT: RealNumber, YT: WholeNumber> LogisticRegression<XT, YT> {
69    /// Creates a new instance of `LogisticRegression` with default values.
70    ///
71    /// # Returns
72    ///
73    /// A new `LogisticRegression` instance.
74    pub fn new() -> Self {
75        Self {
76            weights: DVector::<XT>::from_element(3, XT::from_f64(1.0).unwrap()),
77            _marker: PhantomData,
78        }
79    }
80
81    /// Creates a new instance of `LogisticRegression` with custom parameters.
82    ///
83    /// # Parameters
84    ///
85    /// * `dimension`: The dimension of the input features. If `None`, it will be inferred from the starting weights.
86    /// * `weights`: The starting weights for the logistic regression model. If `None`, default weights will be used.
87    ///
88    /// # Returns
89    ///
90    /// A new `LogisticRegression` instance.
91    ///
92    /// # Errors
93    ///
94    /// An error is returned if the dimension and weights are incompatible.
95    pub fn with_params(
96        dimension: Option<usize>,
97        weights: Option<DVector<XT>>,
98    ) -> Result<Self, Box<dyn Error>> {
99        match (dimension, &weights) {
100            (None, None) => Err("Please input the dimension or starting weights.".into()),
101
102            (Some(dim), Some(w)) if dim != w.len() - 1 => {
103                Err("The weights should be longer by 1 than the dimension to account for the bias weight.".into())
104            }
105            _ => Ok(Self {
106                weights: weights.unwrap_or_else(|| {
107                    DVector::<XT>::from_element(dimension.unwrap() + 1, XT::from_f64(1.0).unwrap())
108                }),
109                _marker: PhantomData,
110            }),
111        }
112    }
113
114    /// Predicts the target labels for the given input features.
115    ///
116    /// # Parameters
117    ///
118    /// * `x_pred`: The input features to make predictions for.
119    ///
120    /// # Returns
121    ///
122    /// A `Result` containing the predicted target labels if successful, or an error message if an error occurs during prediction.
123    pub fn predict(&self, x_pred: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
124        let x_pred_with_bias = x_pred.clone().insert_column(0, XT::from_f64(0.0).unwrap());
125
126        Ok(self.h(&x_pred_with_bias).map(|val| {
127            if val > XT::from_f64(0.5).unwrap() {
128                YT::from_usize(1).unwrap()
129            } else {
130                YT::from_usize(0).unwrap()
131            }
132        }))
133    }
134
135    /// Fits the logistic regression model to a dataset.
136    ///
137    /// # Parameters
138    ///
139    /// * `dataset`: The dataset to fit the model to.
140    /// * `lr`: The learning rate for gradient descent.
141    /// * `max_steps`: The maximum number of steps for gradient descent.
142    /// * `epsilon`: The convergence threshold for gradient descent. If `None`, a default value is used.
143    /// * `progress`: The number of steps to display progress information. If `None`, no progress is displayed.
144    ///
145    /// # Returns
146    ///
147    /// A string indicating the result of the training process.
148    ///
149    /// # Errors
150    ///
151    /// An error is returned if the progress steps value is 0.
152    pub fn fit(
153        &mut self,
154        dataset: &Dataset<XT, YT>,
155        lr: XT,
156        mut max_steps: usize,
157        epsilon: Option<XT>,
158        progress: Option<usize>,
159    ) -> Result<String, Box<dyn Error>> {
160        if progress.is_some_and(|steps| steps == 0) {
161            return Err(
162                "The number of steps for progress visualization must be greater than 0.".into(),
163            );
164        }
165        let (x, y) = dataset.into_parts();
166
167        let epsilon = epsilon.unwrap_or_else(|| XT::from_f64(1e-6).unwrap());
168        let initial_max_steps = max_steps;
169        let x_with_bias = x.clone().insert_column(0, XT::from_f64(1.0).unwrap());
170        while max_steps > 0 {
171            let weights_prev = self.weights.clone();
172
173            let gradient = self.gradient(&x_with_bias, y);
174
175            self.weights -= gradient * lr;
176
177            if progress.is_some_and(|steps| max_steps % steps == 0) {
178                println!("Step: {:?}", initial_max_steps - max_steps);
179                println!("Weights: {:?}", self.weights);
180                println!(
181                    "Cross entropy: {:?}",
182                    self.cross_entropy(&x_with_bias, y, false)
183                );
184            }
185
186            let delta = self
187                .weights
188                .iter()
189                .zip(weights_prev.iter())
190                .map(|(&w, &w_prev)| (w - w_prev) * (w - w_prev))
191                .fold(XT::from_f64(0.0).unwrap(), |acc, x| acc + x);
192
193            if delta < epsilon {
194                return Ok(format!(
195                    "Finished training in {} steps.",
196                    initial_max_steps - max_steps,
197                ));
198            }
199            max_steps -= 1;
200        }
201        Ok("Reached maximum steps without converging.".into())
202    }
203
204    pub fn weights(&self) -> &DVector<XT> {
205        &self.weights
206    }
207
208    fn gradient(&self, x: &DMatrix<XT>, y: &DVector<YT>) -> DVector<XT> {
209        let y_pred = self.h(x);
210
211        let y_xt_vec = y
212            .iter()
213            .map(|&y_i| XT::from(y_i).unwrap())
214            .collect::<Vec<_>>();
215
216        let y_xt = DVector::from_vec(y_xt_vec);
217        let errors = y_pred - y_xt;
218
219        x.transpose() * errors / XT::from_usize(y.len()).unwrap()
220    }
221
222    pub fn cross_entropy(
223        &self,
224        x: &DMatrix<XT>,
225        y: &DVector<YT>,
226        testing: bool,
227    ) -> Result<XT, Box<dyn Error>> {
228        let x = match testing {
229            true => x.clone().insert_column(0, XT::from_f64(0.0).unwrap()),
230            false => x.clone(),
231        };
232        let y_pred: DVector<XT> = self.h(&x);
233        let one = XT::from_f64(1.0).unwrap();
234
235        let cross_entropy = y
236            .iter()
237            .zip(y_pred.iter())
238            .map(|(&y_i, &y_pred_i)| {
239                let y_i_xt = XT::from(y_i).unwrap();
240                -y_i_xt * (y_pred_i + XT::from_f64(f64::EPSILON).unwrap()).ln()
241                    - (one - y_i_xt) * (one - y_pred_i + XT::from_f64(f64::EPSILON).unwrap()).ln()
242            })
243            .fold(XT::from_f64(0.0).unwrap(), |acc, x| acc + x)
244            / XT::from_usize(y.len()).unwrap();
245
246        Ok(cross_entropy)
247    }
248
249    fn h(&self, x: &DMatrix<XT>) -> DVector<XT> {
250        let z = x * &self.weights;
251        z.map(|val| Self::sigmoid(val))
252    }
253
254    fn sigmoid(z: XT) -> XT {
255        let one = XT::from_f64(1.0).unwrap();
256
257        match z {
258            z if z < XT::from_f64(-10.0).unwrap() => XT::from_f64(0.0).unwrap(),
259            z if z > XT::from_f64(10.0).unwrap() => one,
260            _ => one / (one + (-z).exp()),
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_new() {
271        let model = LogisticRegression::<f64, u8>::default();
272        assert_eq!(model.weights().len(), 3);
273        assert!(model.weights().iter().all(|&w| w == 1.0));
274    }
275
276    // Test the creation of a new LogisticRegression model
277    #[test]
278    fn test_with_dimension() {
279        let model = LogisticRegression::<f64, u8>::with_params(Some(3), None);
280        assert!(model.is_ok());
281        assert_eq!(model.as_ref().unwrap().weights().len(), 4);
282        assert!(model.unwrap().weights().iter().all(|&w| w == 1.0));
283    }
284
285    // Test when only starting weights are provided
286    #[test]
287    fn test_with_weights() {
288        let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
289        let model = LogisticRegression::<f64, u8>::with_params(None, Some(weights.clone()));
290        assert!(model.is_ok());
291        assert_eq!(model.unwrap().weights, weights);
292    }
293
294    #[test]
295    fn test_with_params_nothing_provided() {
296        let model = LogisticRegression::<f64, u8>::with_params(None, None);
297        assert!(model.is_err());
298    }
299
300    // Test when both dimension and starting weights are provided correctly
301    #[test]
302    fn test_dimension_and_weights_provided_correct() {
303        let weights = DVector::from_vec(vec![0.5, -0.5, 1.0]);
304        let model = LogisticRegression::<f64, u8>::with_params(Some(2), Some(weights.clone()));
305        assert!(model.is_ok());
306        assert_eq!(model.unwrap().weights, weights);
307    }
308
309    // Test when both dimension and starting weights are provided incorrectly
310    #[test]
311    fn test_dimension_and_weights_provided_incorrect() {
312        let weights = DVector::from_vec(vec![0.5, -0.5]);
313        let model = LogisticRegression::<f64, u8>::with_params(Some(2), Some(weights));
314        assert!(model.is_err());
315    }
316
317    #[test]
318    fn test_h_function() {
319        let mut model = LogisticRegression::<f64, u8>::with_params(Some(2), None).unwrap();
320
321        // Set model weights to known values
322        model.weights = DVector::from_vec(vec![0.0, 0.5, -0.5]);
323
324        // Create features for testing
325        let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
326
327        // Expected sigmoid values for the given features and weights
328        // Sigmoid(0.5*1.0 - 0.5*2.0) and Sigmoid(0.5*3.0 - 0.5*4.0)
329        let expected_sigmoid_values = DVector::from_vec(vec![
330            1.0 / (1.0 + f64::exp(0.5)), // Sigmoid(0.5*1 - 0.5*2 + 0.0*bias)
331            1.0 / (1.0 + f64::exp(0.5)), // Sigmoid(0.5*3 - 0.5*4 + 0.0*bias)
332        ]);
333        let features_with_bias = features.clone().insert_column(0, 1.0);
334        // Compute predictions using the 'h' function
335        let predictions = model.h(&features_with_bias);
336
337        // Check if the computed predictions are close to the expected values
338        for (predicted, expected) in predictions.iter().zip(expected_sigmoid_values.iter()) {
339            assert!((predicted - expected).abs() < f64::EPSILON);
340        }
341    }
342
343    // Test the prediction functionality
344    #[test]
345    fn test_predict() {
346        let model = LogisticRegression::<f64, u8>::with_params(
347            None,
348            Some(DVector::from_vec(vec![0.0, 0.5, -0.5])),
349        )
350        .unwrap();
351
352        let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
353        let predictions = model.predict(&features).unwrap();
354
355        assert_eq!(predictions.len(), 2);
356        assert!(predictions.iter().all(|&p| p == 0 || p == 1));
357    }
358
359    // Add more tests for fit, weights update, gradient calculation, etc.
360
361    // Test sigmoid function
362
363    #[test]
364    fn test_sigmoid_less_than_negative_ten() {
365        let value = LogisticRegression::<f64, u8>::sigmoid(-10.1);
366        assert_eq!(value, 0.0);
367    }
368
369    #[test]
370    fn test_sigmoid_zero() {
371        let value = LogisticRegression::<f64, u8>::sigmoid(0.0);
372        assert!((value - 0.5).abs() < f64::EPSILON);
373    }
374
375    #[test]
376    fn test_sigmoid_one() {
377        let value = LogisticRegression::<f64, u8>::sigmoid(1.0);
378        println!("{}", f64::EPSILON);
379        assert!((value - 0.7310585786300049).abs() < f64::EPSILON);
380    }
381
382    #[test]
383    fn test_sigmoid_over_ten() {
384        let value = LogisticRegression::<f64, u8>::sigmoid(10.1);
385        assert_eq!(value, 1.0);
386    }
387
388    #[test]
389    fn test_h() {
390        let model = LogisticRegression::<f64, u8>::with_params(
391            None,
392            Some(DVector::from_vec(vec![0.0, 0.5, -0.5])),
393        )
394        .unwrap();
395        let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 5.0]);
396        let features_with_bias = features.clone().insert_column(0, 1.0);
397        let value = model.h(&features_with_bias);
398
399        assert!((value[0] - 0.3775406687981454).abs() < f64::EPSILON);
400        assert!((value[1] - 0.2689414213699951).abs() < f64::EPSILON);
401    }
402
403    // Test cross-entropy calculation
404    #[test]
405    fn test_cross_entropy() {
406        let model = LogisticRegression::<f64, u8>::with_params(
407            None,
408            Some(DVector::from_vec(vec![0.0, 0.5, -0.5])),
409        )
410        .unwrap();
411
412        // Create features and labels for testing
413        let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
414        let labels = DVector::from_vec(vec![1, 0]);
415
416        // Compute cross-entropy loss
417        let loss = model.cross_entropy(&features, &labels, true).unwrap();
418        // Expected loss value
419        let expected_loss = 0.7240769841801062;
420
421        // Check if the computed loss is close to the expected value
422        assert!((loss - expected_loss).abs() < f64::EPSILON);
423    }
424
425    #[test]
426    fn test_gradient() {
427        // Create a logistic regression model
428        let model = LogisticRegression::new();
429
430        // Create a test input matrix and labels
431        let x = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
432        let y = DVector::from_vec(vec![0, 1]);
433
434        // Calculate the gradient
435        let gradient = model.gradient(&x, &y);
436        // Assert the expected gradient shape
437        assert_eq!(gradient.shape(), (3, 1));
438    }
439
440    #[test]
441    fn test_fit_with_progress_set_to_zero() {
442        let mut model = LogisticRegression::<f64, u8>::new();
443
444        // Create a dummy dataset
445        let x = DMatrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
446        let y = DVector::from_vec(vec![1, 2]);
447        let dataset = Dataset::new(x, y);
448
449        let lr = 0.1;
450        let max_steps = 100;
451        let epsilon = Some(0.0001);
452        let progress = Some(0);
453
454        let result = model.fit(&dataset, lr, max_steps, epsilon, progress);
455
456        assert!(result.is_err());
457        assert_eq!(
458            result.unwrap_err().to_string(),
459            "The number of steps for progress visualization must be greater than 0."
460        );
461    }
462
463    #[test]
464    fn test_fit() {
465        let mut logistic_regression = LogisticRegression::<f64, u8>::new();
466        let dataset = Dataset::new(
467            DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]),
468            DVector::from_vec(vec![0, 1]),
469        );
470        let result = logistic_regression.fit(&dataset, 0.1, 100, Some(1e-6), Some(50));
471        assert!(result.is_ok());
472    }
473}