rustyml/
traits.rs

1use crate::ModelError;
2use crate::machine_learning::RegularizationType;
3use crate::neural_network::Tensor;
4use crate::neural_network::layer::LayerWeight;
5
6pub trait RegressorCommonGetterFunctions {
7    /// Gets the current setting for fitting the intercept term
8    ///
9    /// # Returns
10    ///
11    /// * `bool` - Returns `true` if the model includes an intercept term, `false` otherwise
12    fn get_fit_intercept(&self) -> bool;
13
14    /// Gets the current learning rate
15    ///
16    /// The learning rate controls the step size in each iteration of gradient descent.
17    ///
18    /// # Returns
19    ///
20    /// * `f64` - The current learning rate value
21    fn get_learning_rate(&self) -> f64;
22
23    /// Gets the maximum number of iterations
24    ///
25    /// # Returns
26    ///
27    /// * `usize` - The maximum number of iterations for the gradient descent algorithm
28    fn get_max_iterations(&self) -> usize;
29
30    /// Gets the convergence tolerance threshold
31    ///
32    /// The convergence tolerance is used to determine when to stop the training process.
33    /// Training stops when the change in the loss function between consecutive iterations
34    /// is less than this value.
35    ///
36    /// # Returns
37    ///
38    /// * `f64` - The current convergence tolerance value
39    fn get_tolerance(&self) -> f64;
40
41    /// Returns the actual number of actual iterations performed during the last model fitting.
42    ///
43    /// # Returns
44    ///
45    /// - `Ok(usize)` - The number of iterations if the model has been fitted
46    /// - `Err(ModelError::NotFitted)` - If the model has not been fitted yet
47    fn get_actual_iterations(&self) -> Result<usize, ModelError>;
48
49    /// Returns a reference to the regularization type of the model
50    ///
51    /// This method provides access to the regularization configuration of the model,
52    /// which can be None (no regularization), L1 (LASSO), or L2 (Ridge).
53    ///
54    /// # Returns
55    ///
56    /// * `&Option<RegularizationType>` - A reference to the regularization type, which will be None if no regularization is applied
57    fn get_regularization_type(&self) -> &Option<RegularizationType>;
58}
59
60/// Defines the interface for neural network layers.
61///
62/// This trait provides the core functionality that all neural network layers must implement,
63/// including forward and backward propagation, as well as parameter updates for different
64/// optimization algorithms.
65pub trait Layer {
66    /// Performs forward propagation through the layer.
67    ///
68    /// # Parameters
69    ///
70    /// * `input` - The input tensor to the layer
71    ///
72    /// # Returns
73    ///
74    /// The output tensor after forward computation
75    fn forward(&mut self, input: &Tensor) -> Tensor;
76
77    /// Performs backward propagation through the layer.
78    ///
79    /// # Parameters
80    ///
81    /// * `grad_output` - The gradient tensor from the next layer
82    ///
83    /// # Returns
84    ///
85    /// - `Ok(Tensor)` - The gradient tensor to be passed to the previous layer
86    /// - `Err(ModelError::ProcessingError(String))` - If the layer encountered an error during processing`
87    fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError>;
88
89    /// Returns the type name of the layer (e.g., "Dense").
90    ///
91    /// # Returns
92    ///
93    /// A string slice representing the layer type
94    fn layer_type(&self) -> &str {
95        "Unknown"
96    }
97
98    /// Returns a description of the output shape of the layer.
99    ///
100    /// # Returns
101    ///
102    /// A string describing the output dimensions
103    fn output_shape(&self) -> String {
104        "Unknown".to_string()
105    }
106
107    /// Returns the total number of trainable parameters in the layer.
108    ///
109    /// # Returns
110    ///
111    /// The count of parameters as an usize
112    fn param_count(&self) -> usize {
113        0
114    }
115
116    /// Updates the layer parameters using Stochastic Gradient Descent.
117    ///
118    /// # Parameters
119    ///
120    /// * `_lr` - Learning rate for parameter updates
121    fn update_parameters_sgd(&mut self, _lr: f32);
122
123    /// Updates the layer parameters using Adam optimizer.
124    ///
125    /// # Parameters
126    ///
127    /// - `_lr` - Learning rate for parameter updates
128    /// - `_beta1` - Exponential decay rate for the first moment estimates
129    /// - `_beta2` - Exponential decay rate for the second moment estimates
130    /// - `_epsilon` - Small constant for numerical stability
131    /// - `_t` - Current training iteration
132    fn update_parameters_adam(
133        &mut self,
134        _lr: f32,
135        _beta1: f32,
136        _beta2: f32,
137        _epsilon: f32,
138        _t: u64,
139    );
140
141    /// Updates the layer parameters using RMSprop optimizer.
142    ///
143    /// # Parameters
144    ///
145    /// - `_lr` - Learning rate for parameter updates
146    /// - `_rho` - Decay rate for moving average of squared gradients
147    /// - `_epsilon` - Small constant for numerical stability
148    fn update_parameters_rmsprop(&mut self, _lr: f32, _rho: f32, _epsilon: f32);
149
150    /// Returns a map of all weights in the layer.
151    ///
152    /// This method provides access to all weight matrices and bias vectors used by the LSTM layer.
153    /// The weights are organized by gate (input, forget, cell, output) and by their role
154    /// (kernel, recurrent_kernel, bias) within each gate.
155    ///
156    /// # Returns
157    ///
158    /// * A `LayerWeight` enum containing:
159    ///   - `LayerWeight::Dense` for Dense layers with weight and bias
160    ///   - `LayerWeight::SimpleRNN` for SimpleRNN layers with kernel, recurrent_kernel, and bias
161    ///   - `LayerWeight::LSTM` for LSTM layers with weights for input, forget, cell, and output gates
162    fn get_weights(&self) -> LayerWeight;
163}
164
165/// Defines the interface for loss functions used in neural network training.
166///
167/// This trait provides methods to compute both the loss value and its gradient
168/// with respect to the predicted values.
169pub trait LossFunction {
170    /// Computes the loss between true and predicted values.
171    ///
172    /// # Parameters
173    ///
174    /// - `y_true` - Tensor containing the ground truth values
175    /// - `y_pred` - Tensor containing the predicted values
176    ///
177    /// # Returns
178    ///
179    /// The scalar loss value
180    fn compute_loss(&self, y_true: &Tensor, y_pred: &Tensor) -> f32;
181
182    /// Computes the gradient of the loss with respect to the predictions.
183    ///
184    /// # Parameters
185    ///
186    /// - `y_true` - Tensor containing the ground truth values
187    /// - `y_pred` - Tensor containing the predicted values
188    ///
189    /// # Returns
190    ///
191    /// Tensor containing the gradient of the loss
192    fn compute_grad(&self, y_true: &Tensor, y_pred: &Tensor) -> Tensor;
193}
194
195/// Defines the interface for optimization algorithms.
196///
197/// This trait provides methods to update layer parameters during
198/// the training process.
199pub trait Optimizer {
200    /// Updates the parameters of a layer according to the optimization algorithm.
201    ///
202    /// # Parameters
203    ///
204    /// * `layer` - The layer whose parameters should be updated
205    fn update(&mut self, layer: &mut dyn Layer);
206}