pub trait OptimizableModel<Input, Output>:
BaseModel<Input, Output>
+ ParamCollection
+ GradientCollection {
// Required methods
fn forward(&self, input: &Input) -> Result<Output, ModelError>;
fn backward(
&mut self,
input: &Input,
output_grad: &Output,
) -> Result<(), ModelError>;
fn compute_output_gradient(
&self,
x: &Input,
y: &Output,
) -> Result<Output, ModelError>;
}
Required Methods§
Sourcefn forward(&self, input: &Input) -> Result<Output, ModelError>
fn forward(&self, input: &Input) -> Result<Output, ModelError>
Forward pass through the model.
Sourcefn backward(
&mut self,
input: &Input,
output_grad: &Output,
) -> Result<(), ModelError>
fn backward( &mut self, input: &Input, output_grad: &Output, ) -> Result<(), ModelError>
Backward pass to compute gradients.
Sourcefn compute_output_gradient(
&self,
x: &Input,
y: &Output,
) -> Result<Output, ModelError>
fn compute_output_gradient( &self, x: &Input, y: &Output, ) -> Result<Output, ModelError>
Computes the gradient of the cost with respect to the output predictions
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.
Implementors§
impl OptimizableModel<ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>>> for LinearRegression
Implementation of the OptimizableModel
trait for LinearRegression