rust_ml/bench/core/
profiler.rs

1use crate::bench::core::error::ProfilerError;
2use crate::bench::core::train_metrics::TrainMetrics;
3
4/// Trait for profiling and benchmarking machine learning models during training and evaluation.
5///
6/// This trait defines methods for collecting performance metrics during model training and evaluation.
7/// It allows for consistent measurement of training time and model performance metrics across
8/// different model types and optim strategies.
9///
10/// # Type Parameters
11///
12/// * `Model` - The machine learning model type being profiled
13/// * `Opt` - The optimizer type used for training
14/// * `Input` - The input data type (features)
15/// * `Output` - The output data type (targets/labels)
16pub trait Profiler<Model, Opt, Input, Output> {
17    /// The type of evaluation metrics returned by the profiler
18    type EvalMetrics;
19
20    /// Profiles the training process of a model, collecting training time and evaluation metrics.
21    ///
22    /// This method measures the time taken for training while also computing performance metrics
23    /// on the provided data.
24    ///
25    /// # Arguments
26    ///
27    /// * `model` - Mutable reference to the model being trained
28    /// * `optimizer` - Mutable reference to the optimizer used for training
29    /// * `x` - Reference to input features
30    /// * `y` - Reference to output targets
31    ///
32    /// # Returns
33    ///
34    /// A tuple containing training metrics (including training time) and evaluation metrics
35    /// specific to the model type, or a ProfilerError if an error occurs during profiling.
36    fn train(
37        &self,
38        model: &mut Model,
39        optimizer: &mut Opt,
40        x: &Input,
41        y: &Output,
42    ) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError>;
43
44    /// Profiles the evaluation process of a model, computing performance metrics.
45    ///
46    /// This method evaluates the model on the provided data and returns metrics
47    /// specific to the model type.
48    ///
49    /// # Arguments
50    ///
51    /// * `model` - Mutable reference to the model being evaluated
52    /// * `x` - Reference to input features
53    /// * `y` - Reference to output targets
54    ///
55    /// # Returns
56    ///
57    /// Evaluation metrics specific to the model type, or a ProfilerError if an error occurs
58    /// during evaluation.
59    fn profile_evaluation(
60        &self,
61        model: &mut Model,
62        x: &Input,
63        y: &Output,
64    ) -> Result<Self::EvalMetrics, ProfilerError>;
65}