rust_ml/model/core/
base.rs

1use crate::core::error::ModelError;
2
3use super::param_collection::{GradientCollection, ParamCollection};
4
5pub trait BaseModel<Input, Output> {
6    /// Predicts an output value based on the input data.
7    ///
8    /// # Returns
9    ///
10    /// The predicted output value
11    fn predict(&self, x: &Input) -> Result<Output, ModelError>;
12
13    /// Computes the cost (or loss) between the predicted output and the actual output.
14    ///
15    /// # Returns
16    ///
17    /// The computed cost as a floating point value
18    fn compute_cost(&self, x: &Input, y: &Output) -> Result<f64, ModelError>;
19}
20
21pub trait OptimizableModel<Input, Output>:
22    BaseModel<Input, Output> + ParamCollection + GradientCollection
23{
24    /// Forward pass through the model.
25    fn forward(&self, input: &Input) -> Result<Output, ModelError>;
26
27    /// Backward pass to compute gradients.
28    fn backward(&mut self, input: &Input, output_grad: &Output) -> Result<(), ModelError>;
29
30    /// Computes the gradient of the cost with respect to the output predictions
31    fn compute_output_gradient(&self, x: &Input, y: &Output) -> Result<Output, ModelError>;
32}