rust_ml/model/core/
classification_model.rs

1use crate::bench::classification_metrics::ClassificationMetrics;
2use crate::core::error::ModelError;
3use crate::model::core::base::OptimizableModel;
4
5/// A trait for models that perform classification tasks.
6///
7/// This trait defines common evaluation metrics and functionality for classification models,
8/// allowing for standardized performance assessment across different implementations.
9///
10/// # Type Parameters
11/// * `Input`: The type of input data the model accepts
12/// * `Output`: The type of output data against which predictions are compared
13pub trait ClassificationModel<Input, Output>: OptimizableModel<Input, Output> {
14    /// Calculates the accuracy of the model on the given data.
15    ///
16    /// Accuracy is defined as the proportion of correct predictions among the total number of predictions.
17    ///
18    /// # Arguments
19    /// * `x` - The input data
20    /// * `y` - The expected output/ground truth
21    ///
22    /// # Returns
23    /// * `Result<f64, ModelError>` - The calculated accuracy score or an error
24    fn accuracy(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
25
26    /// Calculates the loss of the model on the given data.
27    ///
28    /// The specific loss function depends on the model implementation.
29    ///
30    /// # Arguments
31    /// * `x` - The input data
32    /// * `y` - The expected output/ground truth
33    ///
34    /// # Returns
35    /// * `Result<f64, ModelError>` - The calculated loss value or an error
36    fn loss(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
37
38    /// Calculates the recall score of the model on the given data.
39    ///
40    /// Recall (also known as sensitivity) measures the proportion of actual positives
41    /// that were correctly identified by the model.
42    ///
43    /// # Arguments
44    /// * `x` - The input data
45    /// * `y` - The expected output/ground truth
46    ///
47    /// # Returns
48    /// * `Result<f64, ModelError>` - The calculated recall score or an error
49    fn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
50
51    /// Calculates the F1 score of the model on the given data.
52    ///
53    /// F1 score is the harmonic mean of precision and recall, providing a balance
54    /// between the two metrics.
55    ///
56    /// # Arguments
57    /// * `x` - The input data
58    /// * `y` - The expected output/ground truth
59    ///
60    /// # Returns
61    /// * `Result<f64, ModelError>` - The calculated F1 score or an error
62    fn f1_score(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
63
64    /// Computes multiple evaluation metrics for the model on the given data.
65    ///
66    /// This method allows for efficient calculation of multiple metrics in a single pass
67    /// through the data, potentially optimizing performance when multiple metrics are needed.
68    ///
69    /// # Arguments
70    /// * `x` - The input data
71    /// * `y` - The expected output/ground truth
72    ///
73    /// # Returns
74    /// * `Result<ModelParams, ModelError>` - A collection of calculated metrics or an error
75    fn compute_metrics(&self, x: &Input, y: &Output) -> Result<ClassificationMetrics, ModelError>;
76}