rust_ml/model/core/classification_model.rs
1use crate::bench::classification_metrics::ClassificationMetrics;
2use crate::core::error::ModelError;
3use crate::model::core::base::OptimizableModel;
4
5/// A trait for models that perform classification tasks.
6///
7/// This trait defines common evaluation metrics and functionality for classification models,
8/// allowing for standardized performance assessment across different implementations.
9///
10/// # Type Parameters
11/// * `Input`: The type of input data the model accepts
12/// * `Output`: The type of output data against which predictions are compared
13pub trait ClassificationModel<Input, Output>: OptimizableModel<Input, Output> {
14 /// Calculates the accuracy of the model on the given data.
15 ///
16 /// Accuracy is defined as the proportion of correct predictions among the total number of predictions.
17 ///
18 /// # Arguments
19 /// * `x` - The input data
20 /// * `y` - The expected output/ground truth
21 ///
22 /// # Returns
23 /// * `Result<f64, ModelError>` - The calculated accuracy score or an error
24 fn accuracy(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
25
26 /// Calculates the loss of the model on the given data.
27 ///
28 /// The specific loss function depends on the model implementation.
29 ///
30 /// # Arguments
31 /// * `x` - The input data
32 /// * `y` - The expected output/ground truth
33 ///
34 /// # Returns
35 /// * `Result<f64, ModelError>` - The calculated loss value or an error
36 fn loss(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
37
38 /// Calculates the recall score of the model on the given data.
39 ///
40 /// Recall (also known as sensitivity) measures the proportion of actual positives
41 /// that were correctly identified by the model.
42 ///
43 /// # Arguments
44 /// * `x` - The input data
45 /// * `y` - The expected output/ground truth
46 ///
47 /// # Returns
48 /// * `Result<f64, ModelError>` - The calculated recall score or an error
49 fn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
50
51 /// Calculates the F1 score of the model on the given data.
52 ///
53 /// F1 score is the harmonic mean of precision and recall, providing a balance
54 /// between the two metrics.
55 ///
56 /// # Arguments
57 /// * `x` - The input data
58 /// * `y` - The expected output/ground truth
59 ///
60 /// # Returns
61 /// * `Result<f64, ModelError>` - The calculated F1 score or an error
62 fn f1_score(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
63
64 /// Computes multiple evaluation metrics for the model on the given data.
65 ///
66 /// This method allows for efficient calculation of multiple metrics in a single pass
67 /// through the data, potentially optimizing performance when multiple metrics are needed.
68 ///
69 /// # Arguments
70 /// * `x` - The input data
71 /// * `y` - The expected output/ground truth
72 ///
73 /// # Returns
74 /// * `Result<ModelParams, ModelError>` - A collection of calculated metrics or an error
75 fn compute_metrics(&self, x: &Input, y: &Output) -> Result<ClassificationMetrics, ModelError>;
76}