rust_ml/bench/regression_profiler.rs
1use crate::bench::core::error::ProfilerError;
2use crate::bench::core::profiler::Profiler;
3use crate::bench::core::train_metrics::TrainMetrics;
4use crate::bench::regression_metrics::RegressionMetrics;
5use crate::model::core::base::BaseModel;
6use crate::model::core::regression_model::RegressionModel;
7use crate::optim::core::optimizer::Optimizer;
8use std::marker::PhantomData;
9use std::time::Instant;
10
11/// A profiler for regression models that measures training time and computes regression metrics.
12///
13/// This struct implements the `Profiler` trait specifically for regression models,
14/// providing performance assessment through metrics such as Mean Squared Error (MSE),
15/// Root Mean Squared Error (RMSE), and coefficient of determination (R²).
16///
17/// # Type Parameters
18///
19/// * `Model` - The regression model type being profiled
20/// * `Opt` - The optimizer type used for training
21/// * `Input` - The input data type (features)
22/// * `Output` - The output data type (target values)
23pub struct RegressionProfiler<Model, Opt, Input, Output> {
24 _phantom: PhantomData<(Model, Opt, Input, Output)>,
25}
26
27impl<Model, Opt, Input, Output> RegressionProfiler<Model, Opt, Input, Output> {
28 /// Creates a new RegressionProfiler instance.
29 ///
30 /// # Returns
31 ///
32 /// A new RegressionProfiler instance.
33 pub fn new() -> Self {
34 Self {
35 _phantom: PhantomData,
36 }
37 }
38}
39
40impl<Model, Opt, Input, Output> Default for RegressionProfiler<Model, Opt, Input, Output> {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl<Model, Opt, Input, Output> Profiler<Model, Opt, Input, Output>
47 for RegressionProfiler<Model, Opt, Input, Output>
48where
49 Model: BaseModel<Input, Output> + RegressionModel<Input, Output>,
50 Opt: Optimizer<Input, Output, Model>,
51{
52 type EvalMetrics = RegressionMetrics;
53
54 /// Profiles the training process of a regression model.
55 ///
56 /// Measures the time taken to train the model and computes regression metrics
57 /// on the provided data.
58 ///
59 /// # Arguments
60 ///
61 /// * `model` - Mutable reference to the regression model being trained
62 /// * `optimizer` - Mutable reference to the optimizer used for training
63 /// * `x` - Reference to input features
64 /// * `y` - Reference to target values
65 ///
66 /// # Returns
67 ///
68 /// A tuple containing training metrics (including training time) and regression metrics
69 /// (MSE, RMSE, R²), or a ProfilerError if an error occurs.
70 fn train(
71 &self,
72 model: &mut Model,
73 optimizer: &mut Opt,
74 x: &Input,
75 y: &Output,
76 ) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError> {
77 // Train model and measure training time.
78 let tick = Instant::now();
79 optimizer.fit(model, x, y)?;
80 let tock = Instant::now();
81
82 // Store elapsed time and create struct with training metrics.
83 let elapsed = tock.duration_since(tick).as_secs_f64();
84 let train_metrics = TrainMetrics::new(elapsed);
85
86 // Compute model evaluation metrics.
87 let eval_metrics = model.compute_metrics(x, y)?;
88
89 Ok((train_metrics, eval_metrics))
90 }
91
92 /// Profiles the evaluation process of a regression model.
93 ///
94 /// Computes regression metrics for the model on the provided data.
95 ///
96 /// # Arguments
97 ///
98 /// * `model` - Mutable reference to the regression model being evaluated
99 /// * `x` - Reference to input features
100 /// * `y` - Reference to target values
101 ///
102 /// # Returns
103 ///
104 /// Regression metrics (MSE, RMSE, R²), or a ProfilerError if an error occurs.
105 fn profile_evaluation(
106 &self,
107 model: &mut Model,
108 x: &Input,
109 y: &Output,
110 ) -> Result<Self::EvalMetrics, ProfilerError> {
111 Ok(model.compute_metrics(x, y)?)
112 }
113}