pub trait OptimizableModel<Input, Output>:
BaseModel<Input, Output>
+ ParamCollection
+ GradientCollection {
// Required methods
fn forward(&self, input: &Input) -> Result<Output, ModelError>;
fn backward(
&mut self,
input: &Input,
output_grad: &Output,
) -> Result<(), ModelError>;
fn compute_output_gradient(
&self,
x: &Input,
y: &Output,
) -> Result<Output, ModelError>;
}Required Methods§
Sourcefn forward(&self, input: &Input) -> Result<Output, ModelError>
fn forward(&self, input: &Input) -> Result<Output, ModelError>
Forward pass through the model.
Sourcefn backward(
&mut self,
input: &Input,
output_grad: &Output,
) -> Result<(), ModelError>
fn backward( &mut self, input: &Input, output_grad: &Output, ) -> Result<(), ModelError>
Backward pass to compute gradients.
Sourcefn compute_output_gradient(
&self,
x: &Input,
y: &Output,
) -> Result<Output, ModelError>
fn compute_output_gradient( &self, x: &Input, y: &Output, ) -> Result<Output, ModelError>
Computes the gradient of the cost with respect to the output predictions
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".
Implementors§
impl OptimizableModel<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>>> for LinearRegression
Implementation of the OptimizableModel trait for LinearRegression