pub trait ClassificationModel<Input, Output>: OptimizableModel<Input, Output> {
// Required methods
fn accuracy(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
fn loss(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
fn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
fn f1_score(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
fn compute_metrics(
&self,
x: &Input,
y: &Output,
) -> Result<ClassificationMetrics, ModelError>;
}Expand description
A trait for models that perform classification tasks.
This trait defines common evaluation metrics and functionality for classification models, allowing for standardized performance assessment across different implementations.
§Type Parameters
Input: The type of input data the model acceptsOutput: The type of output data against which predictions are compared
Required Methods§
Sourcefn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>
fn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>
Calculates the recall score of the model on the given data.
Recall (also known as sensitivity) measures the proportion of actual positives that were correctly identified by the model.
§Arguments
x- The input datay- The expected output/ground truth
§Returns
Result<f64, ModelError>- The calculated recall score or an error
Sourcefn compute_metrics(
&self,
x: &Input,
y: &Output,
) -> Result<ClassificationMetrics, ModelError>
fn compute_metrics( &self, x: &Input, y: &Output, ) -> Result<ClassificationMetrics, ModelError>
Computes multiple evaluation metrics for the model on the given data.
This method allows for efficient calculation of multiple metrics in a single pass through the data, potentially optimizing performance when multiple metrics are needed.
§Arguments
x- The input datay- The expected output/ground truth
§Returns
Result<ModelParams, ModelError>- A collection of calculated metrics or an error
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".