rust_ml/bench/core/profiler.rs
1use crate::bench::core::error::ProfilerError;
2use crate::bench::core::train_metrics::TrainMetrics;
3
4/// Trait for profiling and benchmarking machine learning models during training and evaluation.
5///
6/// This trait defines methods for collecting performance metrics during model training and evaluation.
7/// It allows for consistent measurement of training time and model performance metrics across
8/// different model types and optim strategies.
9///
10/// # Type Parameters
11///
12/// * `Model` - The machine learning model type being profiled
13/// * `Opt` - The optimizer type used for training
14/// * `Input` - The input data type (features)
15/// * `Output` - The output data type (targets/labels)
16pub trait Profiler<Model, Opt, Input, Output> {
17 /// The type of evaluation metrics returned by the profiler
18 type EvalMetrics;
19
20 /// Profiles the training process of a model, collecting training time and evaluation metrics.
21 ///
22 /// This method measures the time taken for training while also computing performance metrics
23 /// on the provided data.
24 ///
25 /// # Arguments
26 ///
27 /// * `model` - Mutable reference to the model being trained
28 /// * `optimizer` - Mutable reference to the optimizer used for training
29 /// * `x` - Reference to input features
30 /// * `y` - Reference to output targets
31 ///
32 /// # Returns
33 ///
34 /// A tuple containing training metrics (including training time) and evaluation metrics
35 /// specific to the model type, or a ProfilerError if an error occurs during profiling.
36 fn train(
37 &self,
38 model: &mut Model,
39 optimizer: &mut Opt,
40 x: &Input,
41 y: &Output,
42 ) -> Result<(TrainMetrics, Self::EvalMetrics), ProfilerError>;
43
44 /// Profiles the evaluation process of a model, computing performance metrics.
45 ///
46 /// This method evaluates the model on the provided data and returns metrics
47 /// specific to the model type.
48 ///
49 /// # Arguments
50 ///
51 /// * `model` - Mutable reference to the model being evaluated
52 /// * `x` - Reference to input features
53 /// * `y` - Reference to output targets
54 ///
55 /// # Returns
56 ///
57 /// Evaluation metrics specific to the model type, or a ProfilerError if an error occurs
58 /// during evaluation.
59 fn profile_evaluation(
60 &self,
61 model: &mut Model,
62 x: &Input,
63 y: &Output,
64 ) -> Result<Self::EvalMetrics, ProfilerError>;
65}