1use std::{error::Error, marker::PhantomData};
2
3use crate::{
4 data::dataset::{Dataset, RealNumber, WholeNumber},
5 metrics::confusion::ClassificationMetrics,
6};
7use nalgebra::{DMatrix, DVector};
8
9#[derive(Clone, Debug)]
49pub struct LogisticRegression<XT: RealNumber, YT: WholeNumber> {
50 weights: DVector<XT>,
51
52 _marker: PhantomData<YT>,
53}
54
55impl<XT: RealNumber, YT: WholeNumber> ClassificationMetrics<YT> for LogisticRegression<XT, YT> {}
56
57impl<XT: RealNumber, YT: WholeNumber> Default for LogisticRegression<XT, YT> {
58 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl<XT: RealNumber, YT: WholeNumber> LogisticRegression<XT, YT> {
69 pub fn new() -> Self {
75 Self {
76 weights: DVector::<XT>::from_element(3, XT::from_f64(1.0).unwrap()),
77 _marker: PhantomData,
78 }
79 }
80
81 pub fn with_params(
96 dimension: Option<usize>,
97 weights: Option<DVector<XT>>,
98 ) -> Result<Self, Box<dyn Error>> {
99 match (dimension, &weights) {
100 (None, None) => Err("Please input the dimension or starting weights.".into()),
101
102 (Some(dim), Some(w)) if dim != w.len() - 1 => {
103 Err("The weights should be longer by 1 than the dimension to account for the bias weight.".into())
104 }
105 _ => Ok(Self {
106 weights: weights.unwrap_or_else(|| {
107 DVector::<XT>::from_element(dimension.unwrap() + 1, XT::from_f64(1.0).unwrap())
108 }),
109 _marker: PhantomData,
110 }),
111 }
112 }
113
114 pub fn predict(&self, x_pred: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
124 let x_pred_with_bias = x_pred.clone().insert_column(0, XT::from_f64(0.0).unwrap());
125
126 Ok(self.h(&x_pred_with_bias).map(|val| {
127 if val > XT::from_f64(0.5).unwrap() {
128 YT::from_usize(1).unwrap()
129 } else {
130 YT::from_usize(0).unwrap()
131 }
132 }))
133 }
134
135 pub fn fit(
153 &mut self,
154 dataset: &Dataset<XT, YT>,
155 lr: XT,
156 mut max_steps: usize,
157 epsilon: Option<XT>,
158 progress: Option<usize>,
159 ) -> Result<String, Box<dyn Error>> {
160 if progress.is_some_and(|steps| steps == 0) {
161 return Err(
162 "The number of steps for progress visualization must be greater than 0.".into(),
163 );
164 }
165 let (x, y) = dataset.into_parts();
166
167 let epsilon = epsilon.unwrap_or_else(|| XT::from_f64(1e-6).unwrap());
168 let initial_max_steps = max_steps;
169 let x_with_bias = x.clone().insert_column(0, XT::from_f64(1.0).unwrap());
170 while max_steps > 0 {
171 let weights_prev = self.weights.clone();
172
173 let gradient = self.gradient(&x_with_bias, y);
174
175 self.weights -= gradient * lr;
176
177 if progress.is_some_and(|steps| max_steps % steps == 0) {
178 println!("Step: {:?}", initial_max_steps - max_steps);
179 println!("Weights: {:?}", self.weights);
180 println!(
181 "Cross entropy: {:?}",
182 self.cross_entropy(&x_with_bias, y, false)
183 );
184 }
185
186 let delta = self
187 .weights
188 .iter()
189 .zip(weights_prev.iter())
190 .map(|(&w, &w_prev)| (w - w_prev) * (w - w_prev))
191 .fold(XT::from_f64(0.0).unwrap(), |acc, x| acc + x);
192
193 if delta < epsilon {
194 return Ok(format!(
195 "Finished training in {} steps.",
196 initial_max_steps - max_steps,
197 ));
198 }
199 max_steps -= 1;
200 }
201 Ok("Reached maximum steps without converging.".into())
202 }
203
204 pub fn weights(&self) -> &DVector<XT> {
205 &self.weights
206 }
207
208 fn gradient(&self, x: &DMatrix<XT>, y: &DVector<YT>) -> DVector<XT> {
209 let y_pred = self.h(x);
210
211 let y_xt_vec = y
212 .iter()
213 .map(|&y_i| XT::from(y_i).unwrap())
214 .collect::<Vec<_>>();
215
216 let y_xt = DVector::from_vec(y_xt_vec);
217 let errors = y_pred - y_xt;
218
219 x.transpose() * errors / XT::from_usize(y.len()).unwrap()
220 }
221
222 pub fn cross_entropy(
223 &self,
224 x: &DMatrix<XT>,
225 y: &DVector<YT>,
226 testing: bool,
227 ) -> Result<XT, Box<dyn Error>> {
228 let x = match testing {
229 true => x.clone().insert_column(0, XT::from_f64(0.0).unwrap()),
230 false => x.clone(),
231 };
232 let y_pred: DVector<XT> = self.h(&x);
233 let one = XT::from_f64(1.0).unwrap();
234
235 let cross_entropy = y
236 .iter()
237 .zip(y_pred.iter())
238 .map(|(&y_i, &y_pred_i)| {
239 let y_i_xt = XT::from(y_i).unwrap();
240 -y_i_xt * (y_pred_i + XT::from_f64(f64::EPSILON).unwrap()).ln()
241 - (one - y_i_xt) * (one - y_pred_i + XT::from_f64(f64::EPSILON).unwrap()).ln()
242 })
243 .fold(XT::from_f64(0.0).unwrap(), |acc, x| acc + x)
244 / XT::from_usize(y.len()).unwrap();
245
246 Ok(cross_entropy)
247 }
248
249 fn h(&self, x: &DMatrix<XT>) -> DVector<XT> {
250 let z = x * &self.weights;
251 z.map(|val| Self::sigmoid(val))
252 }
253
254 fn sigmoid(z: XT) -> XT {
255 let one = XT::from_f64(1.0).unwrap();
256
257 match z {
258 z if z < XT::from_f64(-10.0).unwrap() => XT::from_f64(0.0).unwrap(),
259 z if z > XT::from_f64(10.0).unwrap() => one,
260 _ => one / (one + (-z).exp()),
261 }
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_new() {
271 let model = LogisticRegression::<f64, u8>::default();
272 assert_eq!(model.weights().len(), 3);
273 assert!(model.weights().iter().all(|&w| w == 1.0));
274 }
275
276 #[test]
278 fn test_with_dimension() {
279 let model = LogisticRegression::<f64, u8>::with_params(Some(3), None);
280 assert!(model.is_ok());
281 assert_eq!(model.as_ref().unwrap().weights().len(), 4);
282 assert!(model.unwrap().weights().iter().all(|&w| w == 1.0));
283 }
284
285 #[test]
287 fn test_with_weights() {
288 let weights = DVector::from_vec(vec![1.0, 2.0, 3.0]);
289 let model = LogisticRegression::<f64, u8>::with_params(None, Some(weights.clone()));
290 assert!(model.is_ok());
291 assert_eq!(model.unwrap().weights, weights);
292 }
293
294 #[test]
295 fn test_with_params_nothing_provided() {
296 let model = LogisticRegression::<f64, u8>::with_params(None, None);
297 assert!(model.is_err());
298 }
299
300 #[test]
302 fn test_dimension_and_weights_provided_correct() {
303 let weights = DVector::from_vec(vec![0.5, -0.5, 1.0]);
304 let model = LogisticRegression::<f64, u8>::with_params(Some(2), Some(weights.clone()));
305 assert!(model.is_ok());
306 assert_eq!(model.unwrap().weights, weights);
307 }
308
309 #[test]
311 fn test_dimension_and_weights_provided_incorrect() {
312 let weights = DVector::from_vec(vec![0.5, -0.5]);
313 let model = LogisticRegression::<f64, u8>::with_params(Some(2), Some(weights));
314 assert!(model.is_err());
315 }
316
317 #[test]
318 fn test_h_function() {
319 let mut model = LogisticRegression::<f64, u8>::with_params(Some(2), None).unwrap();
320
321 model.weights = DVector::from_vec(vec![0.0, 0.5, -0.5]);
323
324 let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
326
327 let expected_sigmoid_values = DVector::from_vec(vec![
330 1.0 / (1.0 + f64::exp(0.5)), 1.0 / (1.0 + f64::exp(0.5)), ]);
333 let features_with_bias = features.clone().insert_column(0, 1.0);
334 let predictions = model.h(&features_with_bias);
336
337 for (predicted, expected) in predictions.iter().zip(expected_sigmoid_values.iter()) {
339 assert!((predicted - expected).abs() < f64::EPSILON);
340 }
341 }
342
343 #[test]
345 fn test_predict() {
346 let model = LogisticRegression::<f64, u8>::with_params(
347 None,
348 Some(DVector::from_vec(vec![0.0, 0.5, -0.5])),
349 )
350 .unwrap();
351
352 let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
353 let predictions = model.predict(&features).unwrap();
354
355 assert_eq!(predictions.len(), 2);
356 assert!(predictions.iter().all(|&p| p == 0 || p == 1));
357 }
358
359 #[test]
364 fn test_sigmoid_less_than_negative_ten() {
365 let value = LogisticRegression::<f64, u8>::sigmoid(-10.1);
366 assert_eq!(value, 0.0);
367 }
368
369 #[test]
370 fn test_sigmoid_zero() {
371 let value = LogisticRegression::<f64, u8>::sigmoid(0.0);
372 assert!((value - 0.5).abs() < f64::EPSILON);
373 }
374
375 #[test]
376 fn test_sigmoid_one() {
377 let value = LogisticRegression::<f64, u8>::sigmoid(1.0);
378 println!("{}", f64::EPSILON);
379 assert!((value - 0.7310585786300049).abs() < f64::EPSILON);
380 }
381
382 #[test]
383 fn test_sigmoid_over_ten() {
384 let value = LogisticRegression::<f64, u8>::sigmoid(10.1);
385 assert_eq!(value, 1.0);
386 }
387
388 #[test]
389 fn test_h() {
390 let model = LogisticRegression::<f64, u8>::with_params(
391 None,
392 Some(DVector::from_vec(vec![0.0, 0.5, -0.5])),
393 )
394 .unwrap();
395 let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 5.0]);
396 let features_with_bias = features.clone().insert_column(0, 1.0);
397 let value = model.h(&features_with_bias);
398
399 assert!((value[0] - 0.3775406687981454).abs() < f64::EPSILON);
400 assert!((value[1] - 0.2689414213699951).abs() < f64::EPSILON);
401 }
402
403 #[test]
405 fn test_cross_entropy() {
406 let model = LogisticRegression::<f64, u8>::with_params(
407 None,
408 Some(DVector::from_vec(vec![0.0, 0.5, -0.5])),
409 )
410 .unwrap();
411
412 let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
414 let labels = DVector::from_vec(vec![1, 0]);
415
416 let loss = model.cross_entropy(&features, &labels, true).unwrap();
418 let expected_loss = 0.7240769841801062;
420
421 assert!((loss - expected_loss).abs() < f64::EPSILON);
423 }
424
425 #[test]
426 fn test_gradient() {
427 let model = LogisticRegression::new();
429
430 let x = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
432 let y = DVector::from_vec(vec![0, 1]);
433
434 let gradient = model.gradient(&x, &y);
436 assert_eq!(gradient.shape(), (3, 1));
438 }
439
440 #[test]
441 fn test_fit_with_progress_set_to_zero() {
442 let mut model = LogisticRegression::<f64, u8>::new();
443
444 let x = DMatrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
446 let y = DVector::from_vec(vec![1, 2]);
447 let dataset = Dataset::new(x, y);
448
449 let lr = 0.1;
450 let max_steps = 100;
451 let epsilon = Some(0.0001);
452 let progress = Some(0);
453
454 let result = model.fit(&dataset, lr, max_steps, epsilon, progress);
455
456 assert!(result.is_err());
457 assert_eq!(
458 result.unwrap_err().to_string(),
459 "The number of steps for progress visualization must be greater than 0."
460 );
461 }
462
463 #[test]
464 fn test_fit() {
465 let mut logistic_regression = LogisticRegression::<f64, u8>::new();
466 let dataset = Dataset::new(
467 DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]),
468 DVector::from_vec(vec![0, 1]),
469 );
470 let result = logistic_regression.fit(&dataset, 0.1, 100, Some(1e-6), Some(50));
471 assert!(result.is_ok());
472 }
473}