Skip to main content

ClassificationModel

Trait ClassificationModel 

Source
pub trait ClassificationModel<Input, Output>: OptimizableModel<Input, Output> {
    // Required methods
    fn accuracy(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
    fn loss(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
    fn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
    fn f1_score(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
    fn compute_metrics(
        &self,
        x: &Input,
        y: &Output,
    ) -> Result<ClassificationMetrics, ModelError>;
}
Expand description

A trait for models that perform classification tasks.

This trait defines common evaluation metrics and functionality for classification models, allowing for standardized performance assessment across different implementations.

§Type Parameters

  • Input: The type of input data the model accepts
  • Output: The type of output data against which predictions are compared

Required Methods§

Source

fn accuracy(&self, x: &Input, y: &Output) -> Result<f64, ModelError>

Calculates the accuracy of the model on the given data.

Accuracy is defined as the proportion of correct predictions among the total number of predictions.

§Arguments
  • x - The input data
  • y - The expected output/ground truth
§Returns
  • Result<f64, ModelError> - The calculated accuracy score or an error
Source

fn loss(&self, x: &Input, y: &Output) -> Result<f64, ModelError>

Calculates the loss of the model on the given data.

The specific loss function depends on the model implementation.

§Arguments
  • x - The input data
  • y - The expected output/ground truth
§Returns
  • Result<f64, ModelError> - The calculated loss value or an error
Source

fn recall(&self, x: &Input, y: &Output) -> Result<f64, ModelError>

Calculates the recall score of the model on the given data.

Recall (also known as sensitivity) measures the proportion of actual positives that were correctly identified by the model.

§Arguments
  • x - The input data
  • y - The expected output/ground truth
§Returns
  • Result<f64, ModelError> - The calculated recall score or an error
Source

fn f1_score(&self, x: &Input, y: &Output) -> Result<f64, ModelError>

Calculates the F1 score of the model on the given data.

F1 score is the harmonic mean of precision and recall, providing a balance between the two metrics.

§Arguments
  • x - The input data
  • y - The expected output/ground truth
§Returns
  • Result<f64, ModelError> - The calculated F1 score or an error
Source

fn compute_metrics( &self, x: &Input, y: &Output, ) -> Result<ClassificationMetrics, ModelError>

Computes multiple evaluation metrics for the model on the given data.

This method allows for efficient calculation of multiple metrics in a single pass through the data, potentially optimizing performance when multiple metrics are needed.

§Arguments
  • x - The input data
  • y - The expected output/ground truth
§Returns
  • Result<ModelParams, ModelError> - A collection of calculated metrics or an error

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§

Source§

impl ClassificationModel<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>>> for LogisticRegression

Implementation of ClassificationModel trait for LogisticRegression