rust_ml/bench/
classification_profiler.rs

1use crate::bench::classification_metrics::ClassificationMetrics;
2use crate::bench::core::error::ProfilerError;
3use crate::bench::core::profiler::Profiler;
4use crate::bench::core::train_metrics::TrainMetrics;
5use crate::model::core::base::BaseModel;
6use crate::model::core::classification_model::ClassificationModel;
7use crate::optim::core::optimizer::Optimizer;
8use std::marker::PhantomData;
9use std::time::Instant;
10
11/// A profiler for classification models that measures training time and computes classification metrics.
12///
13/// This struct implements the `Profiler` trait specifically for classification models,
14/// providing performance assessment through metrics such as accuracy, precision, recall,
15/// and F1 score.
16///
17/// # Type Parameters
18///
19/// * `Model` - The classification model type being profiled
20/// * `Opt` - The optimizer type used for training
21/// * `Input` - The input data type (features)
22/// * `Output` - The output data type (labels/classes)
23pub struct ClassificationProfiler<Model, Opt, Input, Output> {
24    _phantom: std::marker::PhantomData<(Model, Opt, Input, Output)>,
25}
26
27impl<Model, Opt, Input, Output> ClassificationProfiler<Model, Opt, Input, Output> {
28    /// Creates a new ClassificationProfiler instance.
29    ///
30    /// # Returns
31    ///
32    /// A new ClassificationProfiler instance.
33    pub fn new() -> Self {
34        Self {
35            _phantom: PhantomData,
36        }
37    }
38}
39
40impl<Model, Opt, Input, Output> Default for ClassificationProfiler<Model, Opt, Input, Output> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<Model, Opt, Input, Output> Profiler<Model, Opt, Input, Output>
47    for ClassificationProfiler<Model, Opt, Input, Output>
48where
49    Model: BaseModel<Input, Output> + ClassificationModel<Input, Output>,
50    Opt: Optimizer<Input, Output, Model>,
51{
52    type EvalMetrics = ClassificationMetrics;
53
54    /// Profiles the training process of a classification model.
55    ///
56    /// Measures the time taken to train the model and computes classification metrics
57    /// on the provided data.
58    ///
59    /// # Arguments
60    ///
61    /// * `model` - Mutable reference to the classification model being trained
62    /// * `optimizer` - Mutable reference to the optimizer used for training
63    /// * `x` - Reference to input features
64    /// * `y` - Reference to target labels
65    ///
66    /// # Returns
67    ///
68    /// A tuple containing training metrics (including training time) and classification metrics
69    /// (accuracy, precision, recall, F1 score), or a ProfilerError if an error occurs.
70    fn train(
71        &self,
72        model: &mut Model,
73        optimizer: &mut Opt,
74        x: &Input,
75        y: &Output,
76    ) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError> {
77        let tick = Instant::now();
78        optimizer.fit(model, x, y)?;
79        let tock = Instant::now();
80
81        // Store elapsed time and create struct with training metrics.
82        let elapsed = tock.duration_since(tick).as_secs_f64();
83        let train_metrics = TrainMetrics::new(elapsed);
84
85        // Compute model evaluation metrics.
86        let eval_metrics = model.compute_metrics(x, y)?;
87
88        Ok((train_metrics, eval_metrics))
89    }
90
91    /// Profiles the evaluation process of a classification model.
92    ///
93    /// Computes classification metrics for the model on the provided data.
94    ///
95    /// # Arguments
96    ///
97    /// * `model` - Mutable reference to the classification model being evaluated
98    /// * `x` - Reference to input features
99    /// * `y` - Reference to target labels
100    ///
101    /// # Returns
102    ///
103    /// Classification metrics (accuracy, precision, recall, F1 score), or a ProfilerError if an error occurs.
104    fn profile_evaluation(
105        &self,
106        model: &mut Model,
107        x: &Input,
108        y: &Output,
109    ) -> Result<Self::EvalMetrics, ProfilerError> {
110        // Compute model evaluation metrics.
111        let eval_metrics = model.compute_metrics(x, y)?;
112        Ok(eval_metrics)
113    }
114}