rust_ml/bench/classification_profiler.rs
1use crate::bench::classification_metrics::ClassificationMetrics;
2use crate::bench::core::error::ProfilerError;
3use crate::bench::core::profiler::Profiler;
4use crate::bench::core::train_metrics::TrainMetrics;
5use crate::model::core::base::BaseModel;
6use crate::model::core::classification_model::ClassificationModel;
7use crate::optim::core::optimizer::Optimizer;
8use std::marker::PhantomData;
9use std::time::Instant;
10
11/// A profiler for classification models that measures training time and computes classification metrics.
12///
13/// This struct implements the `Profiler` trait specifically for classification models,
14/// providing performance assessment through metrics such as accuracy, precision, recall,
15/// and F1 score.
16///
17/// # Type Parameters
18///
19/// * `Model` - The classification model type being profiled
20/// * `Opt` - The optimizer type used for training
21/// * `Input` - The input data type (features)
22/// * `Output` - The output data type (labels/classes)
23pub struct ClassificationProfiler<Model, Opt, Input, Output> {
24 _phantom: std::marker::PhantomData<(Model, Opt, Input, Output)>,
25}
26
27impl<Model, Opt, Input, Output> ClassificationProfiler<Model, Opt, Input, Output> {
28 /// Creates a new ClassificationProfiler instance.
29 ///
30 /// # Returns
31 ///
32 /// A new ClassificationProfiler instance.
33 pub fn new() -> Self {
34 Self {
35 _phantom: PhantomData,
36 }
37 }
38}
39
40impl<Model, Opt, Input, Output> Default for ClassificationProfiler<Model, Opt, Input, Output> {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl<Model, Opt, Input, Output> Profiler<Model, Opt, Input, Output>
47 for ClassificationProfiler<Model, Opt, Input, Output>
48where
49 Model: BaseModel<Input, Output> + ClassificationModel<Input, Output>,
50 Opt: Optimizer<Input, Output, Model>,
51{
52 type EvalMetrics = ClassificationMetrics;
53
54 /// Profiles the training process of a classification model.
55 ///
56 /// Measures the time taken to train the model and computes classification metrics
57 /// on the provided data.
58 ///
59 /// # Arguments
60 ///
61 /// * `model` - Mutable reference to the classification model being trained
62 /// * `optimizer` - Mutable reference to the optimizer used for training
63 /// * `x` - Reference to input features
64 /// * `y` - Reference to target labels
65 ///
66 /// # Returns
67 ///
68 /// A tuple containing training metrics (including training time) and classification metrics
69 /// (accuracy, precision, recall, F1 score), or a ProfilerError if an error occurs.
70 fn train(
71 &self,
72 model: &mut Model,
73 optimizer: &mut Opt,
74 x: &Input,
75 y: &Output,
76 ) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError> {
77 let tick = Instant::now();
78 optimizer.fit(model, x, y)?;
79 let tock = Instant::now();
80
81 // Store elapsed time and create struct with training metrics.
82 let elapsed = tock.duration_since(tick).as_secs_f64();
83 let train_metrics = TrainMetrics::new(elapsed);
84
85 // Compute model evaluation metrics.
86 let eval_metrics = model.compute_metrics(x, y)?;
87
88 Ok((train_metrics, eval_metrics))
89 }
90
91 /// Profiles the evaluation process of a classification model.
92 ///
93 /// Computes classification metrics for the model on the provided data.
94 ///
95 /// # Arguments
96 ///
97 /// * `model` - Mutable reference to the classification model being evaluated
98 /// * `x` - Reference to input features
99 /// * `y` - Reference to target labels
100 ///
101 /// # Returns
102 ///
103 /// Classification metrics (accuracy, precision, recall, F1 score), or a ProfilerError if an error occurs.
104 fn profile_evaluation(
105 &self,
106 model: &mut Model,
107 x: &Input,
108 y: &Output,
109 ) -> Result<Self::EvalMetrics, ProfilerError> {
110 // Compute model evaluation metrics.
111 let eval_metrics = model.compute_metrics(x, y)?;
112 Ok(eval_metrics)
113 }
114}