rust_ml/bench/
regression_profiler.rs

1use crate::bench::core::error::ProfilerError;
2use crate::bench::core::profiler::Profiler;
3use crate::bench::core::train_metrics::TrainMetrics;
4use crate::bench::regression_metrics::RegressionMetrics;
5use crate::model::core::base::BaseModel;
6use crate::model::core::regression_model::RegressionModel;
7use crate::optim::core::optimizer::Optimizer;
8use std::marker::PhantomData;
9use std::time::Instant;
10
11/// A profiler for regression models that measures training time and computes regression metrics.
12///
13/// This struct implements the `Profiler` trait specifically for regression models,
14/// providing performance assessment through metrics such as Mean Squared Error (MSE),
15/// Root Mean Squared Error (RMSE), and coefficient of determination (R²).
16///
17/// # Type Parameters
18///
19/// * `Model` - The regression model type being profiled
20/// * `Opt` - The optimizer type used for training
21/// * `Input` - The input data type (features)
22/// * `Output` - The output data type (target values)
23pub struct RegressionProfiler<Model, Opt, Input, Output> {
24    _phantom: PhantomData<(Model, Opt, Input, Output)>,
25}
26
27impl<Model, Opt, Input, Output> RegressionProfiler<Model, Opt, Input, Output> {
28    /// Creates a new RegressionProfiler instance.
29    ///
30    /// # Returns
31    ///
32    /// A new RegressionProfiler instance.
33    pub fn new() -> Self {
34        Self {
35            _phantom: PhantomData,
36        }
37    }
38}
39
40impl<Model, Opt, Input, Output> Default for RegressionProfiler<Model, Opt, Input, Output> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<Model, Opt, Input, Output> Profiler<Model, Opt, Input, Output>
47    for RegressionProfiler<Model, Opt, Input, Output>
48where
49    Model: BaseModel<Input, Output> + RegressionModel<Input, Output>,
50    Opt: Optimizer<Input, Output, Model>,
51{
52    type EvalMetrics = RegressionMetrics;
53
54    /// Profiles the training process of a regression model.
55    ///
56    /// Measures the time taken to train the model and computes regression metrics
57    /// on the provided data.
58    ///
59    /// # Arguments
60    ///
61    /// * `model` - Mutable reference to the regression model being trained
62    /// * `optimizer` - Mutable reference to the optimizer used for training
63    /// * `x` - Reference to input features
64    /// * `y` - Reference to target values
65    ///
66    /// # Returns
67    ///
68    /// A tuple containing training metrics (including training time) and regression metrics
69    /// (MSE, RMSE, R²), or a ProfilerError if an error occurs.
70    fn train(
71        &self,
72        model: &mut Model,
73        optimizer: &mut Opt,
74        x: &Input,
75        y: &Output,
76    ) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError> {
77        // Train model and measure training time.
78        let tick = Instant::now();
79        optimizer.fit(model, x, y)?;
80        let tock = Instant::now();
81
82        // Store elapsed time and create struct with training metrics.
83        let elapsed = tock.duration_since(tick).as_secs_f64();
84        let train_metrics = TrainMetrics::new(elapsed);
85
86        // Compute model evaluation metrics.
87        let eval_metrics = model.compute_metrics(x, y)?;
88
89        Ok((train_metrics, eval_metrics))
90    }
91
92    /// Profiles the evaluation process of a regression model.
93    ///
94    /// Computes regression metrics for the model on the provided data.
95    ///
96    /// # Arguments
97    ///
98    /// * `model` - Mutable reference to the regression model being evaluated
99    /// * `x` - Reference to input features
100    /// * `y` - Reference to target values
101    ///
102    /// # Returns
103    ///
104    /// Regression metrics (MSE, RMSE, R²), or a ProfilerError if an error occurs.
105    fn profile_evaluation(
106        &self,
107        model: &mut Model,
108        x: &Input,
109        y: &Output,
110    ) -> Result<Self::EvalMetrics, ProfilerError> {
111        Ok(model.compute_metrics(x, y)?)
112    }
113}